# üé¥ Card Recognition Training V3

**Mobile-Ready with Synthetic Backgrounds + Two-Stage Identifier**

- Synthetic background augmentation (solid colors, gradients)
- Enhanced sim-to-real transforms for phone cameras
- Two-stage: CNN‚ÜíName, pHash‚ÜíExact Printing

---

## 1Ô∏è‚É£ Setup

In [None]:
!nvidia-smi
import torch
print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
!pip install -q timm albumentations opencv-python-headless tqdm imagehash

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2Ô∏è‚É£ Extract Data

In [None]:
import os, zipfile, json
from pathlib import Path
from PIL import Image, ImageOps
from tqdm.notebook import tqdm
from datetime import datetime
import numpy as np
import cv2
import random

ZIP_PATH = "/content/drive/MyDrive/CardData/card_images.zip"
IMAGE_DIR = "/content/card_images"
CHECKPOINT_DIR = '/content/checkpoints'
DRIVE_OUTPUT = '/content/drive/MyDrive/CardRecognition_Models'
CARD_JSON = '/content/drive/MyDrive/CardData/card-flattened-with-phash.json'

for d in [CHECKPOINT_DIR, DRIVE_OUTPUT]:
    os.makedirs(d, exist_ok=True)

if os.path.exists(f"{IMAGE_DIR}/.extracted"):
    print(f"‚úì Already extracted")
elif os.path.exists(ZIP_PATH):
    print("Extracting...")
    !rm -rf {IMAGE_DIR}
    os.makedirs(IMAGE_DIR, exist_ok=True)
    with zipfile.ZipFile(ZIP_PATH, 'r') as z:
        z.extractall(IMAGE_DIR)
    Path(f"{IMAGE_DIR}/.extracted").touch()
    print(f"‚úì Done")

In [None]:
# Validate images
if os.path.exists(f"{IMAGE_DIR}/.validated"):
    print("‚úì Already validated")
else:
    print("Validating...")
    corrupted = []
    for p in tqdm(list(Path(IMAGE_DIR).glob('*'))):
        if p.suffix.lower() in ['.jpg','.jpeg','.png','.webp']:
            try:
                with Image.open(p) as img: img.verify()
                with Image.open(p) as img: img.load()
            except:
                corrupted.append(p.name)
                p.unlink()
    Path(f"{IMAGE_DIR}/.validated").touch()
    print(f"‚úì Removed {len(corrupted)} corrupted")

## 3Ô∏è‚É£ Synthetic Background Augmentation

