# High Precision OCR Model (Target: 95+ Score)

## Key Improvements:
1. **Stronger Backbone**: EfficientNet-b5 / ConvNeXt
2. **Higher Resolution**: 1536x1536
3. **Improved Loss**: Focal + Dice + Boundary Loss
4. **Multi-scale TTA**: Ensemble multiple scales
5. **Refined Post-processing**: Morphological ops + Adaptive threshold
6. **Pseudo Label**: Additional training data
7. **Precision/Recall Metrics**: Track actual performance

In [None]:
# Cell 1: Data Download
!wget -O data.tar.gz "https://aistages-api-public-prod.s3.amazonaws.com/app/Competitions/000377/data/data.tar.gz"
!tar -xzf data.tar.gz
!ls -la data/

In [None]:
# Cell 2: Install Dependencies
!pip install -q segmentation-models-pytorch albumentations opencv-python-headless timm

In [None]:
# Cell 3: Imports and Configuration
import os
import json
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
import gc
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# GPU Optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

# Paths
BASE_PATH = './data/datasets'
TRAIN_IMG_DIR = os.path.join(BASE_PATH, 'images/train')
VAL_IMG_DIR = os.path.join(BASE_PATH, 'images/val')
TEST_IMG_DIR = os.path.join(BASE_PATH, 'images/test')
TRAIN_JSON = os.path.join(BASE_PATH, 'jsons/train.json')
VAL_JSON = os.path.join(BASE_PATH, 'jsons/val.json')
TEST_JSON = os.path.join(BASE_PATH, 'jsons/test.json')
SAMPLE_SUB = os.path.join(BASE_PATH, 'sample_submission.csv')

# Pseudo Label Paths
PSEUDO_BASE = './data/pseudo_label'
SROIE_TRAIN_IMG = os.path.join(PSEUDO_BASE, 'sroie/images/train')
SROIE_TEST_IMG = os.path.join(PSEUDO_BASE, 'sroie/images/test')
WILDRECEIPT_IMG = os.path.join(PSEUDO_BASE, 'wildreceipt/images')
CORDV2_IMG = os.path.join(PSEUDO_BASE, 'cord-v2/images')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 4: Hyperparameters (Optimized for High Precision)
class Config:
    # Model
    ENCODER = 'tu-convnext_base'  # ConvNeXt for better feature extraction
    ENCODER_WEIGHTS = 'imagenet'
    
    # Training
    RESIZE_TARGET = 1536  # Higher resolution for small text
    BATCH_SIZE = 2        # Reduced for high resolution
    ACCUMULATION_STEPS = 16  # Effective batch size = 32
    EPOCHS = 40
    LEARNING_RATE = 5e-5  # Lower LR for stability
    WARMUP_EPOCHS = 3
    
    # Loss weights
    DICE_WEIGHT = 0.4
    BCE_WEIGHT = 0.3
    FOCAL_WEIGHT = 0.3
    
    # Inference
    THRESHOLD = 0.5
    MIN_AREA = 50  # Minimum polygon area
    
    # TTA scales
    TTA_SCALES = [1.0, 0.75, 1.25]
    
    # Use pseudo labels
    USE_PSEUDO_LABELS = True

cfg = Config()
print(f"Config: Resolution={cfg.RESIZE_TARGET}, Encoder={cfg.ENCODER}")

