# üé¥ Card Recognition Training

**Train on Colab ‚Üí Deploy on Jetson Nano**

Features: MobileNetV3-Small, Color histogram, CosFace loss, On-the-fly augmentation

---

## 1Ô∏è‚É£ Setup

In [None]:
# Check GPU
!nvidia-smi

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Clone from GitHub (skips if already done)
GITHUB_REPO = "Krishan552Patel/Card-recognition-fab"

import os
os.chdir('/content')

WORK_DIR = "/content/card_recognition"

if os.path.exists(WORK_DIR) and os.path.exists(f"{WORK_DIR}/.git"):
    print("‚úì Repo already cloned, pulling latest...")
    os.chdir(WORK_DIR)
    !git pull
else:
    if os.path.exists(WORK_DIR):
        !rm -rf {WORK_DIR}
    !git clone https://github.com/{GITHUB_REPO}.git {WORK_DIR}
    os.chdir(WORK_DIR)
    print("‚úì Cloned successfully")

print(f"Working directory: {os.getcwd()}")

In [None]:
# Install dependencies (skips if already installed)
try:
    import timm, albumentations
    print("‚úì Dependencies already installed")
except ImportError:
    print("Installing dependencies...")
    !pip install -q timm albumentations opencv-python-headless tqdm tensorboard imagehash
    print("‚úì Installed")

## 2Ô∏è‚É£ Load Data from Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Extract ZIP (skips if already done)
ZIP_PATH = "/content/drive/MyDrive/CardData/card_images.zip"
IMAGE_DIR = "/content/card_images"
EXTRACTED_MARKER = f"{IMAGE_DIR}/.extracted"

import zipfile
from pathlib import Path

if os.path.exists(EXTRACTED_MARKER):
    images = list(Path(IMAGE_DIR).glob("*.jpg")) + list(Path(IMAGE_DIR).glob("*.png"))
    print(f"‚úì Already extracted: {len(images):,} images")
elif os.path.exists(ZIP_PATH):
    print(f"Extracting {ZIP_PATH}...")
    if os.path.exists(IMAGE_DIR):
        !rm -rf {IMAGE_DIR}
    os.makedirs(IMAGE_DIR, exist_ok=True)
    
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(IMAGE_DIR)
    
    # Create marker file
    Path(EXTRACTED_MARKER).touch()
    
    images = list(Path(IMAGE_DIR).glob("*.jpg")) + list(Path(IMAGE_DIR).glob("*.png"))
    print(f"‚úì Extracted {len(images):,} images")
else:
    print(f"‚ùå ZIP not found: {ZIP_PATH}")

In [None]:
# Validate images and save corrupted list (skips if already done)
from PIL import Image
from tqdm.notebook import tqdm
import json
from datetime import datetime

VALIDATED_MARKER = f"{IMAGE_DIR}/.validated"
CORRUPTED_LOG = "/content/drive/MyDrive/CardData/corrupted_images.json"

if os.path.exists(VALIDATED_MARKER):
    print("‚úì Images already validated")
    if os.path.exists(CORRUPTED_LOG):
        with open(CORRUPTED_LOG, 'r') as f:
            data = json.load(f)
        print(f"  Previous corrupted files: {len(data.get('corrupted', []))}")
else:
    print("Validating images...")
    image_dir = Path(IMAGE_DIR)
    all_images = list(image_dir.glob("*.jpg")) + list(image_dir.glob("*.png")) + list(image_dir.glob("*.jpeg"))
    
    valid_count = 0
    corrupted_files = []
    
    for img_path in tqdm(all_images, desc="Checking"):
        try:
            with Image.open(img_path) as img:
                img.verify()
            with Image.open(img_path) as img:
                img.load()
            valid_count += 1
        except Exception as e:
            corrupted_files.append({
                'filename': img_path.name,
                'error': str(e)
            })
            print(f"  ‚ö†Ô∏è Corrupted: {img_path.name}")
            img_path.unlink()  # Remove corrupted file
    
    # Save corrupted list to Google Drive
    log_data = {
        'validated_at': datetime.now().isoformat(),
        'total_checked': len(all_images),
        'valid_count': valid_count,
        'corrupted': corrupted_files
    }
    
    with open(CORRUPTED_LOG, 'w') as f:
        json.dump(log_data, f, indent=2)
    
    # Create marker
    Path(VALIDATED_MARKER).touch()
    
    print(f"\n‚úì Valid: {valid_count:,}")
    if corrupted_files:
        print(f"‚ö†Ô∏è Corrupted: {len(corrupted_files)} (list saved to {CORRUPTED_LOG})")

