# DeepSeek-OCR with Tversky Neural Networks - Sinhala OCR

This notebook trains DeepSeek-OCR enhanced with Tversky Projection layers for Sinhala OCR.

**Dataset:**
- Images: `/kaggle/input/sinhala-printed-text-dataset-400/images` (400 images)
- Annotations: `/kaggle/input/sinhala-printed-text-dataset-400/annotations.csv`

**Requirements:**
- Kaggle GPU: T4 x2 or P100
- RAM: 16GB+

**Steps:**
1. Environment Setup
2. Load Tversky module
3. Load base model
4. Apply Tversky conversion
5. Prepare Sinhala dataset
6. Train with mixed precision
7. Evaluate and save

## 1. Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

In [None]:
# Install required packages
!pip install -q transformers>=4.37.0 accelerate>=0.25.0 bitsandbytes>=0.41.0
!pip install -q datasets pillow tqdm
!pip install -q ninja packaging pandas

## 2. Tversky Module Setup

Option A: Clone from GitHub
Option B: Upload tversky folder as Kaggle dataset

In [None]:
# Option A: Clone from GitHub (replace with your repo)
# !git clone https://github.com/YOUR_USERNAME/DeepSeek-OCR.git /kaggle/working/DeepSeek-OCR

# Option B: If uploaded as Kaggle dataset named 'tversky-ocr-code'
# !cp -r /kaggle/input/tversky-ocr-code/tversky /kaggle/working/

import sys
import os

# Add path based on your setup
TVERSKY_PATHS = [
    '/kaggle/working/DeepSeek-OCR/DeepSeek-OCR-master/DeepSeek-OCR-vllm',
    '/kaggle/working',
    '/kaggle/input/tversky-ocr-code'
]

for path in TVERSKY_PATHS:
    if os.path.exists(path):
        sys.path.insert(0, path)
        print(f"Added to path: {path}")

In [None]:
# Verify Tversky import
try:
    from tversky import (
        TverskyProjection,
        TverskyLMHead,
        TverskyTrainingConfig,
        SINHALA_OCR_TVERSKY_CONFIG,
        create_tversky_optimizer,
        get_tversky_regularization_loss,
        monitor_tversky_health,
        analyze_tversky_parameters
    )
    print("Tversky module imported successfully!")
except ImportError as e:
    print(f"Import error: {e}")
    print("Please upload the tversky folder or clone the repo")

## 3. Dataset Paths & Configuration

In [None]:
# ============================================
# YOUR KAGGLE DATASET PATHS
# ============================================
KAGGLE_IMAGES_DIR = "/kaggle/input/sinhala-printed-text-dataset-400/images"
KAGGLE_ANNOTATIONS_CSV = "/kaggle/input/sinhala-printed-text-dataset-400/annotations.csv"

# Verify paths exist
import os
print(f"Images directory exists: {os.path.exists(KAGGLE_IMAGES_DIR)}")
print(f"Annotations file exists: {os.path.exists(KAGGLE_ANNOTATIONS_CSV)}")

if os.path.exists(KAGGLE_IMAGES_DIR):
    images = os.listdir(KAGGLE_IMAGES_DIR)
    print(f"Number of images: {len(images)}")
    print(f"Sample images: {images[:5]}")

In [None]:
# Inspect the CSV file
import pandas as pd

