# Fruit Freshness Classification
Important Notes (Please Read Before Using This Notebook):

- This notebook was originally developed and tested in Google Colab using files stored in a personal Google Drive.

- The version published on GitHub is provided for code reference only.

- Running this notebook as-is will not work, because:

    - Dataset files are not included in the repository (due to size and privacy).

    - All file paths point to my personal Drive directory, which other users will not have access to.

    - The notebook pushed from VS Code contains no runtime outputs, because it is not executed before uploading.

If you wish to run this notebook yourself, you will need to:

1. Provide your own dataset following the folder structure described in the report.

2. Update all paths to match your environment (local, Colab, or custom).

3. Mount your own Google Drive (if using Colab) or adjust file loading accordingly.

All results, charts, metrics, comparisons, and discussion can be found in our project report and presentation, which serve as the authoritative sources for model performance.

In [None]:
# pip install for google colab
!pip install -q ultralytics
!pip install timm -q
!pip install opencv-python pillow-heif ultralytics --quiet

In [None]:
import os, sys, zipfile, math, random, copy, shutil, time, datetime, csv, gc, json, glob, pillow_heif
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch, Rectangle
%matplotlib inline
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm, trange
from pathlib import Path
from sklearn.metrics import (
    confusion_matrix,
    ConfusionMatrixDisplay,
    accuracy_score,
    precision_recall_fscore_support,
    roc_curve,
    roc_auc_score,
    classification_report,
)

# torch and nn 
import torch
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset, random_split
from torchvision import transforms, utils, models, datasets
import torch.nn as nn

from ultralytics import YOLO
import timm

import cv2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# run if using google colab
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# fruit classes for reference 
fruit_classes = {
    'kaggle': ('apples','bananas','oranges'),  # 3 types
    'custom': ('avocado','grapes','lemon','mango','pineapple','strawberry','watermelon'), # 7 types
    'ood': ('peach','tomato') # out of distribution (won't be seen by model)
}

In [None]:
# paths
# modify if needed 
dataset_root = "/content/project/fruits_binary"
test_root = os.path.join(dataset_root, "test")
ood_test_root = '/content/project/ood_test'
base_logs_root = Path("/content/drive/MyDrive/CSC2503")

# model logs paths
# example: yolo_logs_root = Path("/content/drive/MyDrive/CSC2503/yolo_training_logs")
yolo_logs_root = base_logs_root / "yolo_training_logs"
swin_logs_root = base_logs_root / "swin_training_logs"
repvgg_logs_root = base_logs_root / "repvgg_training_logs"

# Training 
**(Skip this section if models have been trained. Proceed to next section where models/pre-trained weights will be loaded)**

In [None]:
# modify path as needed

# unzip the complete dataset
!unzip /content/drive/MyDrive/CSC2503/fruits_binary.zip -d /content/project
# unzip the ood test set
!unzip /content/drive/MyDrive/CSC2503/ood_test.zip -d /content/project

# useful bash commands for checking files/dirs
# !ls /content/drive/MyDrive/CSC2503
# !ls -F /content/project/fruits_binary
# !ls -F /content/project/ood_test

### YOLO

In [None]:
yolo_logs_root.mkdir(parents=True, exist_ok=True)

# set this True if you want to force retrain even if logs exist
force_retrain = False  

# --- 1) Try to load an existing best.pt from yolo_logs_root/weights/best.pt ---
best_candidates = sorted(yolo_logs_root.glob("weights/best.pt"))

if best_candidates and not force_retrain:
    # pick the latest run (last in sorted list)
    best_weight = best_candidates[-1]
    print(f"Found existing trained model: {best_weight}")
    print("Loading model without retraining...")
    yolo_model = YOLO(str(best_weight))
    print("model loaded! call 'yolo_model' to access the model.")

