# Font Embedding Model Training (Phase 2)

This notebook fine-tunes OpenCLIP ViT-B-32 on your font dataset to create font embeddings.

**Prerequisites:** You must have the `font_dataset/` folder (with `metadata.json` and `samples/`) ready to upload.

## Workflow
1. Set your project path in the config cell below
2. Upload your `font_dataset/` folder to Google Drive
3. Run all cells
4. Download the trained `best_model.pt` back to your local project

## 0. Configuration

**Change `PROJECT_DIR` to match where you placed `font_dataset/` inside Google Drive.**

In [None]:
#
# === CONFIGURATION - EDIT THIS ===
#
# Path inside Google Drive where your project lives.
# After mounting, Drive is at /content/drive/MyDrive/
# Example: if you uploaded font_dataset/ into Drive > proj1_check_fonts > check_fonts
#   then set PROJECT_DIR = "/content/drive/MyDrive/proj1_check_fonts/check_fonts"
#
PROJECT_DIR = "/content/drive/MyDrive/proj1_check_fonts/check_fonts"

# Training hyperparameters
MODEL_NAME   = "ViT-B-32"
PRETRAINED   = "openai"
EPOCHS       = 15
LEARNING_RATE = 1e-4
BATCH_SIZE   = 32
AUGMENTATION = True   # Set False to disable augmentation
NUM_WORKERS  = 2

# Derived paths (no need to edit)
DATASET_DIR  = f"{PROJECT_DIR}/font_dataset"
METADATA     = f"{PROJECT_DIR}/font_dataset/metadata.json"
SAVE_DIR     = f"{PROJECT_DIR}/models"

## 1. Mount Google Drive & Install Dependencies

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
!pip install -q open-clip-torch pillow tqdm numpy

In [None]:
# Verify GPU is available
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available:  {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU:             {torch.cuda.get_device_name(0)}")
    print(f"VRAM:            {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Go to Runtime > Change runtime type > T4 GPU")

In [None]:
import os
from pathlib import Path

# Verify dataset exists
dataset_path = Path(DATASET_DIR)
metadata_path = Path(METADATA)

assert dataset_path.exists(), f"Dataset directory not found: {DATASET_DIR}\nMake sure you uploaded font_dataset/ to the correct Drive path."
assert metadata_path.exists(), f"metadata.json not found: {METADATA}"

import json
with open(METADATA, "r") as f:
    meta = json.load(f)
print(f"Dataset verified: {meta['num_fonts']} fonts, "
      f"{sum(fd['num_samples'] for fd in meta['fonts'])} total samples")

## 2. Define Model & Dataset Classes

These are identical to `train_embedding_model.py` so that checkpoints are compatible.

In [None]:
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from pathlib import Path
from PIL import Image, ImageFilter, ImageEnhance
import numpy as np
from tqdm.notebook import tqdm
import random
import open_clip


# ── Augmentation helpers ──────────────────────────────────────────

class GaussianBlur:
    def __init__(self, radius_range=(0.5, 2.0)):
        self.radius_range = radius_range
    def __call__(self, img):
        radius = random.uniform(*self.radius_range)
        return img.filter(ImageFilter.GaussianBlur(radius=radius))


class AddGaussianNoise:
    def __init__(self, mean=0, std_range=(5, 25)):
        self.mean = mean
        self.std_range = std_range
    def __call__(self, img):
        std = random.uniform(*self.std_range)
        arr = np.array(img).astype(np.float32)
        noise = np.random.normal(self.mean, std, arr.shape)
        arr = np.clip(arr + noise, 0, 255).astype(np.uint8)
        return Image.fromarray(arr)


class RandomPerspective:
    def __init__(self, distortion_scale=0.1, p=0.3):
        self.distortion_scale = distortion_scale
        self.p = p
    def __call__(self, img):
        if random.random() < self.p:
            w, h = img.size
            startpoints = [[0, 0], [w, 0], [w, h], [0, h]]
            endpoints = []
            for x, y in startpoints:
                dx = random.uniform(-w * self.distortion_scale, w * self.distortion_scale)
                dy = random.uniform(-h * self.distortion_scale, h * self.distortion_scale)
                endpoints.append([x + dx, y + dy])
            coeffs = self._get_perspective_coeffs(startpoints, endpoints)
            return img.transform(img.size, Image.Transform.PERSPECTIVE, coeffs, Image.Resampling.BILINEAR)
        return img
    def _get_perspective_coeffs(self, src, dst):
        matrix = []
        for p1, p2 in zip(src, dst):
            matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]])
            matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]])
        A = np.array(matrix, dtype=np.float32)
        B = np.array(dst, dtype=np.float32).reshape(8)
        res = np.linalg.lstsq(A, B, rcond=None)[0]
        return np.concatenate([res, [1.0]]).reshape(9)


