# Trashnet Dataset Training
Training MobileNetV3-Small and ViT-Small models on the garythung/trashnet dataset (6 classes)


In [None]:
# Data Loading and Preprocessing
from __future__ import annotations
import random
from pathlib import Path
from typing import Tuple, Optional, Dict
import platform

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
from datasets import load_dataset

print("Imports successful!")

# Reproducibility
def set_seed(seed: int = 56) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_num_workers(requested_workers: int = 2) -> int:
    """Windows-safe num_workers configuration"""
    if platform.system() == 'Windows':
        print(f"[Platform] Windows detected: Using num_workers=0")
        return 0
    return requested_workers

set_seed(56)


In [None]:
# HuggingFace Dataset Wrapper
class HuggingFaceImageDataset(Dataset):
    """Wraps a HuggingFace dataset with PyTorch transforms."""
    
    def __init__(self, hf_dataset, transform=None, label_key="label", image_key="image"):
        self.dataset = hf_dataset
        self.transform = transform
        self.label_key = label_key
        self.image_key = image_key
        
        # Extract class names and targets for weighted sampling
        if hasattr(hf_dataset.features[label_key], 'names'):
            self.classes = hf_dataset.features[label_key].names
        else:
            self.classes = []
        
        # Pre-extract all labels for weighted sampler
        self.targets = [item[label_key] for item in hf_dataset]
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item[self.image_key]
        label = item[self.label_key]
        
        # Convert to PIL Image if not already
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        
        # Ensure RGB mode
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

print("✓ Dataset wrapper defined")


In [None]:
# Load Trashnet Dataset
print("[HuggingFace] Loading trashnet dataset...")
ds = load_dataset("garythung/trashnet")

print(f"[Dataset] Available splits: {list(ds.keys())}")
print(f"[Dataset] Total samples: {len(ds['train'])}")

# Create train/val split (80/20)
train_val_split = ds['train'].train_test_split(test_size=0.2, seed=56)

# Define transforms (matching ImageNet preprocessing)
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

val_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

# Create datasets
train_ds = HuggingFaceImageDataset(train_val_split['train'], transform=train_tfms)
val_ds = HuggingFaceImageDataset(train_val_split['test'], transform=val_tfms)

num_classes = len(train_ds.classes)
print(f"\n✓ Classes ({num_classes}): {train_ds.classes}")
print(f"✓ Train samples: {len(train_ds)}")
print(f"✓ Val samples: {len(val_ds)}")

# Create weighted sampler for class balancing
train_targets = np.array(train_ds.targets)
counts = np.bincount(train_targets)
class_weights = 1.0 / np.clip(counts, 1, None)
sample_weights = class_weights[train_targets]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(train_ds), replacement=True)

print(f"\n✓ Class distribution (train):")
for i, cls in enumerate(train_ds.classes):
    print(f"  {cls:15s}: {counts[i]:4d} samples (weight: {class_weights[i]:.3f})")

# Create DataLoaders
num_workers = get_num_workers(4)
train_dl = DataLoader(train_ds, batch_size=64, sampler=sampler, 
                      num_workers=num_workers, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=128, shuffle=False, 
                    num_workers=num_workers, pin_memory=True)

print(f"\n✓ DataLoaders created (num_workers={num_workers})")


In [None]:
# Create Models
import timm
from torch import nn

def make_model(name, num_classes):
    """Create a pretrained model with dropout regularization"""
    m = timm.create_model(name, pretrained=True, drop_rate=0.2, drop_path_rate=0.1, num_classes=num_classes)
    return m

# Create MobileNetV3-Small and ViT-Small
mobilenet_small = make_model("mobilenetv3_small_100", num_classes)
vit_small = make_model("vit_small_patch16_224", num_classes)

print(f"✓ MobileNetV3-Small params: {sum(p.numel() for p in mobilenet_small.parameters())/1e6:.2f}M")
print(f"✓ ViT-Small params: {sum(p.numel() for p in vit_small.parameters())/1e6:.2f}M")


In [None]:
# Training Function with Mixed Precision
from torch.amp import autocast, GradScaler
from sklearn.metrics import f1_score, accuracy_score