In [None]:
class SyntheticBackground:
    """Generate synthetic colored backgrounds and composite cards onto them."""
    
    def __init__(self, output_size=(480, 640)):
        self.output_size = output_size  # (height, width)
    
    def solid_color(self):
        color = tuple(random.randint(0, 255) for _ in range(3))
        return np.full((*self.output_size, 3), color, dtype=np.uint8)
    
    def gradient(self):
        c1 = np.array([random.randint(0, 255) for _ in range(3)])
        c2 = np.array([random.randint(0, 255) for _ in range(3)])
        arr = np.zeros((*self.output_size, 3), dtype=np.uint8)
        for i in range(self.output_size[0]):
            t = i / self.output_size[0]
            arr[i] = (c1 * (1 - t) + c2 * t).astype(np.uint8)
        return arr
    
    def noise_pattern(self):
        base = np.random.randint(50, 200, 3)
        noise = np.random.randint(-30, 30, (*self.output_size, 3))
        return np.clip(base + noise, 0, 255).astype(np.uint8)
    
    def get_random_bg(self):
        bg_type = random.choice(['solid', 'solid', 'gradient', 'noise'])
        if bg_type == 'solid': return self.solid_color()
        elif bg_type == 'gradient': return self.gradient()
        else: return self.noise_pattern()
    
    def composite(self, card_img, apply_perspective=True):
        """
        Composite card onto random background.
        card_img: numpy array (H, W, 3)
        Returns: numpy array (output_size[0], output_size[1], 3)
        """
        bg = self.get_random_bg()
        h, w = card_img.shape[:2]
        
        # Random scale (40-80% of output height)
        scale = random.uniform(0.4, 0.8)
        new_h = int(self.output_size[0] * scale)
        new_w = int(new_h * w / h)
        
        # Resize card
        card_resized = cv2.resize(card_img, (new_w, new_h))
        
        # Apply perspective warp
        if apply_perspective and random.random() > 0.3:
            pts1 = np.float32([[0, 0], [new_w, 0], [0, new_h], [new_w, new_h]])
            offset = int(new_w * 0.1)
            pts2 = np.float32([
                [random.randint(-offset, offset), random.randint(-offset, offset)],
                [new_w + random.randint(-offset, offset), random.randint(-offset, offset)],
                [random.randint(-offset, offset), new_h + random.randint(-offset, offset)],
                [new_w + random.randint(-offset, offset), new_h + random.randint(-offset, offset)]
            ])
            M = cv2.getPerspectiveTransform(pts1, pts2)
            card_resized = cv2.warpPerspective(card_resized, M, (new_w, new_h), 
                                                borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
        
        # Random position
        max_y = max(0, self.output_size[0] - new_h)
        max_x = max(0, self.output_size[1] - new_w)
        y = random.randint(0, max_y) if max_y > 0 else 0
        x = random.randint(0, max_x) if max_x > 0 else 0
        
        # Paste (simple blend where card has content)
        y_end = min(y + new_h, self.output_size[0])
        x_end = min(x + new_w, self.output_size[1])
        card_crop = card_resized[:y_end-y, :x_end-x]
        
        # Create mask (non-black pixels)
        mask = (card_crop.sum(axis=2) > 30).astype(np.float32)[:,:,np.newaxis]
        bg[y:y_end, x:x_end] = (card_crop * mask + bg[y:y_end, x:x_end] * (1 - mask)).astype(np.uint8)
        
        return bg

synth_bg = SyntheticBackground()
print("‚úì SyntheticBackground ready")

## 4Ô∏è‚É£ Model Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

class GeM(nn.Module):
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    
    def forward(self, x):
        return F.adaptive_avg_pool2d(x.clamp(min=self.eps).pow(self.p), 1).pow(1./self.p).view(x.size(0), -1)

class ColorHistogramBranch(nn.Module):
    def __init__(self, bins=32, output_dim=64):
        super().__init__()
        self.bins = bins
        self.fc = nn.Sequential(nn.Linear(bins*3, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, output_dim))
        self.register_buffer('mean', torch.tensor([0.485,0.456,0.406]).view(1,3,1,1))
        self.register_buffer('std', torch.tensor([0.229,0.224,0.225]).view(1,3,1,1))
    
    def forward(self, x):
        x_denorm = ((x * self.std + self.mean) * 255).clamp(0, 255)
        hists = []
        for i in range(x.shape[0]):
            img = x_denorm[i].permute(1,2,0).cpu().numpy().astype(np.uint8)
            hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
            h = np.histogram(hsv[:,:,0], bins=self.bins, range=(0,180))[0]
            s = np.histogram(hsv[:,:,1], bins=self.bins, range=(0,256))[0]
            v = np.histogram(hsv[:,:,2], bins=self.bins, range=(0,256))[0]
            hist = np.concatenate([h,s,v]).astype(np.float32)
            hists.append(hist / (hist.sum() + 1e-8))
        return self.fc(torch.tensor(np.stack(hists), device=x.device, dtype=torch.float32))

class CardEmbeddingNetV3(nn.Module):
    def __init__(self, embedding_dim=512, color_dim=64, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model('mobilenetv3_small_100', pretrained=pretrained, num_classes=0, global_pool='')
        with torch.no_grad():
            self.num_features = self.backbone(torch.randn(1,3,224,224)).shape[1]
        self.gem = GeM()
        self.color_branch = ColorHistogramBranch(bins=32, output_dim=color_dim)
        self.fc = nn.Linear(self.num_features + color_dim, embedding_dim)
        self.bn = nn.BatchNorm1d(embedding_dim)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        visual = self.gem(self.backbone(x))
        color = self.color_branch(x)
        return F.normalize(self.dropout(self.bn(self.fc(torch.cat([visual, color], dim=1)))), p=2, dim=1)

print("‚úì Model ready")

## 5Ô∏è‚É£ Enhanced Sim-to-Real Dataset

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader

def get_heavy_augmentations(size=224):
    """Aggressive sim-to-real augmentations for phone camera robustness."""
    return A.Compose([
        A.Resize(size, size),
        # Geometric
        A.Perspective(scale=(0.02, 0.1), p=0.5),
        A.Rotate(limit=15, border_mode=cv2.BORDER_CONSTANT, p=0.5),
        A.Affine(scale=(0.9, 1.1), shear=(-5, 5), p=0.3),
        # Camera effects
        A.OneOf([
            A.MotionBlur(blur_limit=7),
            A.GaussianBlur(blur_limit=5),
            A.MedianBlur(blur_limit=5),
        ], p=0.4),
        # Lighting
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.RandomGamma(gamma_limit=(70, 130), p=0.3),
        A.RandomShadow(shadow_roi=(0, 0.3, 1, 1), p=0.2),
        # Noise
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50)),
            A.ISONoise(intensity=(0.1, 0.5)),
        ], p=0.3),
        # Compression
        A.ImageCompression(quality_lower=60, quality_upper=100, p=0.3),
        # Normalize
        A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ToTensorV2()
    ])