In [None]:
# Cell 5: Advanced Data Augmentation
def get_train_transform(size):
    return A.Compose([
        A.Resize(size, size),
        # Geometric transforms
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.3),
        A.Perspective(scale=(0.02, 0.08), p=0.4),
        A.Affine(
            scale=(0.9, 1.1),
            rotate=(-5, 5),
            shear=(-5, 5),
            p=0.3
        ),
        # Color transforms (for receipt images)
        A.OneOf([
            A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8)),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
            A.RandomGamma(gamma_limit=(80, 120)),
        ], p=0.5),
        # Noise and blur
        A.OneOf([
            A.GaussianBlur(blur_limit=(3, 5)),
            A.MotionBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
        ], p=0.2),
        A.GaussNoise(std_range=(0.02, 0.1), p=0.2),
        # Normalize
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def get_val_transform(size):
    return A.Compose([
        A.Resize(size, size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

print("Transforms defined.")

In [None]:
# Cell 6: Dataset Classes
class ReceiptDataset(Dataset):
    """Main dataset for train/val with JSON labels"""
    def __init__(self, img_dir, json_path, transform=None, is_test=False):
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        
        with open(json_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)['images']
        self.image_names = list(self.data.keys())
        
    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.img_dir, img_name)
        
        image = cv2.imread(img_path)
        if image is None:
            return self.__getitem__((idx + 1) % len(self))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w = image.shape[:2]
        
        if not self.is_test:
            # Create mask from polygon annotations
            mask = np.zeros((h, w), dtype=np.float32)
            words = self.data[img_name].get('words', {})
            for word_id, word_info in words.items():
                points = np.array(word_info['points'], dtype=np.int32)
                cv2.fillPoly(mask, [points], 1.0)
            
            if self.transform:
                augmented = self.transform(image=image, mask=mask)
                return augmented['image'], augmented['mask']
            return image, mask
        else:
            if self.transform:
                augmented = self.transform(image=image)
                return augmented['image'], img_name, (h, w)
            return image, img_name, (h, w)


class PseudoLabelDataset(Dataset):
    """Dataset for pseudo-labeled images (use image only, generate mask from model)"""
    def __init__(self, img_dirs, transform=None):
        self.transform = transform
        self.images = []
        
        for img_dir in img_dirs:
            if os.path.exists(img_dir):
                for fname in os.listdir(img_dir):
                    if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                        self.images.append(os.path.join(img_dir, fname))
        
        print(f"Pseudo Label: Found {len(self.images)} images")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = cv2.imread(img_path)
        if image is None:
            return self.__getitem__((idx + 1) % len(self))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            return augmented['image']
        return image

print("Dataset classes defined.")

In [None]:
# Cell 7: Advanced Loss Functions
class FocalLoss(nn.Module):
    """Focal Loss for imbalanced data"""
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        pred = pred.clamp(1e-7, 1 - 1e-7)
        
        # Focal loss computation
        pt = torch.where(target == 1, pred, 1 - pred)
        alpha_t = torch.where(target == 1, self.alpha, 1 - self.alpha)
        focal_weight = alpha_t * (1 - pt) ** self.gamma
        
        bce = F.binary_cross_entropy(pred, target, reduction='none')
        loss = focal_weight * bce
        return loss.mean()


class BoundaryLoss(nn.Module):
    """Boundary-aware loss for sharper edges"""
    def __init__(self):
        super().__init__()
        # Sobel kernels for edge detection
        self.register_buffer('sobel_x', torch.tensor(
            [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32
        ).view(1, 1, 3, 3))
        self.register_buffer('sobel_y', torch.tensor(
            [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32
        ).view(1, 1, 3, 3))
    
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        
        # Compute gradients for predicted mask
        pred_gx = F.conv2d(pred, self.sobel_x, padding=1)
        pred_gy = F.conv2d(pred, self.sobel_y, padding=1)
        pred_edge = torch.sqrt(pred_gx ** 2 + pred_gy ** 2 + 1e-8)
        
        # Compute gradients for target mask
        target = target.unsqueeze(1) if target.dim() == 3 else target
        target_gx = F.conv2d(target, self.sobel_x, padding=1)
        target_gy = F.conv2d(target, self.sobel_y, padding=1)
        target_edge = torch.sqrt(target_gx ** 2 + target_gy ** 2 + 1e-8)
        
        # MSE between edges
        return F.mse_loss(pred_edge, target_edge)


class CombinedLoss(nn.Module):
    """Combined loss: Dice + BCE + Focal + Boundary"""
    def __init__(self, dice_w=0.4, bce_w=0.3, focal_w=0.2, boundary_w=0.1):
        super().__init__()
        self.dice = smp.losses.DiceLoss(mode='binary')
        self.bce = nn.BCEWithLogitsLoss()
        self.focal = FocalLoss(alpha=0.25, gamma=2.0)
        self.boundary = BoundaryLoss()
        
        self.dice_w = dice_w
        self.bce_w = bce_w
        self.focal_w = focal_w
        self.boundary_w = boundary_w
    
    def forward(self, pred, target):
        target = target.unsqueeze(1) if target.dim() == 3 else target
        
        loss = (
            self.dice_w * self.dice(pred, target) +
            self.bce_w * self.bce(pred, target) +
            self.focal_w * self.focal(pred, target) +
            self.boundary_w * self.boundary(pred, target)
        )
        return loss

print("Loss functions defined.")

In [None]:
# Cell 8: Model Architecture with Attention
class AttentionBlock(nn.Module):
    """Channel and Spatial Attention"""
    def __init__(self, in_channels):
        super().__init__()
        # Channel attention
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 8, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 8, in_channels, 1, bias=False)
        )
        # Spatial attention
        self.conv = nn.Conv2d(2, 1, 7, padding=3, bias=False)
    
    def forward(self, x):
        # Channel attention
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        channel_att = torch.sigmoid(avg_out + max_out)
        x = x * channel_att
        
        # Spatial attention
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_att = torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1)))
        x = x * spatial_att
        
        return x