def train_model(model, train_dl, val_dl, epochs=20, lr=5e-4, wd=0.05, device="cuda"):
    """
    Train a model with:
    - Mixed precision training
    - AdamW optimizer with weight decay
    - Linear warmup (3 epochs) + Cosine annealing
    - Label smoothing
    - Best model checkpointing based on macro F1
    """
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    
    # Learning rate schedule: warmup + cosine
    warmup = torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.1, total_iters=3)
    cosine = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs-3)
    sched = torch.optim.lr_scheduler.SequentialLR(opt, [warmup, cosine], milestones=[3])
    
    crit = nn.CrossEntropyLoss(label_smoothing=0.1)
    scaler = GradScaler('cuda', enabled=(device.startswith("cuda")))
    best = {"f1": -1, "state": None, "epoch": 0}
    
    for ep in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for x, y in train_dl:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            
            with autocast('cuda', enabled=(device.startswith("cuda"))):
                logits = model(x)
                loss = crit(logits, y)
            
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            train_loss += loss.item()
        
        sched.step()
        
        # Validation
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for x, y in val_dl:
                x = x.to(device, non_blocking=True)
                logits = model(x)
                preds.append(logits.argmax(1).cpu())
                gts.append(y)
        
        p = torch.cat(preds).numpy()
        g = torch.cat(gts).numpy()
        f1 = f1_score(g, p, average="macro")
        acc = accuracy_score(g, p)
        
        # Save best model
        if f1 > best["f1"]:
            best = {"f1": f1, "state": model.state_dict(), "epoch": ep+1}
        
        print(f"ep {ep+1:2d}: acc {acc:.4f}  macroF1 {f1:.4f}  (best F1: {best['f1']:.4f} @ ep{best['epoch']})")
    
    # Load best model
    model.load_state_dict(best["state"])
    print(f"\n✓ Training complete. Best model from epoch {best['epoch']} (F1: {best['f1']:.4f})")
    return model

print("✓ Training function defined")


In [None]:
# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")
    device = "cuda"
else:
    print("Device: CPU")
    device = "cpu"


In [None]:
# Train MobileNetV3-Small
print("="*60)
print("TRAINING MOBILENETV3-SMALL")
print("="*60)
mobilenet_trained = train_model(mobilenet_small, train_dl, val_dl, epochs=20, lr=5e-4, wd=0.05, device=device)


In [None]:
# Train ViT-Small
# ViT typically needs slightly lower learning rate
print("\n" + "="*60)
print("TRAINING VIT-SMALL")
print("="*60)
vit_trained = train_model(vit_small, train_dl, val_dl, epochs=20, lr=3e-4, wd=0.05, device=device)


In [None]:
# Performance Comparison and Benchmarking
import time

def ms_per_image(model, device="cuda"):
    """Measure inference speed in milliseconds per image"""
    model.eval().to(device)
    x = torch.randn(1, 3, 224, 224, device=device)
    
    # Warmup
    for _ in range(5):
        model(x)
    if device == "cuda":
        torch.cuda.synchronize()
    
    # Benchmark
    t0 = time.time()
    for _ in range(50):
        model(x)
    if device == "cuda":
        torch.cuda.synchronize()
    
    return 1000 * (time.time() - t0) / 50

def get_final_metrics(model, val_dl, device):
    """Get final accuracy and F1 score"""
    model.eval().to(device)
    preds, gts = [], []
    with torch.no_grad():
        for x, y in val_dl:
            x = x.to(device, non_blocking=True)
            logits = model(x)
            preds.append(logits.argmax(1).cpu())
            gts.append(y)
    
    p = torch.cat(preds).numpy()
    g = torch.cat(gts).numpy()
    return accuracy_score(g, p), f1_score(g, p, average="macro")

# Inference speed comparison
print("\n" + "="*60)
print("INFERENCE SPEED COMPARISON")
print("="*60)
mobilenet_speed = ms_per_image(mobilenet_trained, device)
vit_speed = ms_per_image(vit_trained, device)
print(f"MobileNetV3-Small: {mobilenet_speed:.2f} ms/image")
print(f"ViT-Small:         {vit_speed:.2f} ms/image")
print(f"Speed ratio:       {vit_speed/mobilenet_speed:.2f}x (ViT/MobileNet)")

# Final validation performance
print("\n" + "="*60)
print("FINAL VALIDATION PERFORMANCE")
print("="*60)

models = [
    ("MobileNetV3-Small", mobilenet_trained),
    ("ViT-Small", vit_trained)
]

results = []
for name, model in models:
    acc, f1 = get_final_metrics(model, val_dl, device)
    params = sum(p.numel() for p in model.parameters()) / 1e6
    speed = ms_per_image(model, device)
    results.append((name, acc, f1, params, speed))
    print(f"{name:20s} | Acc: {acc:.4f} | F1: {f1:.4f} | Params: {params:.2f}M | Speed: {speed:.2f}ms")

print("="*60)


In [None]:
# Save Models
# Save MobileNetV3-Small in PyTorch and ONNX formats
torch.save({
    "model": "mobilenetv3_small_100",
    "classes": train_ds.classes,
    "num_classes": num_classes,
    "state_dict": mobilenet_trained.state_dict(),
    "dataset": "garythung/trashnet"
}, "mobilenetv3_small_trashnet.pt")