def get_val_transforms(size=224):
    return A.Compose([
        A.Resize(size, size),
        A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ToTensorV2()
    ])

class SimToRealDataset(Dataset):
    def __init__(self, image_dir, transform=None, use_backgrounds=True, rotations=[0,90,180,270]):
        self.image_dir = Path(image_dir)
        self.transform = transform
        self.use_backgrounds = use_backgrounds
        self.rotations = rotations
        self.synth_bg = SyntheticBackground() if use_backgrounds else None
        
        self.images = sorted([f for f in self.image_dir.iterdir() 
                              if f.suffix.lower() in ['.jpg','.jpeg','.png','.webp']])
        self.num_cards = len(self.images)
        self.samples = [(i, r) for i in range(len(self.images)) for r in rotations]
        print(f"Dataset: {self.num_cards} cards √ó {len(rotations)} rot = {len(self.samples)} samples")
    
    def __len__(self): return len(self.samples)
    
    def __getitem__(self, idx):
        img_idx, rotation = self.samples[idx]
        try:
            img = np.array(Image.open(self.images[img_idx]).convert('RGB'))
            
            # Apply rotation
            if rotation == 90: img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
            elif rotation == 180: img = cv2.rotate(img, cv2.ROTATE_180)
            elif rotation == 270: img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
            
            # 50% chance to composite onto synthetic background
            if self.use_backgrounds and random.random() > 0.5:
                img = self.synth_bg.composite(img)
            
            if self.transform:
                img = self.transform(image=img)['image']
            return img, img_idx
        except:
            return self.__getitem__(random.randint(0, len(self.samples)-1))
    
    def get_num_classes(self): return self.num_cards

print("‚úì SimToRealDataset ready")

## 6Ô∏è‚É£ Training Config & Setup

In [None]:
class CosFaceLoss(nn.Module):
    def __init__(self, num_classes, embedding_dim, scale=64.0, margin=0.5):
        super().__init__()
        self.scale, self.margin = scale, margin
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, embedding_dim))
        nn.init.xavier_uniform_(self.weight)
    
    def forward(self, embeddings, labels):
        W = F.normalize(self.weight, p=2, dim=1)
        cosine = F.linear(embeddings, W)
        one_hot = torch.zeros_like(cosine).scatter_(1, labels.view(-1,1), 1.0)
        return F.cross_entropy((cosine - one_hot * self.margin) * self.scale, labels)

CONFIG = {
    'epochs': 100,
    'batch_size': 64,
    'learning_rate': 3e-4,
    'weight_decay': 5e-4,
    'embedding_dim': 512,
    'patience': 10,
    'unfreeze_epoch': 15,
}
print("‚úì Config:", CONFIG)