else:
    print("No existing model found in logs OR retrain requested.")
    print("Training a new YOLO classification model...")

    # 2) Train from pretrained YOLO11m classification weights
    yolo_model = YOLO("yolo11m-cls.pt")  # or "yolo11s-cls.pt", etc.

    results = yolo_model.train(
        data=dataset_root,      # root with train/val(/test)
        epochs=15,
        imgsz=224,
        batch=16,
        device=0,               # GPU in Colab
        # classes=2,            # YOLO infers from folders ['bad', 'good']
        auto_augment="randaugment",
        fliplr=0.5,
        degrees=15,
        scale=0.5,
        freeze=range(1, 9999),  # freeze all but first layer + head
    )
    print("YOLO model trained; you can access the model by calling 'yolo_model'.")

    # 3) Mirror YOLO runs from /content/runs/classify/train* → Drive logs
    runs_root = Path("runs") / "classify"   # usually /content/runs/classify
    if runs_root.exists():
        train_runs = list(runs_root.glob("train*"))
        if train_runs:
            for run in train_runs:
                dest = yolo_logs_root / run.name
                print(f"Copying {run} → {dest}")
                shutil.copytree(run, dest, dirs_exist_ok=True)
        else:
            print("WARNING: no train* runs found under runs/classify.")
    else:
        print("WARNING: runs/classify folder not found; nothing copied to yolo_logs_root.")

### RepVGG and Swin Transformer  

In [None]:
# datasets & loaders
img_size = 224

train_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

train_dir = os.path.join(dataset_root, "train")
val_dir   = os.path.join(dataset_root, "val")

train_ds = datasets.ImageFolder(train_dir, transform=train_transform)
val_ds   = datasets.ImageFolder(val_dir,   transform=val_transform)

batch_size = 32

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

print("Classes:", train_ds.classes, "num_classes:", len(train_ds.classes))


In [None]:
num_classes = 2

# helpers
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk_req = max(topk)
        num_classes = output.size(1)
        maxk = min(maxk_req, num_classes)  # don't ask for more than num_classes

        batch_size = target.size(0)
        _, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            k_eff = min(k, num_classes)  # clamp each k as well
            correct_k = correct[:k_eff].reshape(-1).float().sum(0, keepdim=True)
            res.append((correct_k / batch_size).item())
        return res

# train / eval loops
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_acc1 = 0.0
    running_acc5 = 0.0
    n_samples = 0

    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        batch_size = labels.size(0)
        acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))

        running_loss += loss.item() * batch_size
        running_acc1 += acc1 * batch_size
        running_acc5 += acc5 * batch_size
        n_samples += batch_size

    epoch_loss = running_loss / n_samples
    epoch_acc1 = running_acc1 / n_samples
    epoch_acc5 = running_acc5 / n_samples
    return epoch_loss, epoch_acc1, epoch_acc5


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_acc1 = 0.0
    running_acc5 = 0.0
    n_samples = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)

            batch_size = labels.size(0)
            acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))

            running_loss += loss.item() * batch_size
            running_acc1 += acc1 * batch_size
            running_acc5 += acc5 * batch_size
            n_samples += batch_size

    epoch_loss = running_loss / n_samples
    epoch_acc1 = running_acc1 / n_samples
    epoch_acc5 = running_acc5 / n_samples
    return epoch_loss, epoch_acc1, epoch_acc5


# model factory functions
def create_swin_model(num_classes):
    swin = models.swin_t(weights=models.Swin_T_Weights.IMAGENET1K_V1)
    # replace head for binary classification
    swin.head = nn.Linear(swin.head.in_features, num_classes)
    return swin


def create_repvgg_model(num_classes):
    # RepVGG via timm
    model = timm.create_model("repvgg_b0", pretrained=True, num_classes=num_classes)
    return model