class HighPrecisionOCRModel(nn.Module):
    """UNet++ with Attention for High Precision OCR"""
    def __init__(self, encoder_name, encoder_weights='imagenet'):
        super().__init__()
        self.base_model = smp.UnetPlusPlus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=3,
            classes=1,
            decoder_attention_type='scse'  # Spatial and Channel SE attention
        )
        
        # Additional refinement head
        self.refine = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 1)
        )
    
    def forward(self, x):
        out = self.base_model(x)
        out = self.refine(out)
        return out

print("Model architecture defined.")

In [None]:
# Cell 9: Precision/Recall Metrics
def compute_iou(mask1, mask2):
    """Compute IoU between two binary masks"""
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    if union == 0:
        return 0.0
    return intersection / union


def compute_precision_recall(pred_mask, gt_mask, threshold=0.5):
    """Compute pixel-level precision and recall"""
    pred_binary = (pred_mask > threshold).astype(np.float32)
    gt_binary = gt_mask.astype(np.float32)
    
    tp = np.logical_and(pred_binary, gt_binary).sum()
    fp = np.logical_and(pred_binary, ~gt_binary.astype(bool)).sum()
    fn = np.logical_and(~pred_binary.astype(bool), gt_binary).sum()
    
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    
    return precision, recall, f1


def evaluate_model(model, val_loader, device, threshold=0.5):
    """Evaluate model on validation set"""
    model.eval()
    
    total_precision = 0
    total_recall = 0
    total_f1 = 0
    total_iou = 0
    count = 0
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Evaluating"):
            images = images.to(device)
            masks = masks.numpy()
            
            with autocast('cuda', dtype=torch.bfloat16):
                preds = torch.sigmoid(model(images))
            preds = preds.float().cpu().numpy()
            
            for i in range(len(images)):
                pred = preds[i, 0]
                gt = masks[i]
                
                p, r, f1 = compute_precision_recall(pred, gt, threshold)
                iou = compute_iou(pred > threshold, gt > 0.5)
                
                total_precision += p
                total_recall += r
                total_f1 += f1
                total_iou += iou
                count += 1
    
    return {
        'precision': total_precision / count,
        'recall': total_recall / count,
        'f1': total_f1 / count,
        'iou': total_iou / count
    }

print("Metrics defined.")

In [None]:
# Cell 10: Advanced Post-processing
def apply_morphology(mask, kernel_size=3):
    """Apply morphological operations to clean up mask"""
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
    
    # Close small gaps
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    # Remove small noise
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    
    return mask


def mask_to_polygons_advanced(mask, min_area=50, epsilon_factor=0.003):
    """Convert mask to polygons with advanced filtering"""
    # Apply morphology first
    mask = apply_morphology((mask * 255).astype(np.uint8))
    
    # Find contours
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    polygons = []
    for cnt in contours:
        area = cv2.contourArea(cnt)
        if area < min_area:
            continue
        
        # Get bounding rect for aspect ratio check
        x, y, w, h = cv2.boundingRect(cnt)
        
        # Filter out too thin or too small polygons
        if w < 5 or h < 3:
            continue
        
        # Simplify contour
        epsilon = epsilon_factor * cv2.arcLength(cnt, True)
        approx = cv2.approxPolyDP(cnt, epsilon, True)
        
        # Need at least 4 points for a valid polygon
        if len(approx) >= 4:
            # Convert to convex hull if needed for cleaner polygons
            if len(approx) > 8:
                hull = cv2.convexHull(approx)
                approx = hull
            
            polygons.append(approx.reshape(-1, 2).tolist())
    
    return polygons


