# Final Model — Multi-Head DeiT-Base (Training + Evaluation)

This notebook contains the code used to train and evaluate our **final model**:
**DeiT-Base multi-head** (one classification head per behavior attribute).

**Assumptions (same as the original run):**
- Dataset path in Drive: `MyDrive/dataset_project`
- A labels CSV exists: `labels.csv`
- `labels.csv` includes at least: `filename`, `split`
- Label columns (as used in the code): `Gaze`, `Headphones`, `Environment`, `Privacy`, `ObjectInHand`
- Split folders exist under `dataset_project`: `train/`, `validation/`, `test/`

**Main outputs:**
- Model checkpoints: `models/train_last.pth`, `models/val_best.pth`
- Test metrics CSV: `models/test_results.csv`
- Figures (optional): saved under `plots/` and `baseline_results/`

## 0. Setup

Install dependencies, mount Drive, and define:
- Paths
- Label maps
- Dataset class
- Multi-head DeiT model

In [None]:
!pip -q install timm
from google.colab import drive
drive.mount('/content/drive')

import torch, torch.nn as nn
import timm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
from pathlib import Path
from tqdm import tqdm

DATASET_DIR = Path("/content/drive/MyDrive/dataset_project")
CSV_PATH = DATASET_DIR / "labels.csv"

TRAIN_DIR = DATASET_DIR / "train"
VAL_DIR   = DATASET_DIR / "validation"
TEST_DIR  = DATASET_DIR / "test"

MODELS_DIR = DATASET_DIR / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

LABEL_MAPS = {
    "gaze": {"camera":0, "not_camera":1, "not camera":1, "eyes_closed":2, "eyes close":2},
    "headphones": {"with_headphones":0,"with headphones":0,"with_headphone":0,
                   "without_headphones":1,"without headphones":1,"without_headphone":1,
                   "unknown":2},
    "environment": {"indoor":0, "outdoor":1},
    "privacy": {"private":0, "public":1},
    "object": {"cup":0, "phone":1, "pen":2, "none":3, "other":4, "unknown":5}
}
NUM_CLASSES = {"gaze":3,"headphones":3,"environment":2,"privacy":2,"object":6}

CSV_COLS = {"gaze":"Gaze","headphones":"Headphones","environment":"Environment","privacy":"Privacy","object":"ObjectInHand"}

def clean_label(v):
    if pd.isna(v): return ""
    return str(v).strip().lower()

def resolve_image_path(split_value, filename):
    split = str(split_value).strip().lower()
    folder = TRAIN_DIR if split=="train" else VAL_DIR if split in ["val","validation"] else TEST_DIR if split=="test" else DATASET_DIR
    fname = str(filename).replace("\\","/").strip().split("/")[-1]

    p = folder / fname
    if p.exists(): return p

    stem = Path(fname).stem
    for ext in (".jpg",".jpeg",".png",".webp"):
        cand = folder / f"{stem}{ext}"
        if cand.exists(): return cand
    return None

class MultiTaskDataset(Dataset):
    def __init__(self, df, transform=None):
        self.transform = transform
        paths, rows = [], []
        for _, r in df.iterrows():
            p = resolve_image_path(r["split"], r["filename"])
            if p is not None:
                paths.append(p)
                rows.append(r)
        self.df = pd.DataFrame(rows).reset_index(drop=True)
        self.paths = paths

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform: img = self.transform(img)

        row = self.df.iloc[idx]
        targets = {}
        defaults = {"gaze":"camera","environment":"indoor","privacy":"private"}
        for task, col in CSV_COLS.items():
            val = clean_label(row[col])
            if val not in LABEL_MAPS[task]:
                val = "unknown" if "unknown" in LABEL_MAPS[task] else defaults[task]
            targets[task] = torch.tensor(LABEL_MAPS[task][val], dtype=torch.long)
        return img, targets

