# üçé OCR Training - Apple MPS Optimized

**Optimized for Mac with Apple Silicon (M1/M2/M3)**

## Key Settings:
- ‚úÖ MPS device acceleration
- ‚úÖ IMG_WIDTH = 1024 (stable)
- ‚úÖ Batch size = 16
- ‚úÖ Light augmentation
- ‚úÖ Conservative learning rate

**Dataset:** `./2/dataset/`

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
import random
import warnings
from tqdm import tqdm
warnings.filterwarnings('ignore')

# MPS setup for Apple Silicon
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('‚úÖ Using Apple MPS (Metal Performance Shaders)')
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print('‚úÖ Using CUDA')
else:
    device = torch.device('cpu')
    print('‚ö†Ô∏è  Using CPU')

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)
        
set_seed(42)
print(f'Device: {device}\n')

‚úÖ Using Apple MPS (Metal Performance Shaders)
Device: mps



## Load Local Dataset

In [2]:
# Local paths
BASE_PATH = './2/dataset/'
IMG_DIR = os.path.join(BASE_PATH, 'images')
labels_path = os.path.join(BASE_PATH, 'labels.csv')

print(f'Loading from: {BASE_PATH}')
print(f'Images dir: {IMG_DIR}')
print(f'Labels: {labels_path}\n')

df = pd.read_csv(labels_path)
print(f'Total rows in CSV: {len(df)}')

# Clean data
df = df[df['text'].notna()].reset_index(drop=True)
df = df[df['text'].str.strip() != ''].reset_index(drop=True)

# Filter by length
MAX_TEXT_LEN = 130
df = df[df['text'].str.len() <= MAX_TEXT_LEN].reset_index(drop=True)

# Verify images exist
df['img_path'] = df['file_name'].apply(lambda x: os.path.join(IMG_DIR, x))
df['exists'] = df['img_path'].apply(os.path.isfile)
df = df[df['exists']].reset_index(drop=True)

print(f'Valid samples: {len(df)}')
print(f'Max text length: {df["text"].str.len().max()}')
print(f'\nSample entry:')
print(df.iloc[0]['text'][:100])

Loading from: ./2/dataset/
Images dir: ./2/dataset/images
Labels: ./2/dataset/labels.csv

Total rows in CSV: 20000
Valid samples: 6137
Max text length: 130

Sample entry:
LIST_0 ‚Üê [ 55 , 53 , 55 , 50 ]
VAR_0 ‚Üê taille ( LIST_0 )


## Build Vocabulary

In [3]:
# Extract all unique characters
unique_chars = sorted(set(''.join(df['text'].tolist())))
char_list = ['<blank>'] + unique_chars  # blank for CTC
char_to_idx = {ch: i for i, ch in enumerate(char_list)}
idx_to_char = {i: ch for i, ch in enumerate(char_list)}
vocab_size = len(char_list)

print(f'Vocabulary size: {vocab_size}')
print(f'Characters: {"".join(unique_chars[:50])}...')

Vocabulary size: 57
Characters: 	
 ()*+,-/0123456789<=>ADILRSTVX[]_adefghilmnopqrs...


## Configuration - OPTIMIZED

In [4]:
# Image config - FIXED for stability
IMG_HEIGHT = 64
IMG_WIDTH = 1024  # ‚úÖ Reduced from 3200

# Training config
BATCH_SIZE = 16   # ‚úÖ Increased from 8
NUM_EPOCHS = 40
PATIENCE = 12

print(f'Image size: {IMG_HEIGHT}x{IMG_WIDTH}')
print(f'Expected sequence length: ~{IMG_WIDTH // 4}')
print(f'Batch size: {BATCH_SIZE}')
print(f'Epochs: {NUM_EPOCHS} (patience: {PATIENCE})')

Image size: 64x1024
Expected sequence length: ~256
Batch size: 16
Epochs: 40 (patience: 12)


## Dataset & DataLoader

In [5]:
class OCRDataset(Dataset):
    def __init__(self, dataframe, char_to_idx, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.char_to_idx = char_to_idx
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['img_path']).convert('RGB')
        
        if self.transform:
            img = self.transform(img)
            
        encoded = [self.char_to_idx[ch] for ch in row['text']]
        return img, torch.tensor(encoded, dtype=torch.long), len(encoded)

def collate_fn(batch):
    imgs, texts, lens = zip(*batch)
    imgs = torch.stack(imgs, dim=0)
    lens = torch.tensor(lens, dtype=torch.long)
    max_len = max(lens)
    padded = torch.zeros(len(texts), max_len, dtype=torch.long)
    for i, text in enumerate(texts):
        padded[i, :len(text)] = text
    return imgs, padded, lens