def polygons_to_string(polygons):
    """Convert polygons to submission format string"""
    if not polygons:
        return ""
    return "|".join([
        " ".join([f"{int(p[0])} {int(p[1])}" for p in poly])
        for poly in polygons
    ])

print("Post-processing functions defined.")

In [None]:
# Cell 11: Multi-scale TTA Inference
def multi_scale_tta_inference(model, image, device, scales=[1.0, 0.75, 1.25], original_size=None):
    """
    Perform multi-scale TTA inference
    Args:
        model: trained model
        image: input tensor (C, H, W)
        device: torch device
        scales: list of scale factors
        original_size: (h, w) for resizing back
    """
    model.eval()
    _, h, w = image.shape
    
    all_preds = []
    
    with torch.no_grad():
        for scale in scales:
            new_h, new_w = int(h * scale), int(w * scale)
            # Ensure divisible by 32
            new_h = (new_h // 32) * 32
            new_w = (new_w // 32) * 32
            if new_h == 0:
                new_h = 32
            if new_w == 0:
                new_w = 32
            
            # Resize image
            scaled_img = F.interpolate(
                image.unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False
            ).to(device)
            
            # Original prediction
            with autocast('cuda', dtype=torch.bfloat16):
                pred = torch.sigmoid(model(scaled_img))
            pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=False)
            all_preds.append(pred.float())
            
            # Horizontal flip
            with autocast('cuda', dtype=torch.bfloat16):
                pred_flip = torch.sigmoid(model(torch.flip(scaled_img, dims=[3])))
            pred_flip = torch.flip(pred_flip, dims=[3])
            pred_flip = F.interpolate(pred_flip, size=(h, w), mode='bilinear', align_corners=False)
            all_preds.append(pred_flip.float())
    
    # Average all predictions
    final_pred = torch.stack(all_preds, dim=0).mean(dim=0)
    
    # Resize to original size if provided
    if original_size is not None:
        final_pred = F.interpolate(
            final_pred, size=original_size, mode='bilinear', align_corners=False
        )
    
    return final_pred.cpu().numpy()[0, 0]

print("Multi-scale TTA inference defined.")

In [None]:
# Cell 12: Training Function
def train_one_epoch(model, train_loader, criterion, optimizer, scaler, device, accumulation_steps):
    model.train()
    total_loss = 0
    optimizer.zero_grad(set_to_none=True)
    
    pbar = tqdm(train_loader, desc="Training")
    for i, batch in enumerate(pbar):
        images, masks = batch[0], batch[1]
        images = images.to(device, non_blocking=True, memory_format=torch.channels_last)
        masks = masks.to(device, non_blocking=True).unsqueeze(1)
        
        with autocast('cuda', dtype=torch.bfloat16):
            preds = model(images)
            loss = criterion(preds, masks) / accumulation_steps
        
        scaler.scale(loss).backward()
        
        if (i + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
        
        total_loss += loss.item() * accumulation_steps
        pbar.set_postfix({'loss': f'{total_loss / (i + 1):.4f}'})
    
    return total_loss / len(train_loader)


def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validation"):
            images = images.to(device, memory_format=torch.channels_last)
            masks = masks.to(device).unsqueeze(1)
            
            with autocast('cuda', dtype=torch.bfloat16):
                preds = model(images)
                loss = criterion(preds, masks)
            
            total_loss += loss.item()
    
    return total_loss / len(val_loader)

print("Training functions defined.")

In [None]:
# Cell 13: Main Training Pipeline
def main_training():
    print("="*60)
    print("High Precision OCR Training")
    print("="*60)
    
    # Clear memory
    torch.cuda.empty_cache()
    gc.collect()
    
    # Prepare datasets
    print("\n[1/5] Preparing datasets...")
    train_transform = get_train_transform(cfg.RESIZE_TARGET)
    val_transform = get_val_transform(cfg.RESIZE_TARGET)
    
    train_ds = ReceiptDataset(TRAIN_IMG_DIR, TRAIN_JSON, transform=train_transform)
    val_ds = ReceiptDataset(VAL_IMG_DIR, VAL_JSON, transform=val_transform)
    
    print(f"  Train: {len(train_ds)} images")
    print(f"  Val: {len(val_ds)} images")
    
    train_loader = DataLoader(
        train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True,
        num_workers=4, pin_memory=True, drop_last=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=cfg.BATCH_SIZE, shuffle=False,
        num_workers=4, pin_memory=True
    )
    
    # Initialize model
    print("\n[2/5] Initializing model...")
    model = HighPrecisionOCRModel(
        encoder_name=cfg.ENCODER,
        encoder_weights=cfg.ENCODER_WEIGHTS
    ).to(DEVICE)
    model = model.to(memory_format=torch.channels_last)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"  Model: {cfg.ENCODER}")
    print(f"  Parameters: {total_params / 1e6:.1f}M")
    
    # Loss and optimizer
    print("\n[3/5] Setting up training...")
    criterion = CombinedLoss(
        dice_w=cfg.DICE_WEIGHT,
        bce_w=cfg.BCE_WEIGHT,
        focal_w=cfg.FOCAL_WEIGHT,
        boundary_w=0.1
    ).to(DEVICE)
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.LEARNING_RATE,
        weight_decay=0.01
    )
    
    # Scheduler: Warmup + Cosine
    warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.1, total_iters=cfg.WARMUP_EPOCHS
    )
    cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=cfg.EPOCHS - cfg.WARMUP_EPOCHS, eta_min=1e-7
    )
    scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer,
        schedulers=[warmup_scheduler, cosine_scheduler],
        milestones=[cfg.WARMUP_EPOCHS]
    )
    
    scaler = GradScaler('cuda')
    
    # Training loop
    print("\n[4/5] Training...")
    best_loss = float('inf')
    best_f1 = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(1, cfg.EPOCHS + 1):
        print(f"\n--- Epoch {epoch}/{cfg.EPOCHS} ---")
        print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        # Train
        train_loss = train_one_epoch(
            model, train_loader, criterion, optimizer, scaler,
            DEVICE, cfg.ACCUMULATION_STEPS
        )
        
        # Validate
        val_loss = validate(model, val_loader, criterion, DEVICE)
        
        # Evaluate metrics every 5 epochs
        if epoch % 5 == 0 or epoch == cfg.EPOCHS:
            metrics = evaluate_model(model, val_loader, DEVICE, cfg.THRESHOLD)
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            print(f"Precision: {metrics['precision']:.4f} | Recall: {metrics['recall']:.4f} | F1: {metrics['f1']:.4f} | IoU: {metrics['iou']:.4f}")
            
            if metrics['f1'] > best_f1:
                best_f1 = metrics['f1']
                torch.save(model.state_dict(), 'best_f1_model.pth')
                print("New best F1 model saved!")
        else:
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        
        # Save best loss model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), 'best_loss_model.pth')
            print("Best loss model saved!")
            patience_counter = 0
        else:
            patience_counter += 1
        
        scheduler.step()
        
        # Early stopping
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
        
        # Memory cleanup
        torch.cuda.empty_cache()
    
    print("\n[5/5] Training completed!")
    print(f"Best Val Loss: {best_loss:.4f}")
    print(f"Best F1 Score: {best_f1:.4f}")
    
    return model