# generic train-or-load wrapper
def train_or_load_model(model_name, create_model_fn, num_epochs=15, lr=1e-4, wd=1e-4, force_retrain=False):
    """
    model_name: "swin" or "repvgg"
    create_model_fn: function that returns a fresh model instance
    Logs + weights go to: /content/drive/MyDrive/CSC2503/{model_name}_training_logs/
    """
    logs_root = base_logs_root / f"{model_name}_training_logs"
    logs_root.mkdir(parents=True, exist_ok=True)
    weight_path  = logs_root / "best.pth"
    metrics_path = logs_root / "metrics.json"

    # If weights already exist and we don't want to retrain, just load and return model
    if weight_path.exists() and not force_retrain:
        print(f"[{model_name}] Found existing weights at {weight_path}. Loading model...")
        model = create_model_fn(num_classes)
        state = torch.load(weight_path, map_location=device)
        model.load_state_dict(state)
        model.to(device)
        model.eval()
        # Optionally load metrics as well
        if metrics_path.exists():
            with open(metrics_path, "r") as f:
                history = json.load(f)
            print(f"[{model_name}] Loaded existing metrics from {metrics_path}.")
        else:
            history = None
        return model, history

    # Otherwise: train from scratch / pretrained backbone
    print(f"[{model_name}] No existing weights found or retrain forced. Training for {num_epochs} epochs...")
    model = create_model_fn(num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

    best_val_acc1 = 0.0
    history = {
        "epoch": [],
        "train_loss": [],
        "train_acc1": [],
        "train_acc5": [],
        "val_loss": [],
        "val_acc1": [],
        "val_acc5": [],
    }

    for epoch in tqdm(range(1, num_epochs + 1), desc=f"{model_name} epochs"):
        start_time = time.time()
        train_loss, train_acc1, train_acc5 = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc1, val_acc5       = evaluate(model, val_loader, criterion, device)
        elapsed = time.time() - start_time

        print(
            f"[{model_name}] Epoch {epoch:02d}/{num_epochs} "
            f"- {elapsed:.1f}s "
            f"- train_loss: {train_loss:.4f}, train_acc1: {train_acc1:.4f}, train_acc5: {train_acc5:.4f} "
            f"- val_loss: {val_loss:.4f}, val_acc1: {val_acc1:.4f}, val_acc5: {val_acc5:.4f}"
        )

        # log history
        history["epoch"].append(epoch)
        history["train_loss"].append(train_loss)
        history["train_acc1"].append(train_acc1)
        history["train_acc5"].append(train_acc5)
        history["val_loss"].append(val_loss)
        history["val_acc1"].append(val_acc1)
        history["val_acc5"].append(val_acc5)

        # save best weights
        if val_acc1 > best_val_acc1:
            best_val_acc1 = val_acc1
            torch.save(model.state_dict(), weight_path)
            print(f"[{model_name}] New best val_acc1: {best_val_acc1:.4f}. Saved weights to {weight_path}")

    # save metrics to JSON for later analysis
    with open(metrics_path, "w") as f:
        json.dump(history, f, indent=2)
    print(f"[{model_name}] Training complete. Metrics saved to {metrics_path}")

    # load best weights back into model (in case last epoch wasn't best)
    state = torch.load(weight_path, map_location=device)
    model.load_state_dict(state)
    model.to(device)
    model.eval()

    return model, history

In [None]:
# train or load SWIN
swin_model, swin_history = train_or_load_model(
    model_name="swin",
    create_model_fn=create_swin_model,
    num_epochs=15,
    lr=1e-4,
    wd=1e-4,
    force_retrain=False,    # set True if you want to retrain
)

# Train or load RepVGG
repvgg_model, repvgg_history = train_or_load_model(
    model_name="repvgg",
    create_model_fn=create_repvgg_model,
    num_epochs=15,
    lr=1e-4,
    wd=1e-4,
    force_retrain=False,    # set True if you want to retrain

# Load the trained model
**(RUN THIS SECTION)**
- Access the models by calling 
    - yolo_model
    - repvgg_model
    - swin_model

In [None]:
# yolo 
best_candidates = sorted(yolo_logs_root.glob("weights/best.pt"))
if best_candidates:
    # pick the latest run (last in sorted list)
    best_weight = best_candidates[-1]
    print(f"Found existing trained model: {best_weight}")
    print("Loading model without retraining...")
    yolo_model = YOLO(str(best_weight))
    print("model loaded! call 'yolo_model' to access the model.")
else: 
    print("No existing YOLO model found. Please run the training cells above to train the model.")

In [None]:
# helpers for repvgg and swin
num_classes = 2

def create_swin_model(num_classes):
    swin = models.swin_t(weights=models.Swin_T_Weights.IMAGENET1K_V1)
    # replace head for binary classification
    swin.head = nn.Linear(swin.head.in_features, num_classes)
    return swin


def create_repvgg_model(num_classes):
    # RepVGG via timm
    model = timm.create_model("repvgg_b0", pretrained=True, num_classes=num_classes)
    return model

def load_existing_model(model_name, create_model_fn):
    weight_path  = base_logs_root / f"{model_name}_training_logs" / "best.pth"
    metrics_path = base_logs_root / f"{model_name}_training_logs" / "metrics.json"
    if weight_path.exists():
        print(f"[{model_name}] Found existing weights at {weight_path}. Loading model...")
        model = create_model_fn(num_classes)
        state = torch.load(weight_path, map_location=device)
        model.load_state_dict(state)
        model.to(device)
        model.eval()
        # Optionally load metrics as well
        if metrics_path.exists():
            with open(metrics_path, "r") as f:
                history = json.load(f)
            print(f"[{model_name}] Loaded existing metrics from {metrics_path}.")
        else:
            history = None
        return model, history
    else:
        print(f"[{model_name}] No existing weights found at {weight_path}. Please train the model first.")
        return None, None

In [None]:
# repvgg
repvgg_model, repvgg_history = load_existing_model("repvgg", create_repvgg_model)

In [None]:
# swin
swin_model, swin_history = load_existing_model("swin", create_swin_model)

# Testing with test set data 
**(Skip this section if no need to test on test sets)**)
- 2 test sets will be used (complete test set & OOD test set)
    - complete test set: kaggle (3 types) + custom (7 types) + ood (2 types)
    - OOD (out-of-distribution) test set: ood (2 types) i.e., peach and tomato
- Calculate the classification metrics on these two sets

In [None]:
# test loaders & helpers

test_ds = datasets.ImageFolder(
    test_root,
    transform=val_transform   # same normalization as val
)

test_loader = DataLoader(
    test_ds, batch_size=32, shuffle=False,
    num_workers=2, pin_memory=True
)

print("Test classes:", test_ds.classes)

# test - ood loader 
ood_ds = datasets.ImageFolder(
    root=ood_test_root,
    transform=val_transform     # same preprocessing
)
ood_loader = DataLoader(
    ood_ds,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

def evaluate_on_test(
    model,
    loader,
    device,
    class_names,
    pos_class_index=1,   # index of "positive" class, here: 'good' if classes=['bad','good']
):
    """
    Run model on test loader, compute confusion matrix + metrics, and plot.

    Returns:
        cm (ndarray): confusion matrix
        metrics (dict): aggregated metrics
    """
    model.eval()
    all_logits = []
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)          # logits [B, C]
            probs = torch.softmax(outputs, dim=1)  # [B, C]

            preds = torch.argmax(probs, dim=1)

            all_logits.append(outputs.cpu())
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    all_logits = torch.cat(all_logits, dim=0).numpy()
    all_preds  = torch.cat(all_preds, dim=0).numpy()
    all_labels = torch.cat(all_labels, dim=0).numpy()

    # -------------------
    # Confusion matrix
    # -------------------
    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
    disp.plot(cmap="Blues", values_format="d")
    plt.title("Test Confusion Matrix")
    plt.show()

    # -------------------
    # Basic metrics
    # -------------------
    acc = accuracy_score(all_labels, all_preds)

    # per-class precision/recall/F1 and macro averages
    precision, recall, f1, support = precision_recall_fscore_support(
        all_labels, all_preds, labels=range(len(class_names)), zero_division=0
    )

    macro_precision = precision.mean()
    macro_recall    = recall.mean()
    macro_f1        = f1.mean()

    print("Classification report:")
    print(
        classification_report(
            all_labels, all_preds, target_names=class_names, zero_division=0
        )
    )

    # -------------------
    # ROC & AUC (binary)
    # -------------------
    metrics = {
        "accuracy": acc,
        "per_class_precision": dict(zip(class_names, precision)),
        "per_class_recall": dict(zip(class_names, recall)),
        "per_class_f1": dict(zip(class_names, f1)),
        "macro_precision": macro_precision,
        "macro_recall": macro_recall,
        "macro_f1": macro_f1,
    }

    if len(class_names) == 2:
        # take probability of positive class (e.g., 'good' -> index 1)
        probs = torch.softmax(torch.from_numpy(all_logits), dim=1).numpy()
        pos_probs = probs[:, pos_class_index]

        fpr, tpr, thresholds = roc_curve(all_labels, pos_probs, pos_label=pos_class_index)
        auc_value = roc_auc_score(all_labels, pos_probs)

        metrics["roc_auc"] = auc_value

        plt.figure()
        plt.plot(fpr, tpr, label=f"ROC curve (AUC = {auc_value:.3f})")
        plt.plot([0, 1], [0, 1], "k--", label="Random")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title("ROC Curve (test)")
        plt.legend(loc="lower right")
        plt.grid(True)
        plt.show()

    return cm, metrics


def compute_metrics_from_matrix(cm):
    cm = np.array(cm, dtype=float)
    n_classes = cm.shape[0]

    precisions = []
    recalls = []
    f1s = []

    for i in range(n_classes):
        TP = cm[i, i]
        FP = cm[:, i].sum() - TP
        FN = cm[i, :].sum() - TP

        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall    = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1        = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        precisions.append(precision)
        recalls.append(recall)
        f1s.append(f1)

    # Print results
    for i in range(n_classes):
        print(f"Class {i}:")
        print(f"  Precision: {precisions[i]:.4f}")
        print(f"  Recall:    {recalls[i]:.4f}")
        print(f"  F1-score:  {f1s[i]:.4f}")
        print()

    return precisions, recalls, f1s

In [None]:
# yolo
print("yolo on test set:")
results_yolo_test = yolo_model.val(data="/content/project/fruits_binary", split="test")

# confusion matrix object
cm_yolo_test = results_yolo_test.confusion_matrix.matrix
class_names = list(yolo_model.names.values())
class_names.append('_')

print("confusion matrix:")
print(pd.DataFrame(
    cm_yolo_test,
    index=class_names,
    columns=class_names
))
compute_metrics_from_matrix(cm_yolo_test)
print('(ignore class 2)')

print('')

print('yolo on ood set only:')
results_yolo_ood = yolo_model.val(data="/content/project/ood_test")

cm_yolo_ood = results_yolo_ood.confusion_matrix.matrix

print("confusion matrix:")
print(pd.DataFrame(
    cm_yolo_ood,
    index=class_names,
    columns=class_names
))
compute_metrics_from_matrix(cm_yolo_ood)
print('(ignore class 2)')

In [None]:
# repvgg
print('repvgg on test set:')
cm_repvgg, metrics_repvgg = evaluate_on_test(
    model=repvgg_model,
    loader=test_loader,
    device=device,
    class_names=test_ds.classes,
    pos_class_index=1,
)

print('repvgg on ood set only:')
cm_repvgg_ood, metrics_repvgg_ood = evaluate_on_test(
    model=repvgg_model,
    loader=ood_loader,
    device=device,
    class_names=test_ds.classes,   # ['bad', 'good']
    pos_class_index=1,             # 'good' is index 1
)

In [None]:
# swin 
print('swin on test set:')
cm_swin, metrics_swin = evaluate_on_test(
    model=swin_model,
    loader=test_loader,
    device=device,
    class_names=test_ds.classes,   # ['bad', 'good']
    pos_class_index=1,             # 'good' is index 1
)

print('swin on ood set only:')
cm_swin_ood, metrics_swin_ood = evaluate_on_test(
    model=swin_model,
    loader=ood_loader,
    device=device,
    class_names=test_ds.classes,   # ['bad', 'good']
    pos_class_index=1,             # 'good' is index 1
)

# Testing with homemade videos/photos
- iPhone files allowed (.HEIC, .MOV)

### detection model

In [None]:
# yolo detection model for bounding boxes 
detect_model = YOLO("yolo11n.pt")

### draw bounding boxes 

In [None]:
def annotate_file(
    file_name,
    input_dir="/content/drive/MyDrive/CSC2503/fruits_vid",
    output_dir="/content/drive/MyDrive/CSC2503/fruits_vid_annotated",
    detect_model=detect_model,
):
    """
    Annotate a single image or video file by filename located in input_dir (default provided).
    If `file_name` is an absolute path or contains a directory part, that path will be used directly.
    Returns the path to the saved annotated file (str) or None on failure.
    """
    # Resolve input file path: use provided absolute/path-like name as-is, otherwise join with input_dir
    file_name = str(file_name)
    p = Path(file_name)
    if p.is_absolute() or p.parent != Path("."):
        file_path = p
    else:
        file_path = Path(input_dir) / p

    if not file_path.exists():
        print(f"File not found: {file_path}")
        return None

    out_dir = Path(output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    image_exts = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".heic", ".heif", ".webp", ".gif"}
    video_exts = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".mpg", ".mpeg"}

    ext = file_path.suffix.lower()

    if ext in image_exts:
        print("Processing image:", file_path)
        pil_img = Image.open(str(file_path)).convert("RGB")
        frame = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

        results = detect_model(frame, verbose=False)

        for r in results:
            for box in r.boxes:
                x1, y1, x2, y2 = box.xyxy.cpu().numpy().astype(int)[0]
                conf = float(box.conf.cpu())
                label = detect_model.names[int(box.cls.cpu())]

                cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 3)
                cv2.putText(frame, f"{label} {conf:.2f}", (x1, y1 - 10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)

        out_path = out_dir / (file_path.stem + "_det.jpg")
        cv2.imwrite(str(out_path), frame)
        print("Saved:", out_path)
        return str(out_path)

    elif ext in video_exts:
        print("Processing video:", file_path)
        cap = cv2.VideoCapture(str(file_path))
        if not cap.isOpened():
            print("Cannot open:", file_path)
            return None

        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS) or 30.0

        out_path = out_dir / (file_path.stem + "_det.mp4")
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        out = cv2.VideoWriter(str(out_path), fourcc, fps, (width, height))

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            results = detect_model(frame, verbose=False)

            for r in results:
                for box in r.boxes:
                    x1, y1, x2, y2 = box.xyxy.cpu().numpy().astype(int)[0]
                    conf = float(box.conf.cpu())
                    label = detect_model.names[int(box.cls.cpu())]

                    cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    cv2.putText(frame, f"{label} {conf:.2f}", (x1, y1 - 10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

            out.write(frame)

        cap.release()
        out.release()
        print("Saved:", out_path)
        return str(out_path)

    else:
        raise ValueError(f"Unsupported file extension: {ext}. Supported image ext: {sorted(image_exts)}; video ext: {sorted(video_exts)}")

In [None]:
# example use case
files_to_annotate = []
for f in files_to_annotate: 
    annotate_file(f)

### classify using the three models (yolo, repvgg, swin)

In [None]:
class_names = ["not_fresh", "fresh"]
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

# repvgg and swin
@torch.no_grad()
def predict_torch_model(model, pil_img):
    x = inference_transform(pil_img).unsqueeze(0).to(device)
    logits = model(x)
    probs = torch.softmax(logits, dim=1)[0]
    conf, idx = torch.max(probs, dim=0)
    return class_names[idx.item()], conf.item()

# yolo
def predict_yolo_cls(yolo_model, pil_img):
    # YOLO can take numpy or PIL; we use PIL directly
    results = yolo_model(pil_img, verbose=False)
    probs = results[0].probs
    idx = int(probs.top1)
    conf = float(probs.top1conf)
    return class_names[idx], conf

# combine the three
def predict_all_models(pil_crop):
    # Measure RepVGG
    t0 = time.time()
    rep_label, rep_conf   = predict_torch_model(repvgg_model, pil_crop)
    t1 = time.time()

    # Measure Swin
    swin_label, swin_conf = predict_torch_model(swin_model, pil_crop)
    t2 = time.time()

    # Measure YOLO
    yolo_label, yolo_conf = predict_yolo_cls(yolo_model, pil_crop)
    t3 = time.time()

    times = {
        "RepVGG": t1 - t0,
        "Swin":   t2 - t1,
        "YOLO":   t3 - t2,
    }

    preds = {
        "RepVGG": (rep_label, rep_conf),
        "Swin":   (swin_label, swin_conf),
        "YOLO":   (yolo_label, yolo_conf),
    }
    return preds, times

# USE THIS FUNCTION TO PREDICT ON VIDEO
def predict_on_video(mp4_name, save_video=True, save_speed_metrics=True):
    """
    input: mp4_name (str) - name of the mp4 video file in fruits_vid_annotated
    output: optionally saves annotated video with predictions overlaid in fruits_vid_classified
            optionally appends average per-model latencies (ms) to speed_metrics.csv in the out dir
    """
    # enter your video path here
    video_path = f"/content/drive/MyDrive/CSC2503/fruits_vid_annotated/{mp4_name}"
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print(f"Video could not be opened: {video_path}")
        return
    else:
        width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps    = cap.get(cv2.CAP_PROP_FPS)

        # Ensure output directory exists (used for both video and CSV)
        output_dir = "/content/drive/MyDrive/CSC2503/fruits_vid_classified"
        os.makedirs(output_dir, exist_ok=True)

        output_path = os.path.join(output_dir, f"{Path(mp4_name).stem}_classified.mp4")

        out = None
        if save_video:
            # Try 'avc1' (H.264) for better Drive/Web compatibility, fallback to 'mp4v'
            fourcc = cv2.VideoWriter_fourcc(*'avc1')
            out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

            if not out.isOpened():
                print("avc1 codec failed to initialize, falling back to mp4v...")
                fourcc = cv2.VideoWriter_fourcc(*'mp4v')
                out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

            if not out.isOpened():
                print("Error: Could not open video writer. Disabling video saving.")
                out = None
                save_video = False
            else:
                print(f"Processing video... Saving to: {output_path}")
        else:
            print("Processing video... (video saving disabled)")

        frame_counter = 0
        # Track accumulated time per model
        model_times = {"RepVGG": 0.0, "Swin": 0.0, "YOLO": 0.0}

        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break

            frame_counter += 1

            # 1) DETECTION
            det_results = detect_model(frame_bgr, verbose=False)

            for r in det_results:
                boxes = r.boxes.xyxy.cpu().numpy()
                confs = r.boxes.conf.cpu().numpy()
                clss  = r.boxes.cls.cpu().numpy().astype(int)

                if len(boxes) == 0:
                    continue

                # pick the most confident detection for now
                idx = confs.argmax()
                x1, y1, x2, y2 = boxes[idx].astype(int)

                # clamp to frame
                x1 = max(0, x1); y1 = max(0, y1)
                x2 = min(width, x2); y2 = min(height, y2)

                # draw detection box
                cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0,255,0), 5)

                # 2) CROP + CLASSIFICATION
                crop_bgr = frame_bgr[y1:y2, x1:x2]
                if crop_bgr.size == 0:
                    continue

                crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
                pil_crop = Image.fromarray(crop_rgb)

                # time the three-model prediction
                preds, times = predict_all_models(pil_crop)
                
                for m, t in times.items():
                    model_times[m] += t

                # 3) OVERLAY TEXT
                # Settings for high-res video
                font_scale = 2.5
                thickness_out = 8
                thickness_in = 3
                line_spacing = 90
                
                # Determine vertical position (above box if space permits, else inside)
                total_text_height = len(preds) * line_spacing
                y_anchor = y1 - total_text_height - 20
                if y_anchor < 0:
                    y_anchor = y1 + 20

                for i, (name, (label, conf)) in enumerate(preds.items()):
                    text = f"{name}: {label} ({conf:.2f})"
                    y = y_anchor + (i + 1) * line_spacing

                    cv2.putText(frame_bgr, text, (x1, int(y)),
                                cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0,0,0), thickness_out, cv2.LINE_AA)
                    cv2.putText(frame_bgr, text, (x1, int(y)),
                                cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255,255,255), thickness_in, cv2.LINE_AA)

            # write frame only if saving enabled
            if save_video and out is not None:
                out.write(frame_bgr)

        cap.release()
        if out is not None:
            out.release()

        # ---- speed metrics ----
        print(f"Frames processed:    {frame_counter}")
        print("Average latency per frame (classification phase only):")

        denom = max(frame_counter, 1)
        avg_ms = {}
        for m in ["RepVGG", "Swin", "YOLO"]:
            t_total = model_times[m]
            avg = (t_total / denom) * 1000 if denom > 0 else 0.0
            avg_ms[m] = avg
            print(f"  {m}: {avg:.2f} ms")

        # Save speed metrics to CSV if requested
        if save_speed_metrics:
            csv_path = os.path.join(output_dir, "speed_metrics.csv")
            header = ["filename", "frames", "RepVGG_ms", "Swin_ms", "YOLO_ms", "timestamp"]
            row = {
                "filename": mp4_name,
                "frames": frame_counter,
                "RepVGG_ms": f"{avg_ms['RepVGG']:.2f}",
                "Swin_ms": f"{avg_ms['Swin']:.2f}",
                "YOLO_ms": f"{avg_ms['YOLO']:.2f}",
                "timestamp": datetime.datetime.now().isoformat()
            }

            # Use csv module to create or append
            file_exists = os.path.exists(csv_path)
            with open(csv_path, "a", newline="") as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=header)
                if not file_exists:
                    writer.writeheader()
                writer.writerow(row)
            print(f"Saved speed metrics to: {csv_path}")

In [None]:
# example use case
files_to_predict = [
    'IMG_5582_det.mp4',
    'IMG_5580_det.mp4',
]
for f in files_to_predict:
    predict_on_video(f)