# GTSRB Traffic Sign Classifier → ONNX → Unity 


##  Requirements
Run this cell to install required packages (skip if already installed).

In [None]:
%%bash
python -m pip install --upgrade pip
python -m pip install torch torchvision onnx onnxruntime opencv-python pillow numpy kaggle tqdm

## 1 — Download real GTSRB from Kaggle if possible, else create synthetic dataset
This cell checks for `~/.kaggle/kaggle.json`. If present it will attempt to download a public GTSRB mirror via the Kaggle CLI. If the download succeeds, it prepares the ImageFolder layout. If not, it creates a small synthetic ImageFolder dataset (4 classes) so you can run training immediately.

In [None]:

import os, subprocess, sys, shutil, random
from pathlib import Path

DATA_ROOT = Path('data/gtsrb_imagefolder')
KAGGLE_TOKEN = Path.home() / '.kaggle' / 'kaggle.json'
print('Checking for Kaggle token at', KAGGLE_TOKEN)
download_success = False
if KAGGLE_TOKEN.exists():
    print('Kaggle token found. Attempting to download GTSRB dataset via Kaggle CLI (public mirror).')
    os.makedirs('data/gtsrb', exist_ok=True)
    try:
        cmd = ['kaggle', 'datasets', 'download', '-d', 'ibrahimkaratas/gtsrb-german-traffic-sign-recognition-benchmark', '-p', 'data/gtsrb', '--unzip']
        print('Running:', ' '.join(cmd))
        subprocess.run(cmd, check=True)
        download_success = True
        print('Downloaded GTSRB mirror. Preparing ImageFolder structure...')
        src_candidates = [Path('data/gtsrb'), Path('data/gtsrb/GTSRB'), Path('data/gtsrb/Final_Training/Images')]
        src = None
        for cand in src_candidates:
            if cand.exists():
                # look for class subdirectories
                for p in cand.iterdir():
                    if p.is_dir():
                        src = cand
                        break
                if src:
                    break
        if src is None:
            print('Downloaded content not found in expected locations. Will fallback to synthetic dataset.')
            download_success = False
        else:
            if DATA_ROOT.exists():
                shutil.rmtree(DATA_ROOT)
            DATA_ROOT.mkdir(parents=True, exist_ok=True)
            for cls_dir in src.iterdir():
                if cls_dir.is_dir():
                    imgs = list(cls_dir.glob('*.*'))
                    if not imgs:
                        continue
                    random.shuffle(imgs)
                    split = int(len(imgs) * 0.8)
                    train = imgs[:split]; val = imgs[split:]
                    for t in train:
                        dst = DATA_ROOT / 'train' / cls_dir.name
                        dst.mkdir(parents=True, exist_ok=True)
                        shutil.copy(t, dst / t.name)
                    for v in val:
                        dst = DATA_ROOT / 'val' / cls_dir.name
                        dst.mkdir(parents=True, exist_ok=True)
                        shutil.copy(v, dst / v.name)
            print('Prepared ImageFolder at', DATA_ROOT)
    except Exception as e:
        print('Kaggle download or preparation failed:', e)
        download_success = False