df = pd.read_csv(KAGGLE_ANNOTATIONS_CSV)
print(f"CSV shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
print(f"\nFirst 10 rows:")
df.head(10)

In [None]:
from dataclasses import dataclass
from typing import Optional
import os

@dataclass
class KaggleTrainingConfig:
    """Configuration optimized for Kaggle T4/P100 GPUs."""
    
    # Dataset paths
    images_dir: str = KAGGLE_IMAGES_DIR
    annotations_csv: str = KAGGLE_ANNOTATIONS_CSV
    
    # CSV column names (update these based on your CSV structure)
    image_col: str = 'image'  # Column containing image filenames
    text_col: str = 'text'    # Column containing ground truth text
    
    # Model
    model_name: str = "deepseek-ai/deepseek-vl-1.3b-chat"
    use_4bit: bool = True
    use_flash_attention: bool = False  # Set False if not supported
    
    # Tversky
    num_features: int = 512
    conversion_strategy: str = 'lm_head_only'
    feature_activation: str = 'softplus'
    use_smooth_min: bool = True
    smooth_min_temperature: float = 0.5
    init_alpha: float = 0.3
    init_beta: float = 0.7
    init_gamma: float = 15.0
    
    # Training - optimized for 400 samples
    batch_size: int = 2          # Small for T4 16GB
    gradient_accumulation_steps: int = 4  # Effective batch = 8
    learning_rate: float = 5e-5
    tversky_lr_multiplier: float = 0.05
    num_epochs: int = 20         # More epochs for small dataset
    warmup_ratio: float = 0.1
    max_seq_length: int = 256    # Adjust based on your text lengths
    
    # Mixed precision
    fp16: bool = True
    bf16: bool = False
    
    # Regularization
    diversity_weight: float = 0.02
    sparsity_weight: float = 0.001
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    
    # Validation split
    val_split: float = 0.1  # 10% for validation (40 images)
    
    # Paths
    output_dir: str = '/kaggle/working/outputs'
    
    # Logging
    logging_steps: int = 10
    eval_steps: int = 50
    save_steps: int = 100

config = KaggleTrainingConfig()

os.makedirs(config.output_dir, exist_ok=True)

print("Training Configuration:")
for k, v in vars(config).items():
    print(f"  {k}: {v}")

## 4. Dataset Class for CSV Annotations

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import pandas as pd
from pathlib import Path
import torchvision.transforms as T

class SinhalaOCRDataset(Dataset):
    """
    Dataset for Sinhala OCR with CSV annotations.
    """
    
    def __init__(
        self,
        images_dir: str,
        annotations_csv: str,
        tokenizer,
        image_col: str = 'image',
        text_col: str = 'text',
        max_length: int = 256,
        image_size: tuple = (384, 384)
    ):
        self.images_dir = Path(images_dir)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.image_size = image_size
        
        # Load CSV
        df = pd.read_csv(annotations_csv)
        
        # Auto-detect columns if needed
        if image_col not in df.columns:
            for col in ['image', 'image_name', 'filename', 'file', 'img', 'Image', 'path']:
                if col in df.columns:
                    image_col = col
                    break
            else:
                image_col = df.columns[0]
                
        if text_col not in df.columns:
            for col in ['text', 'label', 'ground_truth', 'gt', 'Text', 'annotation', 'transcription']:
                if col in df.columns:
                    text_col = col
                    break
            else:
                text_col = df.columns[1]
        
        print(f"Using columns: image='{image_col}', text='{text_col}'")
        
        # Create samples list
        self.samples = []
        missing_count = 0
        
        for _, row in df.iterrows():
            img_name = str(row[image_col])
            text = str(row[text_col])
            
            # Try to find image file
            img_path = self.images_dir / img_name
            
            if not img_path.exists():
                # Try common extensions
                for ext in ['', '.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']:
                    test_path = self.images_dir / f"{img_name}{ext}"
                    if test_path.exists():
                        img_path = test_path
                        break
            
            if img_path.exists():
                self.samples.append({'image_path': img_path, 'text': text})
            else:
                missing_count += 1
        
        print(f"Loaded {len(self.samples)} samples")
        if missing_count > 0:
            print(f"Warning: {missing_count} images not found")
        
        # Image transforms
        self.transform = T.Compose([
            T.Resize(image_size),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load and transform image
        image = Image.open(sample['image_path']).convert('RGB')
        pixel_values = self.transform(image)
        
        # Tokenize text
        text = sample['text']
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'pixel_values': pixel_values,
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': encoding['input_ids'].squeeze(0).clone()
        }

## 5. Load Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig

def load_model(config):
    """Load model with memory-efficient settings."""
    
    if config.use_4bit:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
        )
    else:
        bnb_config = None
    
    tokenizer = AutoTokenizer.from_pretrained(
        config.model_name,
        trust_remote_code=True
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch.float16 if config.fp16 else torch.float32,
    )
    
    model.gradient_checkpointing_enable()
    
    return model, tokenizer

print("Loading model...")
model, tokenizer = load_model(config)
print(f"Model loaded")
print(f"  Vocab size: {model.config.vocab_size}")
print(f"  Hidden size: {model.config.hidden_size}")

## 6. Apply Tversky Conversion

In [None]:
import torch.nn as nn
from tversky import TverskyLMHead

def convert_to_tversky(model, config):
    """Convert LM head to Tversky projection."""
    
    hidden_size = model.config.hidden_size
    vocab_size = model.config.vocab_size
    
    # Find LM head
    lm_head_attr = None
    old_lm_head = None
    
    for name in ['lm_head', 'output', 'cls']:
        if hasattr(model, name):
            old_lm_head = getattr(model, name)
            lm_head_attr = name
            break
    
    if old_lm_head is None:
        print("Could not find LM head. Model structure:")
        for name, module in model.named_children():
            print(f"  {name}: {type(module).__name__}")
        return model
    
    old_params = sum(p.numel() for p in old_lm_head.parameters())
    
    # Create Tversky head
    new_lm_head = TverskyLMHead(
        hidden_size=hidden_size,
        vocab_size=vocab_size,
        num_features=config.num_features,
        init_from_linear=old_lm_head if isinstance(old_lm_head, nn.Linear) else None,
        feature_activation=config.feature_activation,
        use_smooth_min=config.use_smooth_min,
        smooth_min_temperature=config.smooth_min_temperature,
        init_alpha=config.init_alpha,
        init_beta=config.init_beta,
        init_gamma=config.init_gamma
    )
    
    # Move to device
    device = next(old_lm_head.parameters()).device
    dtype = next(old_lm_head.parameters()).dtype
    new_lm_head = new_lm_head.to(device=device, dtype=dtype)
    
    setattr(model, lm_head_attr, new_lm_head)
    
    new_params = sum(p.numel() for p in new_lm_head.parameters())
    
    print(f"\nTversky conversion complete:")
    print(f"  Original params: {old_params:,}")
    print(f"  Tversky params: {new_params:,}")
    print(f"  Reduction: {(1 - new_params/old_params)*100:.1f}%")
    
    return model

model = convert_to_tversky(model, config)

## 7. Create Data Loaders

In [None]:
# Create dataset
full_dataset = SinhalaOCRDataset(
    images_dir=config.images_dir,
    annotations_csv=config.annotations_csv,
    tokenizer=tokenizer,
    image_col=config.image_col,
    text_col=config.text_col,
    max_length=config.max_seq_length
)

# Split into train/val
val_size = int(len(full_dataset) * config.val_split)
train_size = len(full_dataset) - val_size

train_dataset, val_dataset = random_split(
    full_dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"\nDataset split:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")

# Create loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

In [None]:
# Test a batch
sample_batch = next(iter(train_loader))
print("Sample batch shapes:")
for k, v in sample_batch.items():
    print(f"  {k}: {v.shape}")

# Decode a sample
sample_text = tokenizer.decode(sample_batch['input_ids'][0], skip_special_tokens=True)
print(f"\nSample text: {sample_text[:100]}...")

## 8. Training Loop

In [None]:
from tversky import create_tversky_optimizer, get_tversky_regularization_loss, monitor_tversky_health, analyze_tversky_parameters
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm
import time

class Trainer:
    def __init__(self, model, config, train_loader, val_loader):
        self.model = model
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = next(model.parameters()).device
        
        # Optimizer
        self.optimizer = create_tversky_optimizer(
            model,
            base_lr=config.learning_rate,
            tversky_lr_multiplier=config.tversky_lr_multiplier,
            weight_decay=config.weight_decay
        )
        
        # Scheduler
        total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=total_steps
        )
        
        # Mixed precision
        self.scaler = GradScaler() if config.fp16 else None
        
        # Tracking
        self.best_val_loss = float('inf')
        self.history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
        
    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        self.optimizer.zero_grad()
        
        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}")
        
        for step, batch in enumerate(pbar):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)
            
            if self.config.fp16:
                with autocast():
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels
                    )
                    loss = outputs.loss
                    reg_loss = get_tversky_regularization_loss(
                        self.model,
                        self.config.diversity_weight,
                        self.config.sparsity_weight
                    )
                    loss = (loss + reg_loss) / self.config.gradient_accumulation_steps
                
                self.scaler.scale(loss).backward()
            else:
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss
                reg_loss = get_tversky_regularization_loss(
                    self.model,
                    self.config.diversity_weight,
                    self.config.sparsity_weight
                )
                loss = (loss + reg_loss) / self.config.gradient_accumulation_steps
                loss.backward()
            
            total_loss += loss.item() * self.config.gradient_accumulation_steps
            
            if (step + 1) % self.config.gradient_accumulation_steps == 0:
                if self.config.fp16:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                    self.optimizer.step()
                
                self.scheduler.step()
                self.optimizer.zero_grad()
            
            pbar.set_postfix({'loss': f'{total_loss/(step+1):.4f}'})
        
        return total_loss / len(self.train_loader)
    
    @torch.no_grad()
    def evaluate(self):
        self.model.eval()
        total_loss = 0
        total_correct = 0
        total_tokens = 0
        
        for batch in self.val_loader:
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)
            
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            total_loss += outputs.loss.item()
            
            predictions = outputs.logits.argmax(dim=-1)
            mask = labels != -100
            total_correct += ((predictions == labels) & mask).sum().item()
            total_tokens += mask.sum().item()
        
        return {
            'loss': total_loss / len(self.val_loader),
            'accuracy': total_correct / total_tokens if total_tokens > 0 else 0
        }
    
    def train(self):
        print(f"\n{'='*60}")
        print(f"Starting training - {self.config.num_epochs} epochs")
        print(f"{'='*60}\n")
        
        for epoch in range(self.config.num_epochs):
            train_loss = self.train_epoch(epoch)
            val_metrics = self.evaluate()
            
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_metrics['loss'])
            self.history['val_accuracy'].append(val_metrics['accuracy'])
            
            # Print summary
            print(f"\nEpoch {epoch+1}/{self.config.num_epochs}:")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val Loss: {val_metrics['loss']:.4f}")
            print(f"  Val Accuracy: {val_metrics['accuracy']*100:.2f}%")
            
            # Tversky analysis
            analysis = analyze_tversky_parameters(self.model)
            for name, params in analysis.items():
                print(f"  {name}: a={params['alpha']:.3f}, b={params['beta']:.3f}, g={params['gamma']:.3f}")
            
            # Check health
            warnings = monitor_tversky_health(self.model)
            if warnings:
                print(f"  Warnings: {warnings}")
            
            # Save best
            if val_metrics['loss'] < self.best_val_loss:
                self.best_val_loss = val_metrics['loss']
                self.save('best_model.pt')
                print(f"  Saved best model")
        
        print(f"\n{'='*60}")
        print(f"Training complete! Best val loss: {self.best_val_loss:.4f}")
        print(f"{'='*60}")
        
        return self.history
    
    def save(self, filename):
        path = os.path.join(self.config.output_dir, filename)
        tversky_state = {}
        for name, param in self.model.named_parameters():
            if any(k in name for k in ['alpha_raw', 'beta_raw', 'gamma', 'feature_bank', 'prototype_bank']):
                tversky_state[name] = param.data.cpu()
        
        torch.save({
            'tversky_state_dict': tversky_state,
            'config': vars(self.config),
            'history': self.history,
            'best_val_loss': self.best_val_loss
        }, path)

