In [None]:
# Install PyTorch with GPU (CUDA) support
# For CUDA 12.1 (most common). For other versions, see https://pytorch.org
import subprocess
import sys

subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cu121"])
print("PyTorch with CUDA 12.1 installed successfully!")

In [None]:
%pip install scikit-learn

Cell 1: Setup & CONFIGURATION

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import time

# ==========================================
# 1. CONFIGURATION
# ==========================================
DATA_DIR = '../Public_dataset'      # Folder containing paper, plastic, aluminum
BATCH_SIZE = 16             # Reduced slightly for 512x512 images to avoid memory errors
LEARNING_RATE = 1e-4        # Lower learning rate for fine-tuning
NUM_EPOCHS = 15
NUM_CLASSES = 3             # paper, plastic, aluminum
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

Cell 2: DATA PREPARATION (Auto-Split)


In [None]:
# Training transforms: Resize + Augmentation (Flip, Rotate, Color Jitter)
train_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1), # Adds robustness to lighting
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Validation transforms: Resize only
val_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

try:
    # Load dataset twice: once for train (with augmentation), once for val (clean)
    full_data_train = datasets.ImageFolder(DATA_DIR, transform=train_transform)
    full_data_val = datasets.ImageFolder(DATA_DIR, transform=val_transform)

    # Get class names
    class_names = full_data_train.classes
    print(f"Classes detected: {class_names}")

    # Create indices for split (80% Train, 20% Val)
    train_idx, val_idx = train_test_split(
        list(range(len(full_data_train))),
        test_size=0.2,
        random_state=42
    )

    # Create subsets
    train_dataset = Subset(full_data_train, train_idx)
    val_dataset = Subset(full_data_val, val_idx)

    # Data Loaders
    dataloaders = {
        'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True),
        'val': DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    }
    dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
    
    print(f"Training on {dataset_sizes['train']} images")
    print(f"Validating on {dataset_sizes['val']} images")

except Exception as e:
    print("\nERROR: Could not find dataset!")
    print(f"Make sure you have a folder named '{DATA_DIR}' with subfolders for each class.")
    exit()

Cell 3: Model Setup (ResNet50 + Dropout)

In [None]:
print("\nInitializing ResNet50...")
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

# Modify the final layer (The Classifier)
# ResNet50's default input to the final layer is 2048 features
num_ftrs = model.fc.in_features

model.fc = nn.Sequential(
    nn.Dropout(0.5),            # Strong dropout to prevent overfitting
    nn.Linear(num_ftrs, 512),   # Add an intermediate layer
    nn.ReLU(),
    nn.Dropout(0.3),            # Mild dropout
    nn.Linear(512, NUM_CLASSES) # Final output (3 classes)
)

model = model.to(DEVICE)

Cell 4: Training Loop

In [None]:
criterion = nn.CrossEntropyLoss()

# Adam optimizer is generally faster at converging than SGD
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Scheduler: if validation accuracy doesn't improve, lower the learning rate
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3)

# ==========================================
# TRAINING LOOP
# ==========================================
def train_model(model, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()
    
    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'\n{"="*60}')
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'{"="*60}')

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            batch_count = 0

            # Iterate over data
            for batch_idx, (inputs, labels) in enumerate(dataloaders[phase]):
                inputs = inputs.to(DEVICE)
                labels = labels.to(DEVICE)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                # Track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                batch_count += 1
                
                # Print batch progress every 5 batches
                if (batch_idx + 1) % 5 == 0 or batch_idx == 0:
                    current_loss = running_loss / (batch_count * BATCH_SIZE)
                    current_acc = running_corrects.double() / (batch_count * BATCH_SIZE)
                    num_batches = len(dataloaders[phase])
                    print(f'  {phase.upper()} | Batch {batch_idx+1}/{num_batches} | Loss: {current_loss:.4f} | Acc: {current_acc:.4f}')

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            # Get current learning rate
            current_lr = optimizer.param_groups[0]['lr']
            
            print(f'\n  {phase.upper()} SUMMARY | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f} | LR: {current_lr:.2e}')

            # Deep copy the model if it's the best one so far
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()
                print(f'  *** NEW BEST MODEL! Validation Accuracy: {epoch_acc:.4f} ***')
            
            # Step scheduler based on accuracy
            if phase == 'val':
                scheduler.step(epoch_acc)

    time_elapsed = time.time() - since
    print(f'\n{"="*60}')
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')
    print(f'{"="*60}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

# ==========================================
# RUN & SAVE
# ==========================================
if __name__ == '__main__':
    # Train
    trained_model = train_model(model, criterion, optimizer, scheduler, NUM_EPOCHS)
    
    # Save
    save_path = 'waste_classifier_resnet50.pth'
    torch.save(trained_model.state_dict(), save_path)
    print(f"Model saved to {save_path}")