# ── Dataset ────────────────────────────────────────────────────────

class FontDataset(Dataset):
    def __init__(self, dataset_dir, metadata_file, transform=None, augment=False):
        self.dataset_dir = Path(dataset_dir)
        self.transform = transform
        self.augment = augment

        with open(metadata_file, "r") as f:
            metadata = json.load(f)

        self.samples = []
        self.font_to_idx = {}
        self.idx_to_font = {}
        font_idx = 0
        found = missing = 0

        for font_data in metadata["fonts"]:
            font_name = font_data["name"]
            if font_name not in self.font_to_idx:
                self.font_to_idx[font_name] = font_idx
                self.idx_to_font[font_idx] = font_name
                font_idx += 1
            label = self.font_to_idx[font_name]
            for sample in font_data["samples"]:
                sample_path = sample["path"].replace("\\\\", "/").replace("\\", "/")
                image_path = self.dataset_dir / sample_path
                if image_path.exists():
                    self.samples.append((str(image_path), label))
                    found += 1
                else:
                    missing += 1

        self.num_fonts = len(self.font_to_idx)
        print(f"Loaded {found} samples from {self.num_fonts} fonts  (missing: {missing})")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, label = self.samples[idx]
        image = Image.open(image_path).convert("RGB")
        if self.augment:
            image = self._apply_augmentation(image)
        if self.transform:
            image = self.transform(image)
        return image, label

    def _apply_augmentation(self, image):
        if random.random() < 0.5:
            image = image.rotate(random.uniform(-5, 5), fillcolor="white", resample=Image.BILINEAR)
        if random.random() < 0.5:
            image = ImageEnhance.Brightness(image).enhance(random.uniform(0.8, 1.2))
        if random.random() < 0.5:
            image = ImageEnhance.Contrast(image).enhance(random.uniform(0.8, 1.2))
        if random.random() < 0.3:
            image = GaussianBlur(radius_range=(0.5, 1.5))(image)
        if random.random() < 0.3:
            image = AddGaussianNoise(std_range=(5, 15))(image)
        if random.random() < 0.2:
            image = RandomPerspective(distortion_scale=0.05, p=1.0)(image)
        return image


def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15):
    random.seed(42); np.random.seed(42); torch.manual_seed(42)
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    train_end = int(len(dataset) * train_ratio)
    val_end   = train_end + int(len(dataset) * val_ratio)
    train_ds = Subset(dataset, indices[:train_end])
    val_ds   = Subset(dataset, indices[train_end:val_end])
    test_ds  = Subset(dataset, indices[val_end:])
    print(f"Split: train={len(train_ds)}  val={len(val_ds)}  test={len(test_ds)}")
    return train_ds, val_ds, test_ds


print("Classes defined.")

## 3. Load Dataset & Prepare DataLoaders

In [None]:
# Load full dataset (no transforms yet -- we add them after creating the model)
dataset = FontDataset(DATASET_DIR, METADATA, transform=None)
train_ds, val_ds, test_ds = split_dataset(dataset)

## 4. Initialize Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Load OpenCLIP
model, _, preprocess = open_clip.create_model_and_transforms(
    MODEL_NAME, pretrained=PRETRAINED, device=device
)
model = model.to(device)

# Freeze all vision params, unfreeze last 2 transformer blocks
for param in model.visual.parameters():
    param.requires_grad = False
if hasattr(model.visual, "transformer"):
    for param in model.visual.transformer.resblocks[-2:].parameters():
        param.requires_grad = True

# Classification head
embedding_dim = model.visual.output_dim if hasattr(model.visual, "output_dim") else 512
num_fonts = dataset.num_fonts
classifier = nn.Linear(embedding_dim, num_fonts).to(device)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable += sum(p.numel() for p in classifier.parameters())
total = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in classifier.parameters())
print(f"Trainable parameters: {trainable:,} / {total:,}  ({100*trainable/total:.1f}%)")
print(f"Fonts: {num_fonts}")

In [None]:
# Apply transforms to dataset splits
train_ds.dataset.transform = preprocess
train_ds.dataset.augment = AUGMENTATION
val_ds.dataset.transform = preprocess
val_ds.dataset.augment = False

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train batches: {len(train_loader)}  |  Val batches: {len(val_loader)}")
print(f"Augmentation: {'ON' if AUGMENTATION else 'OFF'}")

## 5. Train

