# RealWaste Dataset Training
Training MobileNetV3-Small and ViT-Small models on the RealWaste dataset (9 classes)


In [1]:
# Data Loading and Preprocessing
from __future__ import annotations
import random
from pathlib import Path
from typing import Tuple, Optional, Dict
import platform

import numpy as np
import torch
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torchvision import datasets, transforms

print("Imports successful!")

# Reproducibility
def set_seed(seed: int = 56) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_num_workers(requested_workers: int = 2) -> int:
    """Windows-safe num_workers configuration"""
    if platform.system() == 'Windows':
        print(f"[Platform] Windows detected: Using num_workers=0")
        return 0
    return requested_workers

set_seed(56)


Imports successful!


In [None]:
# Load RealWaste Dataset
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

# Point to RealWaste dataset
DATA_ROOT = Path("realwaste/realwaste-main/RealWaste")

if not DATA_ROOT.exists():
    raise FileNotFoundError(f"Dataset path does not exist: {DATA_ROOT}")

print(f"Using dataset root: {DATA_ROOT}")

# Filter for JPG images (RealWaste uses .jpg)
def is_jpg_image(p):
    return str(p).lower().endswith(".jpg")

# Load full dataset to get class info
full = datasets.ImageFolder(root=str(DATA_ROOT), transform=None, is_valid_file=is_jpg_image)
print(f"Found {len(full.classes)} classes: {full.classes}")
print(f"Total images: {len(full.samples)}")

# Print class distribution
targets = [y for _, y in full.samples]
print("\nClass distribution:")
for i, class_name in enumerate(full.classes):
    count = sum(1 for t in targets if t == i)
    print(f"  {class_name:25s}: {count:4d} images")

# Stratified split (70% train, 20% val, 10% test)
from sklearn.model_selection import train_test_split

# First split: separate test set (10%)
tr_val_idx, te_idx = train_test_split(
    np.arange(len(targets)), test_size=0.1, random_state=56, stratify=targets
)

# Second split: train/val from remaining (70/20 of original = ~77.8/22.2 of remaining)
tr_targets = [targets[i] for i in tr_val_idx]
tr_idx, va_idx = train_test_split(
    tr_val_idx, test_size=0.222, random_state=56, stratify=tr_targets  # 0.222 * 0.9 ≈ 0.2
)

print(f"\n✓ Train samples: {len(tr_idx)} ({len(tr_idx)/len(targets)*100:.1f}%)")
print(f"✓ Val samples: {len(va_idx)} ({len(va_idx)/len(targets)*100:.1f}%)")
print(f"✓ Test samples: {len(te_idx)} ({len(te_idx)/len(targets)*100:.1f}%)")
print(f"✓ No overlap: {len(set(tr_idx) & set(va_idx) & set(te_idx)) == 0}")


In [None]:
# Create transforms and datasets
train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

val_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

# Create datasets with appropriate transforms
train_ds = Subset(
    datasets.ImageFolder(str(DATA_ROOT), transform=train_tfms, is_valid_file=is_jpg_image),
    tr_idx
)
val_ds = Subset(
    datasets.ImageFolder(str(DATA_ROOT), transform=val_tfms, is_valid_file=is_jpg_image),
    va_idx
)
test_ds = Subset(
    datasets.ImageFolder(str(DATA_ROOT), transform=val_tfms, is_valid_file=is_jpg_image),
    te_idx
)

num_classes = len(full.classes)
print(f"\n✓ Number of classes: {num_classes}")
print(f"✓ Classes: {full.classes}")


In [None]:
# Create weighted sampler for class balancing
train_targets = np.array([targets[i] for i in tr_idx])
counts = np.bincount(train_targets)
class_weights = 1.0 / np.clip(counts, 1, None)
sample_weights = class_weights[train_targets]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(tr_idx), replacement=True)

print(f"✓ Class distribution (train) with weights:")
for i, cls in enumerate(full.classes):
    print(f"  {cls:25s}: {counts[i]:4d} samples (weight: {class_weights[i]:.3f})")

# Create DataLoaders
num_workers = get_num_workers(4)
train_dl = DataLoader(train_ds, batch_size=64, sampler=sampler, 
                      num_workers=num_workers, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=128, shuffle=False, 
                    num_workers=num_workers, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False,
                     num_workers=num_workers, pin_memory=True)

print(f"\n✓ DataLoaders created (num_workers={num_workers})")
print(f"  Train batches: {len(train_dl)}")
print(f"  Val batches: {len(val_dl)}")
print(f"  Test batches: {len(test_dl)}")