# Run training
model = main_training()

In [None]:
# Cell 14: Inference with Multi-scale TTA
def run_inference(model_path='best_f1_model.pth', output_csv='submission_high_precision.csv'):
    print("="*60)
    print("High Precision Inference (Multi-scale TTA)")
    print("="*60)
    
    # Load model
    print("\n[1/3] Loading model...")
    model = HighPrecisionOCRModel(
        encoder_name=cfg.ENCODER,
        encoder_weights=None
    ).to(DEVICE)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    model = model.to(memory_format=torch.channels_last)
    print(f"  Loaded: {model_path}")
    
    # Prepare test data
    print("\n[2/3] Running inference...")
    test_transform = get_val_transform(cfg.RESIZE_TARGET)
    
    with open(TEST_JSON, 'r', encoding='utf-8') as f:
        test_data = json.load(f)['images']
    test_images = list(test_data.keys())
    
    preds = {}
    
    for img_name in tqdm(test_images, desc="Inference"):
        img_path = os.path.join(TEST_IMG_DIR, img_name)
        image = cv2.imread(img_path)
        if image is None:
            preds[img_name] = ""
            continue
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        orig_h, orig_w = image.shape[:2]
        
        # Transform
        transformed = test_transform(image=image)
        img_tensor = transformed['image']
        
        # Multi-scale TTA
        mask = multi_scale_tta_inference(
            model, img_tensor, DEVICE,
            scales=cfg.TTA_SCALES,
            original_size=(orig_h, orig_w)
        )
        
        # Post-processing
        polygons = mask_to_polygons_advanced(
            mask > cfg.THRESHOLD,
            min_area=cfg.MIN_AREA,
            epsilon_factor=0.003
        )
        
        preds[img_name] = polygons_to_string(polygons)
    
    # Save submission
    print("\n[3/3] Saving submission...")
    sample_df = pd.read_csv(SAMPLE_SUB)
    sample_df['polygons'] = sample_df['filename'].map(preds).fillna("")
    sample_df.to_csv(output_csv, index=False)
    
    # Statistics
    polygon_counts = sample_df['polygons'].apply(lambda x: len(x.split('|')) if x else 0)
    print(f"\nSubmission saved: {output_csv}")
    print(f"  Total images: {len(sample_df)}")
    print(f"  Avg polygons/image: {polygon_counts.mean():.1f}")
    print(f"  Max polygons: {polygon_counts.max()}")
    print(f"  Min polygons: {polygon_counts.min()}")
    
    return sample_df