In [None]:
def create_dataloaders(image_dir, batch_size=64, val_split=0.15):
    train_ds = SimToRealDataset(image_dir, get_heavy_augmentations(), use_backgrounds=True)
    val_ds = SimToRealDataset(image_dir, get_val_transforms(), use_backgrounds=False, rotations=[0])
    
    indices = np.random.permutation(train_ds.num_cards)
    split = int((1 - val_split) * train_ds.num_cards)
    train_idx, val_idx = set(indices[:split]), set(indices[split:])
    
    train_samples = [i for i, (c, _) in enumerate(train_ds.samples) if c in train_idx]
    val_samples = [i for i, (c, _) in enumerate(val_ds.samples) if c in val_idx]
    
    train_loader = DataLoader(torch.utils.data.Subset(train_ds, train_samples),
                              batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
    val_loader = DataLoader(torch.utils.data.Subset(val_ds, val_samples),
                            batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    print(f"Train: {len(train_samples):,} | Val: {len(val_samples):,}")
    return train_loader, val_loader, train_ds.get_num_classes(), train_ds

train_loader, val_loader, num_classes, train_ds = create_dataloaders(IMAGE_DIR, CONFIG['batch_size'])
print(f"‚úì Classes: {num_classes:,}")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CardEmbeddingNetV3(embedding_dim=CONFIG['embedding_dim']).to(device)

for p in model.backbone.parameters(): p.requires_grad = False
print("‚úì Backbone frozen")

criterion = CosFaceLoss(num_classes, CONFIG['embedding_dim']).to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                              lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['epochs'])
scaler = torch.amp.GradScaler('cuda')
print("‚úì Ready")

## 7Ô∏è‚É£ Training Loop

In [None]:
best_loss = float('inf')
patience_counter = 0
history = {'train': [], 'val': []}
RESUME_PATH = f"{CHECKPOINT_DIR}/best_model.pth"

for epoch in range(1, CONFIG['epochs'] + 1):
    if epoch == CONFIG['unfreeze_epoch']:
        print("\nüîì Unfreezing backbone...")
        for p in model.backbone.parameters(): p.requires_grad = True
        optimizer = torch.optim.AdamW([
            {'params': model.backbone.parameters(), 'lr': CONFIG['learning_rate']/10},
            {'params': model.gem.parameters()},
            {'params': model.color_branch.parameters()},
            {'params': model.fc.parameters()},
            {'params': model.bn.parameters()},
        ], lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['epochs']-epoch)
    
    model.train()
    train_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        images, labels = images.to(device), labels.to(device)
        with torch.amp.autocast('cuda'):
            loss = criterion(model(images), labels)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    train_loss /= len(train_loader)
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            val_loss += criterion(model(images), labels).item()
    val_loss /= len(val_loader)
    
    scheduler.step()
    history['train'].append(train_loss)
    history['val'].append(val_loss)
    print(f"Epoch {epoch}: Train={train_loss:.4f}, Val={val_loss:.4f}")
    
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                    'val_loss': val_loss, 'num_classes': num_classes, 'config': CONFIG}, RESUME_PATH)
        print("  üíæ Saved")
    else:
        patience_counter += 1
        if patience_counter >= CONFIG['patience']:
            print("\n‚ö†Ô∏è Early stopping!")
            break

print(f"\n‚úì Done! Best: {best_loss:.4f}")

## 8Ô∏è‚É£ Build Reference Embeddings

In [None]:
import matplotlib.pyplot as plt

if history['train']:
    plt.figure(figsize=(10,4))
    plt.plot(history['train'], label='Train')
    plt.plot(history['val'], label='Val')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()
    plt.savefig(f"{CHECKPOINT_DIR}/training.png")
    plt.show()

In [None]:
import shutil
for f in ['best_model.pth', 'training.png']:
    src = f"{CHECKPOINT_DIR}/{f}"
    if os.path.exists(src):
        shutil.copy(src, DRIVE_OUTPUT)
        print(f"‚úì Saved {f}")

In [None]:
print("Building reference embeddings...")
ckpt = torch.load(RESUME_PATH)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

test_transform = get_val_transforms()
reference_embeddings = []
reference_names = []

with torch.no_grad():
    for img_path in tqdm(train_ds.images, desc="Building refs"):
        try:
            img = np.array(Image.open(img_path).convert('RGB'))
            emb = model(test_transform(image=img)['image'].unsqueeze(0).to(device))
            reference_embeddings.append(emb.cpu())
            reference_names.append(img_path.stem)
        except: pass

reference_embeddings = torch.cat(reference_embeddings, dim=0)
print(f"‚úì {len(reference_embeddings):,} embeddings")

## 9Ô∏è‚É£ Two-Stage Identifier (CNN + pHash)

In [None]:
import imagehash

# Load card metadata
with open(CARD_JSON, 'r') as f:
    all_cards = json.load(f)

card_lookup = {c['printing_unique_id']: c for c in all_cards}
print(f'‚úì Loaded {len(all_cards):,} cards')

# Build name -> printings lookup
name_to_printings = {}
for card in all_cards:
    name = card.get('name', '')
    if name not in name_to_printings:
        name_to_printings[name] = []
    name_to_printings[name].append({
        'printing_id': card['printing_unique_id'],
        'set_id': card.get('set_id', ''),
        'edition': card.get('edition', ''),
        'foiling': card.get('foiling', ''),
        'card_id': card.get('id', ''),
        'phash': card.get('image_phash', '')[:64]
    })