# Light augmentation - text friendly
train_transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.1, hue=0.0)
    ], p=0.5),
    transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 0.5))], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Split 80/20
train_size = int(0.8 * len(df))
val_size = len(df) - train_size

train_idx, val_idx = random_split(
    range(len(df)), [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

train_ds = OCRDataset(df.iloc[train_idx.indices], char_to_idx, train_transform)
val_ds = OCRDataset(df.iloc[val_idx.indices], char_to_idx, val_transform)

train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
val_loader = DataLoader(val_ds, BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=0)

print(f'Train: {len(train_ds)}, Val: {len(val_ds)}')
print(f'Train batches: {len(train_loader)}, Val batches: {len(val_loader)}')

Train: 4909, Val: 1228
Train batches: 307, Val batches: 77


## Model Architecture

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class ImprovedOCR(nn.Module):
    def __init__(self, num_classes, hidden_dim=256, nhead=8, num_layers=4):
        super().__init__()
        
        # CNN backbone
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 32x512
            
            nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 16x256
            
            nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(),
            
            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, None))  # 1xW
        )
        
        self.proj = nn.Linear(512, hidden_dim)
        self.pos_enc = PositionalEncoding(hidden_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=nhead, 
            dim_feedforward=hidden_dim * 4,
            dropout=0.2, activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.norm = nn.LayerNorm(hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        x = self.conv(x)  # [B, 512, 1, W]
        x = x.squeeze(2).permute(0, 2, 1)  # [B, W, 512]
        x = self.proj(x)
        x = self.pos_enc(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = self.fc(x)
        return x.permute(1, 0, 2)  # [T, B, C] for CTC

model = ImprovedOCR(num_classes=vocab_size, hidden_dim=256, nhead=8, num_layers=4).to(device)

params = sum(p.numel() for p in model.parameters())
print(f'‚úÖ Model loaded')
print(f'Parameters: {params:,}')

# Test forward pass
test_in = torch.randn(2, 3, IMG_HEIGHT, IMG_WIDTH).to(device)
test_out = model(test_in)
print(f'Output shape: {test_out.shape}  (T={test_out.shape[0]}, B={test_out.shape[1]}, C={test_out.shape[2]})')

‚úÖ Model loaded
Parameters: 5,449,017
Output shape: torch.Size([256, 2, 57])  (T=256, B=2, C=57)


## Training Setup

In [7]:
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=3e-4,
    epochs=NUM_EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.3
)

def decode(logits, idx_to_char):
    preds = []
    logits = logits.permute(1, 0, 2)
    for i in range(logits.size(0)):
        indices = torch.argmax(logits[i], dim=-1).tolist()
        decoded, prev = [], None
        for idx in indices:
            if idx != 0 and idx != prev:
                decoded.append(idx_to_char.get(idx, ''))
            prev = idx
        preds.append(''.join(decoded))
    return preds

def accuracy(preds, targets, lens):
    total, correct = 0, 0
    for pred, target, l in zip(preds, targets, lens):
        true = ''.join([idx_to_char[target[j].item()] for j in range(l)])
        total += len(true)
        for i, c in enumerate(true):
            if i < len(pred) and pred[i] == c:
                correct += 1
    return correct / total if total > 0 else 0.0

print('‚úÖ Training setup complete')

‚úÖ Training setup complete


## Training Loop

In [None]:
best_val_acc = 0.0
best_epoch = 0
train_losses, val_losses = [], []
train_accs, val_accs = [], []

print('üöÄ Training on MPS...\n')

for epoch in range(1, NUM_EPOCHS + 1):
    # TRAIN
    model.train()
    t_loss = 0
    t_preds, t_targets, t_lens = [], [], []
    valid = 0
    skipped = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{NUM_EPOCHS}', leave=False)
    for imgs, targets, lens in pbar:
        imgs = imgs.to(device)
        
        with torch.no_grad():
            test_out = model(imgs[:1])
            seq_len = test_out.size(0)
        
        if (lens > seq_len).any():
            skipped += 1
            continue
        
        optimizer.zero_grad()
        logits = model(imgs)
        in_lens = torch.full((imgs.size(0),), seq_len, dtype=torch.long)
        log_probs = F.log_softmax(logits, dim=-1)
        # Move to CPU for CTC loss (MPS doesn't support CTC yet)
        loss = criterion(log_probs.cpu(), targets, in_lens, lens)
        
        if torch.isnan(loss) or torch.isinf(loss) or loss.item() > 100:
            skipped += 1
            continue
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        optimizer.step()
        scheduler.step()
        
        t_loss += loss.item()
        valid += 1
        
        with torch.no_grad():
            t_preds.extend(decode(logits.cpu(), idx_to_char))
            t_targets.extend(targets)
            t_lens.extend(lens)
        
        pbar.set_postfix({'loss': f'{loss.item():.3f}'})
    
    if valid == 0:
        print(f'Epoch {epoch}: No valid batches!')
        continue
    
    avg_t = t_loss / valid
    t_acc = accuracy(t_preds, t_targets, t_lens)
    train_losses.append(avg_t)
    train_accs.append(t_acc)
    
    # VAL
    model.eval()
    v_loss = 0
    v_preds, v_targets, v_lens = [], [], []
    valid_v = 0
    
    with torch.no_grad():
        for imgs, targets, lens in val_loader:
            imgs = imgs.to(device)
            logits = model(imgs)
            seq_len = logits.size(0)
            in_lens = torch.full((imgs.size(0),), seq_len, dtype=torch.long)
            
            if (lens > seq_len).any():
                continue
            
            # Move to CPU for CTC loss (MPS doesn't support CTC yet)
            loss = criterion(log_probs.cpu(), targets, in_lens, lens)
            
            if not torch.isnan(loss) and not torch.isinf(loss) and loss.item() < 100:
                v_loss += loss.item()
                valid_v += 1
            
            v_preds.extend(decode(logits.cpu(), idx_to_char))
            v_targets.extend(targets)
            v_lens.extend(lens)
    
    avg_v = v_loss / valid_v if valid_v > 0 else float('inf')
    v_acc = accuracy(v_preds, v_targets, v_lens)
    val_losses.append(avg_v)
    val_accs.append(v_acc)
    
    if v_acc > best_val_acc:
        best_val_acc = v_acc
        best_epoch = epoch
        torch.save(model.state_dict(), 'best_mps_model.pth')
        print(f'‚úÖ Best: {v_acc*100:.2f}% (epoch {epoch})')
    
    skip_msg = f' | Skip: {skipped}' if skipped > 0 else ''
    print(f'Epoch {epoch:2d}/{NUM_EPOCHS} | Train: {avg_t:.4f}/{t_acc*100:.2f}% | Val: {avg_v:.4f}/{v_acc*100:.2f}%{skip_msg}')
    
    if epoch - best_epoch >= PATIENCE:
        print(f'\nüõë Early stop (no improvement for {PATIENCE} epochs)')
        print(f'Best: {best_val_acc*100:.2f}% at epoch {best_epoch}')
        break

print(f'\n‚úÖ Training complete!')

print(f'Best validation accuracy: {best_val_acc*100:.2f}% (epoch {best_epoch})')print(f'Best validation accuracy: {best_val_acc*100:.2f}% (epoch {best_epoch})')

üöÄ Training on MPS...



                                                   

NotImplementedError: The operator 'aten::_ctc_loss' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 449b1768410104d3ed79d3bcfe4ba1d65c7f22c0. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

## Plot Results

In [None]:
plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, 'o-', label='Train', lw=2)
plt.plot(val_losses, 's-', label='Val', lw=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot([a*100 for a in train_accs], 'o-', label='Train', lw=2)
plt.plot([a*100 for a in val_accs], 's-', label='Val', lw=2)
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Character Accuracy')
plt.legend()
plt.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('mps_training_results.png', dpi=150)
plt.show()

print(f'Best: {best_val_acc*100:.2f}% at epoch {best_epoch}')

## Evaluate & Show Predictions

In [None]:
model.load_state_dict(torch.load('best_mps_model.pth'))
model.eval()

final_preds, final_targets, final_lens = [], [], []

print('Evaluating on validation set...\n')
with torch.no_grad():
    for imgs, targets, lens in tqdm(val_loader, desc='Eval'):
        imgs = imgs.to(device)
        logits = model(imgs)
        final_preds.extend(decode(logits.cpu(), idx_to_char))
        final_targets.extend(targets)
        final_lens.extend(lens)

final_acc = accuracy(final_preds, final_targets, final_lens)

print(f'\nüéØ FINAL ACCURACY: {final_acc*100:.2f}%\n')
print('='*100)
print('Sample Predictions:\n')

for i in range(min(15, len(final_preds))):
    true = ''.join([idx_to_char[final_targets[i][j].item()] for j in range(final_lens[i])])
    pred = final_preds[i]
    char_acc = sum(1 for a,b in zip(true, pred) if a==b) / len(true) * 100 if len(true) > 0 else 0
    
    print(f'Sample {i+1:2d} | Acc: {char_acc:5.1f}%')
    print(f'  True: {true[:90]}')
    print(f'  Pred: {pred[:90]}')
    print('-'*100)

print(f'\n‚úÖ Final Accuracy: {final_acc*100:.2f}%')
print(f'Model saved as: best_mps_model.pth')