# Run inference
submission_df = run_inference()

In [None]:
# Cell 15: Optional - Threshold Optimization
def optimize_threshold(model, val_loader, device, thresholds=[0.3, 0.4, 0.5, 0.6, 0.7]):
    """Find optimal threshold for best F1 score"""
    print("Optimizing threshold...")
    
    best_thresh = 0.5
    best_f1 = 0
    
    for thresh in thresholds:
        metrics = evaluate_model(model, val_loader, device, threshold=thresh)
        print(f"  Threshold {thresh:.1f}: P={metrics['precision']:.4f}, R={metrics['recall']:.4f}, F1={metrics['f1']:.4f}")
        
        if metrics['f1'] > best_f1:
            best_f1 = metrics['f1']
            best_thresh = thresh
    
    print(f"\nOptimal threshold: {best_thresh} (F1={best_f1:.4f})")
    return best_thresh

# Uncomment to run threshold optimization
# val_transform = get_val_transform(cfg.RESIZE_TARGET)
# val_ds = ReceiptDataset(VAL_IMG_DIR, VAL_JSON, transform=val_transform)
# val_loader = DataLoader(val_ds, batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=4)
# optimal_thresh = optimize_threshold(model, val_loader, DEVICE)

In [None]:
# Cell 16: Final Summary
print("="*60)
print("HIGH PRECISION OCR - SUMMARY")
print("="*60)
print(f"""
Key Improvements Applied:
-------------------------
1. Model: UNet++ with {cfg.ENCODER} backbone
2. Resolution: {cfg.RESIZE_TARGET}x{cfg.RESIZE_TARGET} (high res for small text)
3. Loss: Dice({cfg.DICE_WEIGHT}) + BCE({cfg.BCE_WEIGHT}) + Focal({cfg.FOCAL_WEIGHT}) + Boundary(0.1)
4. Augmentation: Perspective, CLAHE, Blur, Noise
5. TTA: Multi-scale ({cfg.TTA_SCALES}) + Horizontal Flip
6. Post-processing: Morphological ops + Adaptive polygon extraction
7. Training: Warmup + Cosine scheduler, gradient clipping

Expected Improvements:
----------------------
- Higher precision through Focal Loss (reduces FP)
- Better small text detection with high resolution
- Sharper boundaries with Boundary Loss
- Robust predictions with Multi-scale TTA
- Cleaner polygons with morphological post-processing

Files Generated:
----------------
- best_loss_model.pth: Best validation loss model
- best_f1_model.pth: Best F1 score model
- submission_high_precision.csv: Final submission
""")
print("Good luck with your submission!")