## 3Ô∏è‚É£ Model Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
import cv2

class GeM(nn.Module):
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    
    def forward(self, x):
        x = x.clamp(min=self.eps).pow(self.p)
        x = F.adaptive_avg_pool2d(x, 1).pow(1.0 / self.p)
        return x.view(x.size(0), -1)

class ColorHistogramBranch(nn.Module):
    def __init__(self, bins=32, output_dim=64):
        super().__init__()
        self.bins = bins
        self.fc = nn.Sequential(
            nn.Linear(bins * 3, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, output_dim)
        )
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
    def forward(self, x):
        x_denorm = ((x * self.std + self.mean) * 255).clamp(0, 255)
        histograms = []
        for i in range(x.shape[0]):
            img = x_denorm[i].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
            hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
            h = np.histogram(hsv[:,:,0], bins=self.bins, range=(0, 180))[0]
            s = np.histogram(hsv[:,:,1], bins=self.bins, range=(0, 256))[0]
            v = np.histogram(hsv[:,:,2], bins=self.bins, range=(0, 256))[0]
            hist = np.concatenate([h, s, v]).astype(np.float32)
            histograms.append(hist / (hist.sum() + 1e-8))
        return self.fc(torch.tensor(np.stack(histograms), device=x.device, dtype=torch.float32))

class CardEmbeddingNetV2(nn.Module):
    def __init__(self, embedding_dim=512, color_dim=64, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model('mobilenetv3_small_100', pretrained=pretrained,
                                          num_classes=0, global_pool='')
        with torch.no_grad():
            self.num_features = self.backbone(torch.randn(1, 3, 224, 224)).shape[1]
        self.gem = GeM(p=3.0)
        self.color_branch = ColorHistogramBranch(bins=32, output_dim=color_dim)
        self.fc = nn.Linear(self.num_features + color_dim, embedding_dim)
        self.bn = nn.BatchNorm1d(embedding_dim)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        visual = self.gem(self.backbone(x))
        color = self.color_branch(x)
        embedding = self.dropout(self.bn(self.fc(torch.cat([visual, color], dim=1))))
        return F.normalize(embedding, p=2, dim=1)

model = CardEmbeddingNetV2()
print(f"‚úì Model: {sum(p.numel() for p in model.parameters()):,} params")

## 4Ô∏è‚É£ Loss & Dataset

In [None]:
class CosFaceLoss(nn.Module):
    def __init__(self, num_classes, embedding_dim, scale=30.0, margin=0.35):
        super().__init__()
        self.scale, self.margin = scale, margin
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, embedding_dim))
        nn.init.xavier_uniform_(self.weight)
    
    def forward(self, embeddings, labels):
        W = F.normalize(self.weight, p=2, dim=1)
        cosine = F.linear(embeddings, W)
        one_hot = torch.zeros_like(cosine).scatter_(1, labels.view(-1, 1), 1.0)
        return F.cross_entropy((cosine - one_hot * self.margin) * self.scale, labels)

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random