In [None]:
# Create trainer and train
trainer = Trainer(model, config, train_loader, val_loader)
history = trainer.train()

## 9. Visualize Results

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Accuracy
axes[1].plot([acc * 100 for acc in history['val_accuracy']])
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Validation Accuracy')
axes[1].grid(True)

plt.tight_layout()
plt.savefig(f"{config.output_dir}/training_curves.png", dpi=150)
plt.show()

## 10. Test Inference

In [None]:
# Test on a few validation samples
@torch.no_grad()
def test_samples(model, tokenizer, val_dataset, num_samples=5):
    model.eval()
    device = next(model.parameters()).device
    
    print("\nSample Predictions:")
    print("="*60)
    
    for i in range(min(num_samples, len(val_dataset))):
        sample = val_dataset[i]
        
        input_ids = sample['input_ids'].unsqueeze(0).to(device)
        
        # Get prediction
        outputs = model(input_ids=input_ids)
        predictions = outputs.logits.argmax(dim=-1)
        
        # Decode
        ground_truth = tokenizer.decode(sample['input_ids'], skip_special_tokens=True)
        predicted = tokenizer.decode(predictions[0], skip_special_tokens=True)
        
        print(f"\nSample {i+1}:")
        print(f"  Ground Truth: {ground_truth[:100]}")
        print(f"  Predicted:    {predicted[:100]}")

test_samples(model, tokenizer, val_dataset)

## 11. Save & Download

In [None]:
# List output files
print("\nOutput files:")
for f in os.listdir(config.output_dir):
    filepath = os.path.join(config.output_dir, f)
    size_mb = os.path.getsize(filepath) / 1e6
    print(f"  {f} ({size_mb:.1f} MB)")

---

## Notes

### CSV Column Names
Update `config.image_col` and `config.text_col` if your CSV has different column names.

### Memory Issues
If you get OOM errors:
1. Reduce `batch_size` to 1
2. Increase `gradient_accumulation_steps`
3. Reduce `max_seq_length`

### After Training
1. Download `best_model.pt` from Output tab
2. Use it to initialize Tversky layers for inference