In [None]:
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18

from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Resize, CenterCrop, RandomCrop, RandomHorizontalFlip, ColorJitter, Normalize
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import os
import random

In [None]:
data_dir = "dataset/RealWaste"

total_images = sum([len(files) for _, _, files in os.walk(data_dir)])
print(f"Total images: {total_images}")

dataset = ImageFolder(root=data_dir, transform=ToTensor())

# Get indices for each class
class_indices = {i: [] for i in range(len(dataset.classes))}
for idx, (_, label) in enumerate(dataset.samples):
    class_indices[label].append(idx)

# Split into train/val/test maintaining class balance
train_indices = []
val_indices = []
test_indices = []

for label, indices in class_indices.items():
    # First split into train and temp (val+test)
    train_idx, temp_idx = train_test_split(
        indices, train_size=0.8, random_state=42, stratify=[label]*len(indices)
    )
    # Then split temp into val and test
    val_idx, test_idx = train_test_split(
        temp_idx, train_size=0.5, random_state=42, stratify=[label]*len(temp_idx)
    )
    train_indices.extend(train_idx)
    val_indices.extend(val_idx)
    test_indices.extend(test_idx)

# Create subsets
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

In [None]:
# Normalization statistics
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Augmentation for training
train_transforms = Compose([
    Resize(256),                      # Resize to 256
    RandomCrop(224),                  # Random crop to 224x224
    RandomHorizontalFlip(p=0.5),      # Horizontal flip
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1), # Color jitter
    ToTensor(),
    Normalize(IMAGENET_MEAN, IMAGENET_STD)  # Normalization
])

# Transforms for val/test (no augmentation)
val_test_transforms = Compose([
    Resize(256),
    CenterCrop(224),                  # Center crop
    ToTensor(),
    Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

# Apply transforms
train_dataset.dataset.transform = train_transforms
val_dataset.dataset.transform = val_test_transforms
test_dataset.dataset.transform = val_test_transforms

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [None]:
# Defines the model ensuring consistency with model.py
class WasteClassifier(nn.Module):
    def __init__(self, num_classes=9, weights='DEFAULT'):
        super().__init__()
        # Use weights instead of pretrained=True to avoid deprecation warnings
        self.backbone = resnet18(weights=weights)
        # Replace the last layer
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)

model = WasteClassifier(num_classes=9)
if torch.cuda.is_available():
    model = model.cuda()  # Move to GPU

In [None]:
# Calculate class weights
class_counts = [len(indices) for indices in class_indices.values()]
total_samples = sum(class_counts)
class_weights = [total_samples / count for count in class_counts]

# Convert to tensor
weights_tensor = torch.FloatTensor(class_weights)
if torch.cuda.is_available():
    weights_tensor = weights_tensor.cuda()

# Use weighted loss
criterion = nn.CrossEntropyLoss(weight=weights_tensor)

In [None]:
# Hyperparameters
EPOCHS = 25
LEARNING_RATE = 0.001

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3)

best_val_acc = 0

for epoch in range(EPOCHS):
    # Train
    model.train()
    train_loss = 0
    train_correct = 0

    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS} - Train'):
        if torch.cuda.is_available():
            images, labels = images.cuda(), labels.cuda()

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_correct += (outputs.argmax(1) == labels).sum().item()

    train_acc = train_correct / len(train_dataset)

    # Validation
    model.eval()
    val_loss = 0
    val_correct = 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{EPOCHS} - Val'):
            if torch.cuda.is_available():
                images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            val_correct += (outputs.argmax(1) == labels).sum().item()

    val_acc = val_correct / len(val_dataset)
    scheduler.step(val_loss)

    print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_waste_model.pth')
        print(f'New best model saved with accuracy: {val_acc:.4f}')