# Export to ONNX for deployment
dummy = torch.randn(1, 3, 224, 224)
mobilenet_trained.eval().cpu()
torch.onnx.export(
    mobilenet_trained, 
    dummy, 
    "mobilenetv3_small_trashnet.onnx",
    input_names=["input"], 
    output_names=["logits"], 
    opset_version=17,
    dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}}
)

# Save ViT-Small
torch.save({
    "model": "vit_small_patch16_224",
    "classes": train_ds.classes,
    "num_classes": num_classes,
    "state_dict": vit_trained.state_dict(),
    "dataset": "garythung/trashnet"
}, "vit_small_trashnet.pt")

print("✓ Saved MobileNetV3-Small:")
print("  - mobilenetv3_small_trashnet.pt (PyTorch)")
print("  - mobilenetv3_small_trashnet.onnx (ONNX)")
print("✓ Saved ViT-Small:")
print("  - vit_small_trashnet.pt (PyTorch)")


In [None]:
# Optional: Confusion Matrix Analysis
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

def plot_confusion_matrix(model, val_dl, classes, device, title="Confusion Matrix"):
    """Plot confusion matrix for model predictions"""
    model.eval().to(device)
    preds, gts = [], []
    with torch.no_grad():
        for x, y in val_dl:
            x = x.to(device, non_blocking=True)
            logits = model(x)
            preds.append(logits.argmax(1).cpu())
            gts.append(y)
    
    p = torch.cat(preds).numpy()
    g = torch.cat(gts).numpy()
    
    # Confusion matrix
    cm = confusion_matrix(g, p)
    
    # Plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    # Classification report
    print(f"\n{title} - Classification Report:")
    print(classification_report(g, p, target_names=classes, digits=4))

# Plot for both models
print("="*60)
print("CONFUSION MATRIX ANALYSIS")
print("="*60)
plot_confusion_matrix(mobilenet_trained, val_dl, train_ds.classes, device, 
                     title="MobileNetV3-Small Confusion Matrix")
plot_confusion_matrix(vit_trained, val_dl, train_ds.classes, device, 
                     title="ViT-Small Confusion Matrix")


## Training Summary

This notebook trains two models on the **Trashnet dataset** (6 classes of waste):
- **cardboard, glass, metal, paper, plastic, trash**

### Models Trained:
1. **MobileNetV3-Small** (~1.5M params) - Efficient mobile-optimized CNN
2. **ViT-Small** (~22M params) - Vision Transformer with attention mechanism

### Training Configuration:
- **Optimizer**: AdamW with weight decay (0.05)
- **Learning Rate**: 5e-4 (MobileNet), 3e-4 (ViT)
- **LR Schedule**: Linear warmup (3 epochs) + Cosine annealing
- **Data Augmentation**: RandomResizedCrop, HorizontalFlip, ColorJitter
- **Regularization**: Dropout (0.2), DropPath (0.1), Label smoothing (0.1)
- **Training**: Mixed precision (FP16) with gradient scaling
- **Epochs**: 20
- **Batch Size**: 64 (train), 128 (val)
- **Class Balancing**: Weighted random sampling

### Expected Performance:
- **MobileNetV3-Small**: ~85-90% accuracy, fast inference (~5-7ms)
- **ViT-Small**: ~88-93% accuracy, moderate inference (~5-6ms)

### Output Files:
- `mobilenetv3_small_trashnet.pt` - PyTorch checkpoint
- `mobilenetv3_small_trashnet.onnx` - ONNX export for deployment
- `vit_small_trashnet.pt` - PyTorch checkpoint


In [None]:
# Optional: Test Loading Saved Models
# This cell demonstrates how to load and use the saved models

def load_model_from_checkpoint(checkpoint_path, device="cuda"):
    """Load a trained model from checkpoint"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Recreate model
    model = timm.create_model(
        checkpoint["model"], 
        pretrained=False,
        num_classes=checkpoint["num_classes"]
    )
    
    # Load weights
    model.load_state_dict(checkpoint["state_dict"])
    model.to(device)
    model.eval()
    
    return model, checkpoint["classes"]

# Example: Load MobileNetV3-Small
print("Testing model loading...")
loaded_model, classes = load_model_from_checkpoint("mobilenetv3_small_trashnet.pt", device)
print(f"✓ Loaded MobileNetV3-Small")
print(f"  Classes: {classes}")

# Quick test inference
test_input = torch.randn(1, 3, 224, 224, device=device)
with torch.no_grad():
    output = loaded_model(test_input)
    pred_class = output.argmax(1).item()
    print(f"  Test inference: predicted class {pred_class} ({classes[pred_class]})")

print("\n✓ Model loading and inference working correctly!")
