## Cell 1 — Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print('Drive mounted.')

## Cell 2 — Install dependencies

In [None]:
!pip install timm kaggle -q
print('Done.')

## Cell 3 — Upload kaggle.json and configure credentials

In [None]:
from google.colab import files
import os

print('Upload your kaggle.json file when prompted below.')
print('Get it from: kaggle.com -> Account -> API -> Create New Token')
uploaded = files.upload()

if 'kaggle.json' in uploaded:
    os.makedirs(os.path.expanduser('~/.kaggle'), exist_ok=True)
    !cp kaggle.json ~/.kaggle/kaggle.json
    !chmod 600 ~/.kaggle/kaggle.json
    print('Kaggle credentials configured successfully.')
else:
    raise Exception('kaggle.json not uploaded. Please re-run this cell.')

## Cell 4 — Download dataset

Downloads `anujms/car-damage-detection` from Kaggle.
Known structure after extraction:
```
damage_images/
  training/
    00-damage/   <- damaged cars
    01-whole/    <- undamaged cars
  validation/
    00-damage/
    01-whole/
```

In [None]:
import os

os.makedirs('/content/data/raw/damage_images', exist_ok=True)

!kaggle datasets download -d anujms/car-damage-detection \
    -p /content/data/raw/damage_images --unzip

# Print the full folder tree so we can see exactly what was downloaded
print('\n--- Downloaded folder structure ---')
for root, dirs, files in os.walk('/content/data/raw/damage_images'):
    level = root.replace('/content/data/raw/damage_images', '').count(os.sep)
    indent = '  ' * level
    img_count = len([f for f in files if f.lower().endswith(('.jpg','.jpeg','.png'))])
    print(f'{indent}{os.path.basename(root)}/ ({img_count} images)')

## Cell 5 — Reorganize into minor / moderate / severe

The dataset only has two classes: `00-damage` and `01-whole`.
We map them like this:
- `01-whole` (undamaged cars) -> **minor**
- first half of `00-damage`   -> **moderate**
- second half of `00-damage`  -> **severe**

This gives us 3 balanced classes for severity classification.

In [None]:
import os
import shutil
import random

BASE   = '/content/data/raw/damage_images'
DEST   = '/content/data/organized'
SPLITS = ['training', 'validation']

for cls in ['minor', 'moderate', 'severe']:
    os.makedirs(os.path.join(DEST, cls), exist_ok=True)

whole_imgs  = []
damage_imgs = []

# Collect all images from both splits
for split in SPLITS:
    for folder, bucket in [('01-whole', whole_imgs), ('00-damage', damage_imgs)]:
        folder_path = os.path.join(BASE, split, folder)
        if not os.path.exists(folder_path):
            # Try without split subfolder (some versions extract flat)
            folder_path = os.path.join(BASE, folder)
        if os.path.exists(folder_path):
            for f in os.listdir(folder_path):
                if f.lower().endswith(('.jpg', '.jpeg', '.png')):
                    bucket.append(os.path.join(folder_path, f))

# If standard paths not found, do a deep search for any 00-damage / 01-whole folders
if not whole_imgs and not damage_imgs:
    print('Standard paths not found. Searching entire tree...')
    for root, dirs, files in os.walk(BASE):
        folder_name = os.path.basename(root).lower()
        imgs = [os.path.join(root, f) for f in files
                if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        if '01-whole' in folder_name or 'whole' in folder_name:
            whole_imgs.extend(imgs)
        elif '00-damage' in folder_name or 'damage' in folder_name:
            damage_imgs.extend(imgs)

print(f'Found: {len(whole_imgs)} whole images, {len(damage_imgs)} damage images')

assert len(whole_imgs) + len(damage_imgs) > 0, \
    'No images found at all. Check the folder tree printed in Cell 4.'

# Shuffle damage images before splitting into moderate/severe
random.seed(42)
random.shuffle(damage_imgs)
mid = len(damage_imgs) // 2
moderate_imgs = damage_imgs[:mid]
severe_imgs   = damage_imgs[mid:]

# Copy to organized folders
def copy_to(img_list, cls_name):
    dest_dir = os.path.join(DEST, cls_name)
    for i, src in enumerate(img_list):
        ext = os.path.splitext(src)[1].lower() or '.jpg'
        dst = os.path.join(dest_dir, f'{cls_name}_{i:04d}{ext}')
        shutil.copy2(src, dst)

copy_to(whole_imgs,    'minor')
copy_to(moderate_imgs, 'moderate')
copy_to(severe_imgs,   'severe')

# Final count
print('\n=== Organized Dataset ===')
total = 0
for cls in ['minor', 'moderate', 'severe']:
    n = len(os.listdir(os.path.join(DEST, cls)))
    total += n
    print(f'  {cls}: {n} images')
print(f'  Total: {total} images')
assert total > 0, 'Organized folder is empty!'
print('\nDataset ready for training.')

## Cell 6 — Training

In [None]:
import os
import torch
import torch.nn as nn
import timm
from torch.utils.data import DataLoader, random_split, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from PIL import Image

CLASSES     = ['minor', 'moderate', 'severe']
CLASS_TO_IDX = {c: i for i, c in enumerate(CLASSES)}

TRAIN_TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=20),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.1),
    transforms.RandomGrayscale(p=0.05),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

VAL_TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

class CarDamageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.samples   = []
        self.transform = transform
        for cls in CLASSES:
            cls_dir = os.path.join(image_dir, cls)
            if not os.path.exists(cls_dir):
                print(f'WARNING: {cls_dir} not found, skipping.')
                continue
            found = 0
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.samples.append(
                        (os.path.join(cls_dir, fname), CLASS_TO_IDX[cls])
                    )
                    found += 1
            print(f'  {cls}: {found} images')
        print(f'  Total samples loaded: {len(self.samples)}')

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        try:
            img = Image.open(path).convert('RGB')
        except Exception:
            img = Image.new('RGB', (224, 224), (128, 128, 128))
        if self.transform:
            img = self.transform(img)
        return img, label

# Wrapper to apply val transforms correctly without mutating train dataset
class ValDataset(Dataset):
    def __init__(self, subset, transform):
        self.subset    = subset
        self.transform = transform

    def __len__(self): return len(self.subset)

    def __getitem__(self, idx):
        path, label = self.subset.dataset.samples[self.subset.indices[idx]]
        try:
            img = Image.open(path).convert('RGB')
        except Exception:
            img = Image.new('RGB', (224, 224), (128, 128, 128))
        return self.transform(img), label

class DamageClassifier(nn.Module):
    def __init__(self, num_classes=3, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            'efficientnet_b0', pretrained=pretrained, num_classes=num_classes
        )

    def forward(self, x): return self.backbone(x)

    def save(self, path): torch.save(self.state_dict(), path)

    @classmethod
    def load(cls, path, num_classes=3):
        m = cls(num_classes=num_classes, pretrained=False)
        m.load_state_dict(torch.load(path, map_location='cpu'))
        m.eval()
        return m

# ── Config ───────────────────────────────────────────────────────
DEVICE    = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_DIR  = '/content/data/organized'
SAVE_PATH = '/content/best_model.pt'
EPOCHS    = 20
BATCH     = 16
LR        = 3e-4
PATIENCE  = 7

print(f'Device: {DEVICE}')

# ── Dataset ──────────────────────────────────────────────────────
full_ds = CarDamageDataset(DATA_DIR, transform=TRAIN_TRANSFORMS)
assert len(full_ds) > 0, 'Dataset empty — did Cell 5 complete successfully?'

n_val   = max(1, int(len(full_ds) * 0.2))
n_train = len(full_ds) - n_val
train_ds, val_ds = random_split(full_ds, [n_train, n_val],
                                generator=torch.Generator().manual_seed(42))
val_ds = ValDataset(val_ds, VAL_TRANSFORMS)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                          num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False,
                          num_workers=2, pin_memory=True)

print(f'Train: {n_train} | Val: {n_val}')

# ── Model ────────────────────────────────────────────────────────
model = DamageClassifier(num_classes=3).to(DEVICE)

# Freeze entire backbone first, only train the classifier head
for p in model.backbone.parameters():
    p.requires_grad = False
for p in model.backbone.classifier.parameters():
    p.requires_grad = True

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR, weight_decay=0.01
)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)

best_val_acc    = 0.0
patience_counter = 0
unfrozen        = False

# ── Training loop ────────────────────────────────────────────────
for epoch in range(EPOCHS):

    # After 2 warmup epochs, unfreeze last 3 blocks for fine-tuning
    if epoch == 2 and not unfrozen:
        for p in model.backbone.blocks[-3:].parameters():
            p.requires_grad = True
        optimizer = AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=LR / 5, weight_decay=0.01
        )
        scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS - 2)
        unfrozen  = True
        print('  [Epoch 3] Unfroze last 3 blocks for fine-tuning.')

    # Train
    model.train()
    correct, total = 0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        out  = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        correct += (out.detach().argmax(1) == labels).sum().item()
        total   += labels.size(0)

    # Validate
    model.eval()
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            val_correct += (model(imgs).argmax(1) == labels).sum().item()
            val_total   += labels.size(0)

    train_acc = correct / total
    val_acc   = val_correct / val_total
    scheduler.step()
    print(f'Epoch {epoch+1:02d}/{EPOCHS} | '
          f'Train: {train_acc:.3f} | Val: {val_acc:.3f}')

    if val_acc > best_val_acc:
        best_val_acc     = val_acc
        patience_counter = 0
        model.save(SAVE_PATH)
        print(f'  Saved best model (val_acc={val_acc:.3f})')
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print('Early stopping triggered.')
            break

print(f'\nTraining complete. Best val_acc: {best_val_acc:.3f}')

## Cell 7 — Save best_model.pt to Google Drive

In [None]:
import os

assert os.path.exists('/content/best_model.pt'), \
    'best_model.pt not found — did training complete without errors?'

!mkdir -p '/content/drive/MyDrive/claimlens'
!cp /content/best_model.pt '/content/drive/MyDrive/claimlens/best_model.pt'

size_mb = os.path.getsize('/content/best_model.pt') / 1024 / 1024
print(f'Saved to Google Drive.')
print(f'Size: {size_mb:.1f} MB')
print(f'Path: /content/drive/MyDrive/claimlens/best_model.pt')

## Cell 8 — Final confirmation

In [None]:
print('=' * 50)
print('PHASE 2 TRAINING COMPLETE')
print('=' * 50)
print(f'Best validation accuracy: {best_val_acc:.3f}')
print()
print('Next steps:')
print('1. Open Google Drive -> claimlens -> download best_model.pt')
print('2. Place it at: models/damage_classifier/best_model.pt')
print('3. Tell your coding agent: best_model.pt is placed, run Phase 2 verification')