In [None]:
# Create Models
import timm
from torch import nn

def make_model(name, num_classes):
    """Create a pretrained model with dropout regularization"""
    m = timm.create_model(name, pretrained=True, drop_rate=0.2, drop_path_rate=0.1, num_classes=num_classes)
    return m

# Create MobileNetV3-Small and ViT-Small
mobilenet_small = make_model("mobilenetv3_small_100", num_classes)
vit_small = make_model("vit_small_patch16_224", num_classes)

print(f"✓ MobileNetV3-Small params: {sum(p.numel() for p in mobilenet_small.parameters())/1e6:.2f}M")
print(f"✓ ViT-Small params: {sum(p.numel() for p in vit_small.parameters())/1e6:.2f}M")


In [None]:
# Training Function with Mixed Precision
from torch.amp import autocast, GradScaler
from sklearn.metrics import f1_score, accuracy_score

def train_model(model, train_dl, val_dl, epochs=20, lr=5e-4, wd=0.05, device="cuda"):
    """
    Train a model with:
    - Mixed precision training
    - AdamW optimizer with weight decay
    - Linear warmup (3 epochs) + Cosine annealing
    - Label smoothing
    - Best model checkpointing based on macro F1
    """
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    
    # Learning rate schedule: warmup + cosine
    warmup = torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.1, total_iters=3)
    cosine = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs-3)
    sched = torch.optim.lr_scheduler.SequentialLR(opt, [warmup, cosine], milestones=[3])
    
    crit = nn.CrossEntropyLoss(label_smoothing=0.1)
    scaler = GradScaler('cuda', enabled=(device.startswith("cuda")))
    best = {"f1": -1, "state": None, "epoch": 0}
    
    for ep in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for x, y in train_dl:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            
            with autocast('cuda', enabled=(device.startswith("cuda"))):
                logits = model(x)
                loss = crit(logits, y)
            
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            train_loss += loss.item()
        
        sched.step()
        
        # Validation
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for x, y in val_dl:
                x = x.to(device, non_blocking=True)
                logits = model(x)
                preds.append(logits.argmax(1).cpu())
                gts.append(y)
        
        p = torch.cat(preds).numpy()
        g = torch.cat(gts).numpy()
        f1 = f1_score(g, p, average="macro")
        acc = accuracy_score(g, p)
        
        # Save best model
        if f1 > best["f1"]:
            best = {"f1": f1, "state": model.state_dict(), "epoch": ep+1}
        
        print(f"ep {ep+1:2d}: acc {acc:.4f}  macroF1 {f1:.4f}  (best F1: {best['f1']:.4f} @ ep{best['epoch']})")
    
    # Load best model
    model.load_state_dict(best["state"])
    print(f"\n✓ Training complete. Best model from epoch {best['epoch']} (F1: {best['f1']:.4f})")
    return model

print("✓ Training function defined")


In [None]:
# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")
    device = "cuda"
else:
    print("Device: CPU")
    device = "cpu"


In [None]:
# Train MobileNetV3-Small
print("="*60)
print("TRAINING MOBILENETV3-SMALL")
print("="*60)
mobilenet_trained = train_model(mobilenet_small, train_dl, val_dl, epochs=20, lr=5e-4, wd=0.05, device=device)


In [None]:
# Train ViT-Small
# ViT typically needs slightly lower learning rate
print("\n" + "="*60)
print("TRAINING VIT-SMALL")
print("="*60)
vit_trained = train_model(vit_small, train_dl, val_dl, epochs=20, lr=3e-4, wd=0.05, device=device)


In [None]:
# Performance Comparison and Benchmarking
import time

def ms_per_image(model, device="cuda"):
    """Measure inference speed in milliseconds per image"""
    model.eval().to(device)
    x = torch.randn(1, 3, 224, 224, device=device)
    
    # Warmup
    for _ in range(5):
        model(x)
    if device == "cuda":
        torch.cuda.synchronize()
    
    # Benchmark
    t0 = time.time()
    for _ in range(50):
        model(x)
    if device == "cuda":
        torch.cuda.synchronize()
    
    return 1000 * (time.time() - t0) / 50

def get_final_metrics(model, val_dl, device):
    """Get final accuracy and F1 score"""
    model.eval().to(device)
    preds, gts = [], []
    with torch.no_grad():
        for x, y in val_dl:
            x = x.to(device, non_blocking=True)
            logits = model(x)
            preds.append(logits.argmax(1).cpu())
            gts.append(y)
    
    p = torch.cat(preds).numpy()
    g = torch.cat(gts).numpy()
    return accuracy_score(g, p), f1_score(g, p, average="macro")