if not download_success:
    print('Creating synthetic fallback dataset (small) at', DATA_ROOT)
    from PIL import Image, ImageDraw
    classes = ['left','right','forward','stop']
    n_per_class = 40
    sz = 128
    if DATA_ROOT.exists():
        shutil.rmtree(DATA_ROOT)
    for cls in classes:
        d = DATA_ROOT / 'train' / cls
        d.mkdir(parents=True, exist_ok=True)
        for i in range(n_per_class):
            img = Image.new('RGB', (sz, sz), color=(255,255,255))
            draw = ImageDraw.Draw(img)
            color = tuple([int(x) for x in (random.randint(50,220), random.randint(50,220), random.randint(50,220))])
            r = random.randint(20,48)
            cx = random.randint(r, sz-r)
            cy = random.randint(r, sz-r)
            draw.ellipse((cx-r, cy-r, cx+r, cy+r), fill=color, outline=(0,0,0))
            if cls == 'left':
                draw.polygon([(cx- r//2, cy), (cx + r//2, cy - r//2), (cx + r//2, cy + r//2)], fill=(0,0,0))
            elif cls == 'right':
                draw.polygon([(cx + r//2, cy), (cx - r//2, cy - r//2), (cx - r//2, cy + r//2)], fill=(0,0,0))
            elif cls == 'forward':
                draw.polygon([(cx, cy - r//2), (cx - r//2, cy + r//2), (cx + r//2, cy + r//2)], fill=(0,0,0))
            else:
                draw.rectangle((cx - r//2, cy - r//2, cx + r//2, cy + r//2), fill=(0,0,0))
            img.save(d / f'{cls}_{i}.png')
    # small val split
    for cls in classes:
        src = DATA_ROOT / 'train' / cls
        dst = DATA_ROOT / 'val' / cls
        dst.mkdir(parents=True, exist_ok=True)
        files = list(src.iterdir())
        random.shuffle(files)
        take = files[:8]
        for f in take:
            f.rename(dst / f.name)
    print('Synthetic dataset created at', DATA_ROOT)


## 2 — Train (on real GTSRB if downloaded, else synthetic) and export ONNX
This cell will run training on whatever `data/gtsrb_imagefolder` currently contains (real or synthetic).

In [None]:

import torch, os
from torchvision import transforms, datasets, models
from torch import nn, optim
from torch.utils.data import DataLoader

DATA_DIR = 'data/gtsrb_imagefolder'
EPOCHS = 6
BATCH = 64
LR = 3e-4
OUT_ONNX_REAL = 'gtsrb.onnx'
OUT_ONNX_SYN = 'gtsrb_synthetic.onnx'
OUT_PTH = 'gtsrb_model.pth'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

if not os.path.exists(DATA_DIR):
    raise RuntimeError('Data folder not found. Run the download/prep cell first.')
tr = transforms.Compose([transforms.Resize((128,128)),
                         transforms.RandomRotation(10),
                         transforms.ColorJitter(0.1,0.1,0.1),
                         transforms.ToTensor(),
                         transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
valtr = transforms.Compose([transforms.Resize((128,128)),
                            transforms.ToTensor(),
                            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
train_ds = datasets.ImageFolder(os.path.join(DATA_DIR,'train'), transform=tr)
val_ds = datasets.ImageFolder(os.path.join(DATA_DIR,'val'), transform=valtr)
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=0)
print('Num classes:', len(train_ds.classes), train_ds.classes)

# Use smaller model settings for speed if dataset is synthetic (small)
use_synthetic = (len(train_ds) < 2000)
model = models.resnet18(pretrained=not use_synthetic)  # if synthetic keep training from scratch to avoid download/pretrained deps
model.fc = nn.Linear(model.fc.in_features, len(train_ds.classes))
model = model.to(device)
opt = optim.Adam(model.parameters(), lr=LR)
crit = nn.CrossEntropyLoss()

for e in range(EPOCHS):
    model.train()
    running = 0.0
    for X,y in train_loader:
        X=X.to(device); y=y.to(device)
        opt.zero_grad()
        out = model(X)
        loss = crit(out,y)
        loss.backward()
        opt.step()
        running += loss.item() * X.size(0)
    train_loss = running / len(train_ds)
    model.eval()
    correct=0; total=0
    with torch.no_grad():
        for X,y in val_loader:
            X=X.to(device); y=y.to(device)
            out = model(X)
            pred = out.argmax(dim=1)
            correct += (pred==y).sum().item()
            total += y.size(0)
    val_acc = correct/total if total>0 else 0.0
    print(f'Epoch {e+1}/{EPOCHS} train_loss={train_loss:.4f} val_acc={val_acc:.4f}')

# save checkpoint
torch.save({'model':model.state_dict(),'classes':train_ds.classes}, OUT_PTH)
print('Saved checkpoint', OUT_PTH)

# export ONNX
model.eval()
dummy = torch.randn(1,3,128,128, device=device)
out_name = OUT_ONNX_REAL if not use_synthetic else OUT_ONNX_SYN
torch.onnx.export(model, dummy, out_name, input_names=['input'], output_names=['output'], opset_version=11)
print('Exported ONNX to', out_name)


## 3 — ONNX runtime check (optional)
This runs a quick inference using ONNXRuntime on one image from the validation set.

In [None]:

try:
    import onnxruntime as rt
    from PIL import Image
    import numpy as np, os
    onnx_file = 'gtsrb.onnx' if os.path.exists('gtsrb.onnx') else 'gtsrb_synthetic.onnx'
    print('Using ONNX file:', onnx_file)
    sess = rt.InferenceSession(onnx_file)
    input_name = sess.get_inputs()[0].name
    img_path = None
    for root,dirs,files in os.walk('data/gtsrb_imagefolder/val'):
        for f in files:
            if f.lower().endswith(('.png','.jpg','.ppm','.jpeg')):
                img_path = os.path.join(root,f); break
        if img_path: break
    if img_path is None:
        print('No validation image found to test ONNX.')
    else:
        img = Image.open(img_path).convert('RGB').resize((128,128))
        arr = np.array(img).astype(np.float32)/255.0
        mean = np.array([0.485,0.456,0.406],dtype=np.float32)
        std = np.array([0.229,0.224,0.225],dtype=np.float32)
        arr = (arr - mean) / std
        arr = arr.transpose(2,0,1)[None,...]
        out = sess.run(None, {input_name: arr})
        print('ONNX output shape:', out[0].shape)
except Exception as e:
    print('ONNX runtime test failed (maybe onnxruntime not installed). Error:', e)


## 4 — Unity script
The Unity C# script (`WebcamModelController.cs`) was saved to the notebook folder. Path: `/mnt/data/WebcamModelController.cs`. Import it into Unity `Assets/` and point its `classes` array to the trained class names (the order matches `model.classes` saved in the checkpoint).

In [None]:
print('Unity script saved to:', r'/mnt/data/WebcamModelController.cs')