def get_train_transforms(size=224):
    return A.Compose([
        A.Resize(size, size),
        A.Perspective(scale=(0.02, 0.05), p=0.3),
        A.Affine(scale=(0.97, 1.03), rotate=(-2, 2), p=0.3),
        A.OneOf([A.GaussianBlur(blur_limit=(3,5)), A.MotionBlur(blur_limit=(3,5))], p=0.2),
        A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.4),
        A.HueSaturationValue(hue_shift_limit=3, sat_shift_limit=10, val_shift_limit=10, p=0.2),
        A.GaussNoise(var_limit=(5, 20), p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def get_val_transforms(size=224):
    return A.Compose([
        A.Resize(size, size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

class CardDatasetWithRotation(Dataset):
    def __init__(self, image_dir, transform=None, rotations=[0, 90, 180, 270]):
        self.image_dir = Path(image_dir)
        self.transform = transform
        self.rotations = rotations
        self.images = sorted([f for f in self.image_dir.iterdir() 
                              if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.webp']])
        self.num_cards = len(self.images)
        self.samples = [(i, r) for i in range(len(self.images)) for r in rotations]
        print(f"Dataset: {self.num_cards} cards √ó {len(rotations)} rot = {len(self.samples)} samples")
    
    def __len__(self): return len(self.samples)
    
    def __getitem__(self, idx):
        img_idx, rotation = self.samples[idx]
        try:
            with Image.open(self.images[img_idx]) as pil_img:
                img = np.array(pil_img.convert('RGB'))
            if rotation == 90: img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
            elif rotation == 180: img = cv2.rotate(img, cv2.ROTATE_180)
            elif rotation == 270: img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
            if self.transform: img = self.transform(image=img)['image']
            return img, img_idx
        except:
            return self.__getitem__(random.randint(0, len(self.samples)-1))
    
    def get_num_classes(self): return self.num_cards

def create_dataloaders(image_dir, batch_size=64, val_split=0.15):
    train_ds = CardDatasetWithRotation(image_dir, get_train_transforms())
    val_ds = CardDatasetWithRotation(image_dir, get_val_transforms(), rotations=[0])
    indices = np.random.permutation(train_ds.num_cards)
    split = int((1 - val_split) * train_ds.num_cards)
    train_idx = set(indices[:split])
    val_idx = set(indices[split:])
    train_samples = [i for i, (c, _) in enumerate(train_ds.samples) if c in train_idx]
    val_samples = [i for i, (c, _) in enumerate(val_ds.samples) if c in val_idx]
    train_loader = DataLoader(torch.utils.data.Subset(train_ds, train_samples),
                              batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(torch.utils.data.Subset(val_ds, val_samples),
                            batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    print(f"Train: {len(train_samples)} | Val: {len(val_samples)}")
    return train_loader, val_loader, train_ds.get_num_classes(), train_ds

train_loader, val_loader, num_classes, train_ds = create_dataloaders(IMAGE_DIR, batch_size=4)
print(f"‚úì Classes: {num_classes}")

## 5Ô∏è‚É£ Training

In [None]:
CONFIG = {
    'epochs': 100, 'batch_size': 64, 'learning_rate': 1e-3,
    'weight_decay': 1e-4, 'embedding_dim': 512, 'patience': 15, 'unfreeze_epoch': 6
}
print("Config:", CONFIG)

In [None]:
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = '/content/checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Check if we can resume
RESUME_PATH = f"{CHECKPOINT_DIR}/best_model.pth"
start_epoch = 1

train_loader, val_loader, num_classes, train_ds = create_dataloaders(IMAGE_DIR, CONFIG['batch_size'])
model = CardEmbeddingNetV2(embedding_dim=CONFIG['embedding_dim']).to(device)

if os.path.exists(RESUME_PATH):
    print(f"Found existing checkpoint, loading...")
    ckpt = torch.load(RESUME_PATH)
    model.load_state_dict(ckpt['model_state_dict'])
    start_epoch = ckpt.get('epoch', 0) + 1
    print(f"‚úì Resuming from epoch {start_epoch}")
else:
    for p in model.backbone.parameters():
        p.requires_grad = False

print(f"Device: {device}")
criterion = CosFaceLoss(num_classes, CONFIG['embedding_dim']).to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                              lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['epochs'])
scaler = torch.cuda.amp.GradScaler()

In [None]:
# Training loop (can resume)
best_loss = float('inf')
patience_counter = 0
history = {'train': [], 'val': []}

for epoch in range(start_epoch, CONFIG['epochs'] + 1):
    if epoch == CONFIG['unfreeze_epoch']:
        print(f"\nüîì Unfreezing backbone...")
        for p in model.backbone.parameters():
            p.requires_grad = True
        optimizer = torch.optim.AdamW([
            {'params': model.backbone.parameters(), 'lr': CONFIG['learning_rate'] / 10},
            {'params': model.gem.parameters()},
            {'params': model.color_branch.parameters()},
            {'params': model.fc.parameters()},
            {'params': model.bn.parameters()},
        ], lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
    
    model.train()
    train_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        images, labels = images.to(device), labels.to(device)
        with torch.amp.autocast('cuda'):
            loss = criterion(model(images), labels)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    train_loss /= len(train_loader)
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            val_loss += criterion(model(images), labels).item()
    val_loss /= len(val_loader)
    
    scheduler.step()
    history['train'].append(train_loss)
    history['val'].append(val_loss)
    print(f"Epoch {epoch}: Train={train_loss:.4f}, Val={val_loss:.4f}")
    
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch, 'model_state_dict': model.state_dict(),
            'val_loss': val_loss, 'num_classes': num_classes, 'config': CONFIG
        }, RESUME_PATH)
        print(f"  üíæ Saved")
    else:
        patience_counter += 1
        if patience_counter >= CONFIG['patience']:
            print(f"\n‚ö†Ô∏è Early stopping!")
            break

print(f"\n‚úì Training complete! Best loss: {best_loss:.4f}")

## 6Ô∏è‚É£ Export & Save

In [None]:
import matplotlib.pyplot as plt

if history['train']:
    plt.figure(figsize=(10, 4))
    plt.plot(history['train'], label='Train')
    plt.plot(history['val'], label='Val')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()
    plt.savefig(f"{CHECKPOINT_DIR}/training.png")
    plt.show()

In [None]:
# Export to ONNX
onnx_path = f"{CHECKPOINT_DIR}/card_recognition.onnx"

if not os.path.exists(onnx_path):
    ckpt = torch.load(RESUME_PATH)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    dummy = torch.randn(1, 3, 224, 224).to(device)
    torch.onnx.export(model, dummy, onnx_path,
                      input_names=['image'], output_names=['embedding'],
                      dynamic_axes={'image': {0: 'batch'}, 'embedding': {0: 'batch'}},
                      opset_version=11)
    print(f"‚úì ONNX exported: {os.path.getsize(onnx_path)/1024/1024:.1f} MB")
else:
    print("‚úì ONNX already exists")

In [None]:
# Save to Google Drive
import shutil
DRIVE_OUTPUT = '/content/drive/MyDrive/CardRecognition_Models'
os.makedirs(DRIVE_OUTPUT, exist_ok=True)

for f in ['best_model.pth', 'card_recognition.onnx', 'training.png']:
    src = f"{CHECKPOINT_DIR}/{f}"
    if os.path.exists(src):
        shutil.copy(src, DRIVE_OUTPUT)

print(f"‚úì Saved to: {DRIVE_OUTPUT}")
!ls -lh {DRIVE_OUTPUT}

## 7Ô∏è‚É£ Test Model

In [None]:
# Build reference embeddings
print("Building reference embeddings...")
model.eval()
test_transform = get_val_transforms()
reference_embeddings, reference_names = [], []

with torch.no_grad():
    for img_path in tqdm(train_ds.images, desc="Building refs"):
        try:
            img = np.array(Image.open(img_path).convert('RGB'))
            img_tensor = test_transform(image=img)['image'].unsqueeze(0).to(device)
            reference_embeddings.append(model(img_tensor).cpu())
            reference_names.append(img_path.stem)
        except: pass

reference_embeddings = torch.cat(reference_embeddings, dim=0)
print(f"‚úì {len(reference_embeddings)} embeddings")

In [None]:
# Test on random cards
test_cards = random.sample(list(train_ds.images), min(5, len(train_ds.images)))
fig, axes = plt.subplots(1, len(test_cards), figsize=(4*len(test_cards), 5))
if len(test_cards) == 1: axes = [axes]

correct = 0
for i, card_path in enumerate(test_cards):
    img = np.array(Image.open(card_path).convert('RGB'))
    with torch.no_grad():
        query = model(test_transform(image=img)['image'].unsqueeze(0).to(device)).cpu()
    sims = F.cosine_similarity(query, reference_embeddings)
    top_idx = sims.argmax().item()
    
    actual = card_path.stem
    predicted = reference_names[top_idx]
    is_correct = actual == predicted
    if is_correct: correct += 1
    
    axes[i].imshow(Image.open(card_path))
    axes[i].axis('off')
    axes[i].set_title(f"{'‚úÖ' if is_correct else '‚ùå'} {sims[top_idx].item()*100:.1f}%")

plt.tight_layout()
plt.savefig(f"{CHECKPOINT_DIR}/test_results.png")
plt.show()
print(f"Accuracy: {correct}/{len(test_cards)} = {100*correct/len(test_cards):.1f}%")

## ‚úÖ Done!

**Corrupted images list saved to:** `MyDrive/CardData/corrupted_images.json`

**Models saved to:** `MyDrive/CardRecognition_Models/`

**Deploy to Jetson:**
```bash
trtexec --onnx=card_recognition.onnx --saveEngine=card.engine --fp16
```