# Inference speed comparison
print("\n" + "="*60)
print("INFERENCE SPEED COMPARISON")
print("="*60)
mobilenet_speed = ms_per_image(mobilenet_trained, device)
vit_speed = ms_per_image(vit_trained, device)
print(f"MobileNetV3-Small: {mobilenet_speed:.2f} ms/image")
print(f"ViT-Small:         {vit_speed:.2f} ms/image")
print(f"Speed ratio:       {vit_speed/mobilenet_speed:.2f}x (ViT/MobileNet)")

# Final validation performance
print("\n" + "="*60)
print("FINAL VALIDATION PERFORMANCE")
print("="*60)

models = [
    ("MobileNetV3-Small", mobilenet_trained),
    ("ViT-Small", vit_trained)
]

results = []
for name, model in models:
    acc, f1 = get_final_metrics(model, val_dl, device)
    params = sum(p.numel() for p in model.parameters()) / 1e6
    speed = ms_per_image(model, device)
    results.append((name, acc, f1, params, speed))
    print(f"{name:20s} | Acc: {acc:.4f} | F1: {f1:.4f} | Params: {params:.2f}M | Speed: {speed:.2f}ms")

print("="*60)


In [None]:
# Test Set Evaluation
print("\n" + "="*60)
print("TEST SET EVALUATION")
print("="*60)

for name, model in models:
    acc, f1 = get_final_metrics(model, test_dl, device)
    print(f"{name:20s} | Test Acc: {acc:.4f} | Test F1: {f1:.4f}")

print("="*60)


In [None]:
# Save Models
# Save MobileNetV3-Small in PyTorch and ONNX formats
torch.save({
    "model": "mobilenetv3_small_100",
    "classes": full.classes,
    "num_classes": num_classes,
    "state_dict": mobilenet_trained.state_dict(),
    "dataset": "RealWaste"
}, "mobilenetv3_small_realwaste.pt")

# Export to ONNX for deployment
dummy = torch.randn(1, 3, 224, 224)
mobilenet_trained.eval().cpu()
torch.onnx.export(
    mobilenet_trained, 
    dummy, 
    "mobilenetv3_small_realwaste.onnx",
    input_names=["input"], 
    output_names=["logits"], 
    opset_version=17,
    dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}}
)

# Save ViT-Small
torch.save({
    "model": "vit_small_patch16_224",
    "classes": full.classes,
    "num_classes": num_classes,
    "state_dict": vit_trained.state_dict(),
    "dataset": "RealWaste"
}, "vit_small_realwaste.pt")

print("✓ Saved MobileNetV3-Small:")
print("  - mobilenetv3_small_realwaste.pt (PyTorch)")
print("  - mobilenetv3_small_realwaste.onnx (ONNX)")
print("✓ Saved ViT-Small:")
print("  - vit_small_realwaste.pt (PyTorch)")


In [None]:
# Optional: Confusion Matrix Analysis
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

def plot_confusion_matrix(model, val_dl, classes, device, title="Confusion Matrix"):
    """Plot confusion matrix for model predictions"""
    model.eval().to(device)
    preds, gts = [], []
    with torch.no_grad():
        for x, y in val_dl:
            x = x.to(device, non_blocking=True)
            logits = model(x)
            preds.append(logits.argmax(1).cpu())
            gts.append(y)
    
    p = torch.cat(preds).numpy()
    g = torch.cat(gts).numpy()
    
    # Confusion matrix
    cm = confusion_matrix(g, p)
    
    # Plot
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.title(title, fontsize=14)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    # Classification report
    print(f"\n{title} - Classification Report:")
    print(classification_report(g, p, target_names=classes, digits=4))

# Plot for both models
print("="*60)
print("CONFUSION MATRIX ANALYSIS")
print("="*60)
plot_confusion_matrix(mobilenet_trained, val_dl, full.classes, device, 
                     title="MobileNetV3-Small Confusion Matrix")
plot_confusion_matrix(vit_trained, val_dl, full.classes, device, 
                     title="ViT-Small Confusion Matrix")


## Training Summary

This notebook trains two models on the **RealWaste dataset** (9 classes of waste):
- **Cardboard, Food Organics, Glass, Metal, Miscellaneous Trash, Paper, Plastic, Textile Trash, Vegetation**