print(f'‚úì Built name->printings for {len(name_to_printings):,} unique card names')

In [None]:
class TwoStageIdentifier:
    """
    Stage 1: CNN identifies card NAME (robust to editions)
    Stage 2: pHash identifies exact PRINTING (specific edition/foil)
    """
    
    def __init__(self, model, ref_embeddings, ref_names, card_lookup, name_to_printings, device):
        self.model = model
        self.ref_embeddings = ref_embeddings
        self.ref_names = ref_names
        self.card_lookup = card_lookup
        self.name_to_printings = name_to_printings
        self.device = device
        self.transform = get_val_transforms()
    
    def identify(self, image_input, top_k=5):
        # Load image
        if isinstance(image_input, str):
            pil_img = Image.open(image_input).convert('RGB')
        elif isinstance(image_input, np.ndarray):
            pil_img = Image.fromarray(image_input)
        else:
            pil_img = image_input.convert('RGB')
        
        pil_img = ImageOps.autocontrast(pil_img, cutoff=1)
        img_array = np.array(pil_img)
        
        # STAGE 1: CNN - Find card candidates
        with torch.no_grad():
            tensor = self.transform(image=img_array)['image'].unsqueeze(0).to(self.device)
            query_emb = self.model(tensor).cpu()
        
        sims = F.cosine_similarity(query_emb, self.ref_embeddings)
        top_indices = sims.argsort(descending=True)[:top_k]
        
        top_id = self.ref_names[top_indices[0]]
        card_info = self.card_lookup.get(top_id, {})
        card_name = card_info.get('name', 'Unknown')
        
        result = {
            'name': card_name,
            'cnn_confidence': sims[top_indices[0]].item(),
            'cnn_printing_id': top_id,
            'cnn_top_matches': [
                {'id': self.ref_names[i], 'score': sims[i].item(), 
                 'name': self.card_lookup.get(self.ref_names[i], {}).get('name', '?')}
                for i in top_indices
            ]
        }
        
        # STAGE 2: pHash - Find exact printing
        query_phash = imagehash.phash(pil_img, hash_size=16)
        printings = self.name_to_printings.get(card_name, [])
        
        if printings:
            best_match = None
            best_distance = 999
            
            for p in printings:
                if p['phash']:
                    try:
                        ref_hash = imagehash.hex_to_hash(p['phash'])
                        distance = query_phash - ref_hash
                        if distance < best_distance:
                            best_distance = distance
                            best_match = p
                    except: pass
            
            if best_match:
                result['printing_id'] = best_match['printing_id']
                result['set_id'] = best_match['set_id']
                result['card_id'] = best_match['card_id']
                result['edition'] = best_match['edition']
                result['foiling'] = best_match['foiling']
                result['phash_distance'] = best_distance
                result['total_printings'] = len(printings)
        
        return result

identifier = TwoStageIdentifier(model, reference_embeddings, reference_names, 
                                 card_lookup, name_to_printings, device)
print('‚úì TwoStageIdentifier ready!')

## üîü Test with Uploaded Images

In [None]:
from google.colab import files

print('Upload card images to test...')
uploaded = files.upload()

for filename in uploaded.keys():
    print(f'\n{"="*60}')
    result = identifier.identify(filename)
    
    print(f'üé¥ {result["name"]}')
    print(f'   CNN Confidence: {result["cnn_confidence"]*100:.1f}%')
    print(f'   Printing: {result.get("set_id", "?")} {result.get("card_id", "?")} ({result.get("foiling", "?")})')
    print(f'   pHash Distance: {result.get("phash_distance", "N/A")} (lower=better)')
    print(f'   Total printings of this card: {result.get("total_printings", 1)}')
    
    plt.figure(figsize=(6,8))
    plt.imshow(Image.open(filename))
    plt.title(f'{result["name"]}\n{result.get("set_id","")} {result.get("card_id","")}')
    plt.axis('off')
    plt.show()

In [None]:
def quick_id(img_path):
    r = identifier.identify(img_path)
    print(f'üé¥ {r["name"]} | {r.get("set_id","")} {r.get("card_id","")} | {r["cnn_confidence"]*100:.0f}%')
    return r

print('‚úì quick_id() ready')

## ‚úÖ Done!

**Features:**
- Synthetic background augmentation
- Heavy sim-to-real transforms
- Two-stage: CNN‚ÜíName, pHash‚ÜíPrinting

**Model:** `MyDrive/CardRecognition_Models/best_model.pth`