class MultiHeadDeiT(nn.Module):
    def __init__(self, backbone_name, num_classes, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, num_classes=0)
        dim = self.backbone.num_features
        self.heads = nn.ModuleDict({t: nn.Linear(dim, n) for t, n in num_classes.items()})
    def forward(self, x):
        feat = self.backbone(x)
        return {t: head(feat) for t, head in self.heads.items()}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Load CSV + Build DataLoaders

- Read `labels.csv`
- Split by `split` column (train / val / test)
- Create transforms
- Build DataLoaders

In [None]:
df = pd.read_csv(CSV_PATH)
df.columns = [c.strip() for c in df.columns]

df_train = df[df["split"].astype(str).str.strip().str.lower()=="train"].reset_index(drop=True)
df_val   = df[df["split"].astype(str).str.strip().str.lower().isin(["val","validation"])].reset_index(drop=True)
df_test  = df[df["split"].astype(str).str.strip().str.lower()=="test"].reset_index(drop=True)

train_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

eval_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

train_ds = MultiTaskDataset(df_train, transform=train_tf)
val_ds   = MultiTaskDataset(df_val,   transform=eval_tf)
test_ds  = MultiTaskDataset(df_test,  transform=eval_tf)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=2, pin_memory=True)
val_dl   = DataLoader(val_ds,   batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
test_dl  = DataLoader(test_ds,  batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

## 2. Train DeiT Multi-Head (15 epochs)

- Pretrained DeiT backbone
- Multi-head classification (5 tasks)
- Loss = sum of CrossEntropy across heads
- Saves last checkpoint: `models/train_last.pth`

In [None]:
model = MultiHeadDeiT("deit_base_patch16_224", NUM_CLASSES, pretrained=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.05)

use_amp = torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

EPOCHS = 15
TRAIN_SAVE = MODELS_DIR / "train_last.pth"

model.train()
for epoch in range(1, EPOCHS+1):
    total = 0.0
    pbar = tqdm(train_dl, desc=f"Train {epoch}/{EPOCHS}")
    for imgs, targets in pbar:
        imgs = imgs.to(device, non_blocking=True)
        targets = {k:v.to(device, non_blocking=True) for k,v in targets.items()}

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            out = model(imgs)
            loss = sum(criterion(out[t], targets[t]) for t in out)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total += loss.item()
        pbar.set_postfix(loss=total/max(1,pbar.n))

    torch.save(model.state_dict(), TRAIN_SAVE)
print("Saved:", TRAIN_SAVE)

## 3. Validation pass + Save best checkpoint

Evaluates validation loss and saves best weights to: `models/val_best.pth`

In [None]:
VAL_BEST = MODELS_DIR / "val_best.pth"
val_best_loss = float("inf")

model.eval()
val_loss = 0.0
with torch.no_grad():
    pbar = tqdm(val_dl, desc="Validation")
    for imgs, targets in pbar:
        imgs = imgs.to(device, non_blocking=True)
        targets = {k:v.to(device, non_blocking=True) for k,v in targets.items()}
        with torch.cuda.amp.autocast(enabled=use_amp):
            out = model(imgs)
            loss = sum(criterion(out[t], targets[t]) for t in out)
        val_loss += loss.item()
        pbar.set_postfix(loss=val_loss/max(1,pbar.n))

avg_val = val_loss / max(1, len(val_dl))
if avg_val < val_best_loss:
    val_best_loss = avg_val
    torch.save(model.state_dict(), VAL_BEST)

print("avg_val_loss:", avg_val)
print("best_path:", VAL_BEST, "best_loss:", val_best_loss)

## 4. Test evaluation (Accuracy per task)

Loads `val_best.pth` (if exists) and prints per-task test accuracy.

In [None]:
# טוענים את best של הולידציה אם קיים
BEST = MODELS_DIR / "val_best.pth"
if BEST.exists():
    model.load_state_dict(torch.load(BEST, map_location=device))

model.eval()

correct = {t:0 for t in NUM_CLASSES}
total   = {t:0 for t in NUM_CLASSES}

with torch.no_grad():
    pbar = tqdm(test_dl, desc="Test")
    for imgs, targets in pbar:
        imgs = imgs.to(device, non_blocking=True)
        targets = {k:v.to(device, non_blocking=True) for k,v in targets.items()}
        out = model(imgs)

        for t in out:
            pred = out[t].argmax(dim=1)
            correct[t] += (pred == targets[t]).sum().item()
            total[t] += targets[t].numel()

for t in NUM_CLASSES:
    acc = correct[t] / max(1, total[t])
    print(t, "acc:", acc)

### Export test results to CSV

Saves: `models/test_results.csv`

In [None]:
results_path = MODELS_DIR / "test_results.csv"
rows = []
for t in NUM_CLASSES:
    rows.append({"task": t, "accuracy": correct[t]/max(1,total[t]), "correct": correct[t], "total": total[t]})

pd.DataFrame(rows).to_csv(results_path, index=False)
print("Saved:", results_path)

## 5. Macro-F1 + Overall metrics (Mean Accuracy + Joint Accuracy)

Collect predictions on the test set and compute:
- Accuracy per task
- Macro-F1 per task
- Mean Accuracy (avg over tasks)
- Joint Accuracy (all 5 tasks correct)

In [None]:
!pip -q install scikit-learn

import numpy as np
from sklearn.metrics import f1_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import torch

TASKS = ["gaze", "headphones", "environment", "privacy", "object"]

model.eval()

y_true = {t: [] for t in TASKS}
y_pred = {t: [] for t in TASKS}

with torch.no_grad():
    for imgs, targets in test_dl:
        imgs = imgs.to(device, non_blocking=True)
        targets = {k: v.to(device, non_blocking=True) for k, v in targets.items()}

        out = model(imgs)
        for t in TASKS:
            pred = out[t].argmax(dim=1)
            y_true[t].append(targets[t].cpu().numpy())
            y_pred[t].append(pred.cpu().numpy())

for t in TASKS:
    y_true[t] = np.concatenate(y_true[t])
    y_pred[t] = np.concatenate(y_pred[t])

print("Done collecting predictions.")

In [None]:
acc = {}
f1m = {}

for t in TASKS:
    acc[t] = float((y_pred[t] == y_true[t]).mean())
    f1m[t] = float(f1_score(y_true[t], y_pred[t], average="macro"))

# Accuracy כולל יחד: תמונה נחשבת נכונה רק אם כל ה-5 משימות נכונות
all_correct = np.ones_like(y_true[TASKS[0]], dtype=bool)
for t in TASKS:
    all_correct &= (y_pred[t] == y_true[t])
joint_accuracy = float(all_correct.mean())

avg_accuracy = float(np.mean([acc[t] for t in TASKS]))

print("=== Accuracy per task ===")
for t in TASKS:
    print(f"{t}: {acc[t]:.4f}")

print("\n=== Macro F1 per task ===")
for t in TASKS:
    print(f"{t}: {f1m[t]:.4f}")

print("\n=== Overall ===")
print("Average accuracy (mean over tasks):", f"{avg_accuracy:.4f}")
print("Joint accuracy (all tasks correct):", f"{joint_accuracy:.4f}")

## 6. Confusion matrix (default: Object-in-hand)

Change `TASK_CM` if you want a different task confusion matrix.

In [None]:
TASK_CM = "object"  # אפשר לשנות ל: "gaze" / "headphones" / "environment" / "privacy" / "object"

idx_to_label = {v: k for k, v in LABEL_MAPS[TASK_CM].items()}
labels_order = [idx_to_label[i] for i in range(NUM_CLASSES[TASK_CM])]

cm = confusion_matrix(y_true[TASK_CM], y_pred[TASK_CM], labels=list(range(NUM_CLASSES[TASK_CM])))

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels_order)
plt.figure(figsize=(7, 7))
disp.plot(values_format="d", xticks_rotation=45)
plt.title(f"Confusion Matrix - {TASK_CM}")
plt.show()

## 7. Export plots used in the presentation

⚠️ This section uses the hard-coded results (exactly as in the original run),
so the exported plots match the reported numbers.

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

# איפה לשמור את התמונות (בדרייב)
OUT_DIR = Path("/content/drive/MyDrive/dataset_project/plots")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# התוצאות שלך
acc = {
    "gaze": 0.5575,
    "headphones": 0.5977,
    "environment": 0.8908,
    "privacy": 0.9655,
    "object": 0.6897
}

f1 = {
    "gaze": 0.4357,
    "headphones": 0.5699,
    "environment": 0.4711,
    "privacy": 0.4912,
    "object": 0.3442
}

overall_mean_acc = 0.7402
joint_acc = 0.2471

print("✅ OUT_DIR:", OUT_DIR)

In [None]:
tasks = list(acc.keys())
x = np.arange(len(tasks))
width = 0.38

plt.figure(figsize=(10,5))
plt.bar(x - width/2, [acc[t] for t in tasks], width=width, label="Accuracy", color="#1f77b4")
plt.bar(x + width/2, [f1[t] for t in tasks],  width=width, label="Macro F1", color="#6baed6")

plt.ylim(0, 1.05)
plt.xticks(x, tasks, rotation=0)
plt.ylabel("Score")
plt.title("Per-Task Performance (Accuracy & Macro F1)")
plt.grid(axis="y", alpha=0.25)
plt.legend()

path = OUT_DIR / "per_task_accuracy_f1.png"
plt.tight_layout()
plt.savefig(path, dpi=300)
plt.show()

print("✅ Saved:", path)

In [None]:
plt.figure(figsize=(7,5))
labels = ["Mean Accuracy\n(mean over tasks)", "Joint Accuracy\n(all tasks correct)"]
vals = [overall_mean_acc, joint_acc]

plt.bar(labels, vals, color=["#1f77b4", "#3182bd"])
plt.ylim(0, 1.05)
plt.ylabel("Score")
plt.title("Overall Performance")
plt.grid(axis="y", alpha=0.25)

for i, v in enumerate(vals):
    plt.text(i, v + 0.02, f"{v:.3f}", ha="center", fontweight="bold")

path = OUT_DIR / "overall_mean_vs_joint.png"
plt.tight_layout()
plt.savefig(path, dpi=300)
plt.show()

print("✅ Saved:", path)

## 8. (Optional) Baseline comparison figure

Compares baseline train/val behavior to our final DeiT test average accuracy line
(uses `baseline_results/baseline_metrics.csv`).

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

BASELINE_DIR = Path("/content/drive/MyDrive/dataset_project/baseline_results")
METRICS_PATH = BASELINE_DIR / "baseline_metrics.csv"

# ממכם:
NEW_AVG_ACC = 0.7402  # Average accuracy (mean over tasks)

dfm = pd.read_csv(METRICS_PATH)

epochs = dfm["Epoch"]
train_loss = dfm["Train_Loss"]
val_avg = dfm["Val_Avg_Accuracy"]

plt.figure(figsize=(10,5))

ax1 = plt.gca()
ax1.plot(epochs, train_loss, marker="o", label="Baseline Train Loss", color=plt.cm.Blues(0.75))
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Train Loss")
ax1.grid(alpha=0.25)

ax2 = ax1.twinx()
ax2.plot(epochs, val_avg, marker="s", label="Baseline Val Avg Accuracy", color=plt.cm.Blues(0.45))
ax2.axhline(NEW_AVG_ACC, linestyle="--", label="Multi-Head DeiT Test Avg Acc (0.7402)", color=plt.cm.Blues(0.9))
ax2.set_ylabel("Validation Avg Accuracy")

lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
plt.legend(lines1 + lines2, labels1 + labels2, loc="lower right")

plt.title("Baseline Overfitting: Train Loss vs Validation Accuracy (+ Our Final Test Avg Acc)")
out_path = BASELINE_DIR / "baseline_overfitting_plus_ours.png"
plt.tight_layout()
plt.savefig(out_path, dpi=200)
plt.show()

print("Saved:", out_path)