# Waste Classification Training Pipeline
**Architecture**: EfficientNet-B0  
**Objective**: Robust classification of 9 waste categories using a Composite Augmentation Strategy.

## Methodology
To address real-world challenges such as distinguishing **White Metal Cans** from **White Paper**, this pipeline implements a specialized data augmentation strategy that targets specific material properties:

1.  **Scale Invariance**: `RandomAffine` ensures the model recognizes objects at various zoom levels (close-ups vs far shots).
2.  **Material Sheen**: `ColorJitter` allows the model to learn specular highlights (shine) characteristic of metal.
3.  **Structural Learning**: `RandomGrayscale` forces the model to rely on object geometry rather than just color.
4.  **Texture Enhancement**: `RandomAdjustSharpness` emphasizes edges and creases, critical for identifying crumpled paper.

In [None]:
import numpy as np
from PIL import Image
import os
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Resize, CenterCrop, RandomCrop, Normalize, RandomHorizontalFlip, RandomRotation, RandomGrayscale, RandomAdjustSharpness, RandomAutocontrast, RandomAffine, ColorJitter
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Initialization

In [None]:
data_dir = "dataset"
total_images = sum([len(files) for _, _, files in os.walk(data_dir)])
print(f"Total images found: {total_images}")

## 2. Composite Augmentation Pipeline
The augmentation strategy is order-dependent to avoid artifacts (e.g., black borders from rotation) and maximize feature diversity.

In [None]:
# ImageNet Normalization Stats
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

train_transforms = Compose([
    # 1. Geometric Transformations
    # Applied before cropping to minimize 'void' (black) regions
    Resize(256),
    RandomRotation(degrees=180),
    RandomAffine(degrees=0, translate=None, scale=(0.8, 1.2)), # Scale Invariance (0.8x - 1.2x)
    
    # 2. Patch Extraction
    RandomCrop(224),
    RandomHorizontalFlip(p=0.5),

    # 3. Material Property Augmentation
    # Preserves specular highlights (Metal vs Paper)
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
    
    # 4. Structural Augmentation
    # Forces learning of shape/geometry (p=0.3)
    RandomGrayscale(p=0.3),

    # 5. Texture Enhancement
    # Emphasizes creases and high-frequency details (Paper Texture)
    RandomAdjustSharpness(sharpness_factor=1.5, p=0.3),
    RandomAutocontrast(p=0.3),
    
    ToTensor(),
    Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

val_test_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

# Dataset Wrapper
class TransformedDataset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, idx):
        x, y = self.subset[idx]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

# Split Strategy
base_dataset = ImageFolder(root=data_dir, transform=None)

# Stratified Split (80% Train, 10% Val, 10% Test)
class_indices = {i: [] for i in range(len(base_dataset.classes))}
for idx, (_, label) in enumerate(base_dataset.samples):
    class_indices[label].append(idx)

train_indices = []
val_indices = []
test_indices = []

for label, indices in class_indices.items():
    train_idx, temp_idx = train_test_split(
        indices, train_size=0.8, random_state=42, stratify=[label]*len(indices)
    )
    val_idx, test_idx = train_test_split(
        temp_idx, train_size=0.5, random_state=42, stratify=[label]*len(temp_idx)
    )
    train_indices.extend(train_idx)
    val_indices.extend(val_idx)
    test_indices.extend(test_idx)

train_dataset = TransformedDataset(Subset(base_dataset, train_indices), train_transforms)
val_dataset = TransformedDataset(Subset(base_dataset, val_indices), val_test_transforms)
test_dataset = TransformedDataset(Subset(base_dataset, test_indices), val_test_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

## 3. Regularization (Mixup)

In [None]:
def mixup_data(x, y, alpha=0.2, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

## 4. Model Definition

In [None]:
class WasteClassifier(nn.Module):
    def __init__(self, num_classes=9):
        super().__init__()
        weights = EfficientNet_B0_Weights.DEFAULT
        self.backbone = efficientnet_b0(weights=weights)
        
        # Replace classifier head for 9 classes
        original_in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier[1] = nn.Linear(original_in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)

model = WasteClassifier(num_classes=9)
model = model.to(device)

## 5. Training Loop
Includes Class Weighting, Label Smoothing, and Checkpointing.

In [None]:
# Handle Class Imbalance
class_counts = [len(indices) for indices in class_indices.values()]
total_samples = sum(class_counts)
class_weights = [total_samples / count for count in class_counts]
weights_tensor = torch.FloatTensor(class_weights).to(device)

# Loss with Label Smoothing
criterion = nn.CrossEntropyLoss(weight=weights_tensor, label_smoothing=0.1)

EPOCHS = 30
LEARNING_RATE = 1e-4

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)

# Save Directory
os.makedirs('weights', exist_ok=True)

best_val_acc = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    train_correct = 0

    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS} - Train'):
        images, labels = images.to(device), labels.to(device)

        # Mixup
        inputs, targets_a, targets_b, lam = mixup_data(images, labels, alpha=0.2, use_cuda=torch.cuda.is_available())
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_correct += (lam * predicted.eq(targets_a.data).cpu().sum().float() + 
                          (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())

    train_acc = train_correct / len(train_dataset)

    # Validation
    model.eval()
    val_loss = 0
    val_correct = 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{EPOCHS} - Val'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            val_correct += (outputs.argmax(1) == labels).sum().item()

    val_acc = val_correct / len(val_dataset)
    scheduler.step(val_loss)

    print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
    
    # Save Checkpoint (Every Epoch)
    torch.save(model.state_dict(), f'weights/model_epoch_{epoch+1}.pth')

    # Save Best Model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'weights/best_waste_model.pth')
        print(f'New best model saved! Accuracy: {val_acc:.4f}')