In [None]:
# Optimizer & loss
trainable_params = [p for p in model.parameters() if p.requires_grad]
trainable_params += list(classifier.parameters())

optimizer = torch.optim.AdamW(trainable_params, lr=LEARNING_RATE, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

save_path = Path(SAVE_DIR)
save_path.mkdir(parents=True, exist_ok=True)

best_val_acc = 0.0
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

print(f"Training {EPOCHS} epochs  |  LR={LEARNING_RATE}  |  Batch={BATCH_SIZE}")
print(f"Saving checkpoints to: {SAVE_DIR}")
print("-" * 60)

for epoch in range(EPOCHS):
    # ── Train ─────────────────────────────────────────────
    model.train()
    running_loss = 0.0
    correct = total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [train]")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        features = model.encode_image(images)
        logits = classifier(features)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = logits.max(1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()
        pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{100*correct/total:.1f}%")

    train_loss = running_loss / len(train_loader)
    train_acc  = 100 * correct / total

    # ── Validate ──────────────────────────────────────────
    model.eval()
    val_loss_sum = 0.0
    val_correct = val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            features = model.encode_image(images)
            logits = classifier(features)
            loss = criterion(logits, labels)
            val_loss_sum += loss.item()
            _, preds = logits.max(1)
            val_total += labels.size(0)
            val_correct += (preds == labels).sum().item()

    val_loss = val_loss_sum / len(val_loader)
    val_acc  = 100 * val_correct / val_total

    # ── Log ───────────────────────────────────────────────
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    marker = ""
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "classifier_state_dict": classifier.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_acc": val_acc,
            "num_fonts": num_fonts,
        }, save_path / "best_model.pt")
        marker = "  << saved best"

    print(f"Epoch {epoch+1:>2}/{EPOCHS}  "
          f"train_loss={train_loss:.4f}  train_acc={train_acc:.1f}%  "
          f"val_loss={val_loss:.4f}  val_acc={val_acc:.1f}%{marker}")

print("=" * 60)
print(f"Training complete.  Best val accuracy: {best_val_acc:.1f}%")
print(f"Checkpoint saved to: {save_path / 'best_model.pt'}")

## 6. Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

epochs_range = range(1, len(history["train_loss"]) + 1)

ax1.plot(epochs_range, history["train_loss"], label="Train")
ax1.plot(epochs_range, history["val_loss"], label="Val")
ax1.set_xlabel("Epoch"); ax1.set_ylabel("Loss"); ax1.set_title("Loss")
ax1.legend()

ax2.plot(epochs_range, history["train_acc"], label="Train")
ax2.plot(epochs_range, history["val_acc"], label="Val")
ax2.set_xlabel("Epoch"); ax2.set_ylabel("Accuracy (%)"); ax2.set_title("Accuracy")
ax2.legend()

plt.tight_layout()
plt.savefig(save_path / "training_curves.png", dpi=150)
plt.show()
print(f"Saved plot to {save_path / 'training_curves.png'}")

## 7. Quick Sanity Check

Load the saved checkpoint and verify it works on a random sample.

In [None]:
# Load best checkpoint and run a quick test
ckpt = torch.load(save_path / "best_model.pt", map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
classifier.load_state_dict(ckpt["classifier_state_dict"])
model.eval()

# Pick 5 random samples from test set
test_loader = DataLoader(test_ds, batch_size=5, shuffle=True, num_workers=0)
images, labels = next(iter(test_loader))
images = images.to(device)

with torch.no_grad():
    features = model.encode_image(images)
    logits = classifier(features)
    _, preds = logits.max(1)

print("Sanity check on 5 test samples:")
for i in range(len(labels)):
    true_font = dataset.idx_to_font[labels[i].item()]
    pred_font = dataset.idx_to_font[preds[i].item()]
    match = "OK" if labels[i] == preds[i] else "WRONG"
    print(f"  [{match:>5}]  True: {true_font:<25}  Predicted: {pred_font}")

## 8. Download Checkpoint

If you want to download `best_model.pt` directly from Colab (instead of grabbing it from Drive):

In [None]:
# Optional: download directly from Colab to your browser
from google.colab import files
files.download(str(save_path / "best_model.pt"))

## Next Steps

1. Copy `best_model.pt` into your local project at `proj1_check_fonts/check_fonts/models/best_model.pt`
2. Run Phase 3 (Vector DB) locally -- it does **not** require a GPU:
   ```
   python phase3_vector_db.py --checkpoint models/best_model.pt
   ```
3. Test search:
   ```
   python search_font.py --image path/to/query.png --top_k 5
   ```