### Models Trained:
1. **MobileNetV3-Small** (~1.5M params) - Efficient mobile-optimized CNN
2. **ViT-Small** (~22M params) - Vision Transformer with attention mechanism

### Training Configuration:
- **Optimizer**: AdamW with weight decay (0.05)
- **Learning Rate**: 5e-4 (MobileNet), 3e-4 (ViT)
- **LR Schedule**: Linear warmup (3 epochs) + Cosine annealing
- **Data Augmentation**: RandomResizedCrop, HorizontalFlip, ColorJitter
- **Regularization**: Dropout (0.2), DropPath (0.1), Label smoothing (0.1)
- **Training**: Mixed precision (FP16) with gradient scaling
- **Epochs**: 20
- **Batch Size**: 64 (train), 128 (val/test)
- **Class Balancing**: Weighted random sampling
- **Data Split**: 70% train, 20% val, 10% test (stratified)

### Dataset Characteristics:
- **Total Images**: ~4,752 JPG images
- **Classes**: 9 waste categories
- **Image Format**: RGB JPG files
- **Source**: Local RealWaste dataset

### Expected Performance:
- **MobileNetV3-Small**: ~80-85% accuracy, fast inference (~5-7ms)
- **ViT-Small**: ~85-90% accuracy, moderate inference (~5-6ms)

### Output Files:
- `mobilenetv3_small_realwaste.pt` - PyTorch checkpoint
- `mobilenetv3_small_realwaste.onnx` - ONNX export for deployment
- `vit_small_realwaste.pt` - PyTorch checkpoint


In [None]:
# Optional: Test Loading Saved Models
# This cell demonstrates how to load and use the saved models

def load_model_from_checkpoint(checkpoint_path, device="cuda"):
    """Load a trained model from checkpoint"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Recreate model
    model = timm.create_model(
        checkpoint["model"], 
        pretrained=False,
        num_classes=checkpoint["num_classes"]
    )
    
    # Load weights
    model.load_state_dict(checkpoint["state_dict"])
    model.to(device)
    model.eval()
    
    return model, checkpoint["classes"]

# Example: Load MobileNetV3-Small
print("Testing model loading...")
loaded_model, classes = load_model_from_checkpoint("mobilenetv3_small_realwaste.pt", device)
print(f"✓ Loaded MobileNetV3-Small")
print(f"  Classes: {classes}")

# Quick test inference
test_input = torch.randn(1, 3, 224, 224, device=device)
with torch.no_grad():
    output = loaded_model(test_input)
    pred_class = output.argmax(1).item()
    print(f"  Test inference: predicted class {pred_class} ({classes[pred_class]})")

print("\n✓ Model loading and inference working correctly!")


In [None]:
# Optional: Per-Class Performance Analysis
from sklearn.metrics import precision_recall_fscore_support

def analyze_per_class_performance(model, val_dl, classes, device):
    """Analyze per-class precision, recall, and F1-score"""
    model.eval().to(device)
    preds, gts = [], []
    with torch.no_grad():
        for x, y in val_dl:
            x = x.to(device, non_blocking=True)
            logits = model(x)
            preds.append(logits.argmax(1).cpu())
            gts.append(y)
    
    p = torch.cat(preds).numpy()
    g = torch.cat(gts).numpy()
    
    # Per-class metrics
    precision, recall, f1, support = precision_recall_fscore_support(g, p, average=None)
    
    print(f"\nPer-Class Performance:")
    print(f"{'Class':<25s} | {'Precision':>9s} | {'Recall':>9s} | {'F1-Score':>9s} | {'Support':>7s}")
    print("-" * 80)
    for i, cls in enumerate(classes):
        print(f"{cls:<25s} | {precision[i]:9.4f} | {recall[i]:9.4f} | {f1[i]:9.4f} | {support[i]:7d}")
    
    # Overall metrics
    macro_p, macro_r, macro_f1 = precision.mean(), recall.mean(), f1.mean()
    print("-" * 80)
    print(f"{'Macro Average':<25s} | {macro_p:9.4f} | {macro_r:9.4f} | {macro_f1:9.4f} | {support.sum():7d}")

# Analyze both models
print("="*80)
print("PER-CLASS PERFORMANCE ANALYSIS")
print("="*80)

print("\n" + "="*80)
print("MobileNetV3-Small")
print("="*80)
analyze_per_class_performance(mobilenet_trained, val_dl, full.classes, device)

print("\n" + "="*80)
print("ViT-Small")
print("="*80)
analyze_per_class_performance(vit_trained, val_dl, full.classes, device)
