In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from torchvision import models
from sklearn.model_selection import StratifiedShuffleSplit
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import os
from pathlib import Path
from sklearn.metrics import confusion_matrix, classification_report
import time

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: NVIDIA GeForce RTX 3050 Ti Laptop GPU


## Dataset Parameters

In [2]:
IMAGE_SIZE = 224
BATCH_SIZE = 32
NUM_CLASSES = 9
NUM_EPOCHS = 60
DATA_DIR = '../realwaste/realwaste-main/RealWaste'  # ✅ Kaggle path

# Class names
CLASS_NAMES = [
    'Cardboard',
    'Food Organics',
    'Glass',
    'Metal',
    'Miscellaneous Trash',
    'Paper',
    'Plastic',
    'Textile Trash',
    'Vegetation'
]

print("\n" + "=" * 70)
print("KAGGLE ENVIRONMENT - IMPROVED RESNET50 TRAINING")
print("=" * 70)
print(f"📁 Data directory: {DATA_DIR}")
print(f"🎯 Target epochs: {NUM_EPOCHS}")
print(f"🔢 Classes: {NUM_CLASSES}")
print("=" * 70)



KAGGLE ENVIRONMENT - IMPROVED RESNET50 TRAINING
📁 Data directory: ../realwaste/realwaste-main/RealWaste
🎯 Target epochs: 60
🔢 Classes: 9


## Custom Dataset Class

In [3]:
class WasteDataset(Dataset):
    """Custom Dataset for loading waste material images"""
    
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.images = []
        self.labels = []
        
        for class_idx, class_name in enumerate(CLASS_NAMES):
            class_dir = self.root_dir / class_name
            if class_dir.exists():
                image_files = sorted(class_dir.glob('*.*'))
                for img_path in image_files:
                    if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                        self.images.append(str(img_path))
                        self.labels.append(class_idx)
            else:
                print(f"⚠️ Warning: Directory not found: {class_dir}")
        
        print(f"Loaded {len(self.images)} images from {len(CLASS_NAMES)} classes")
        
        if len(self.images) == 0:
            raise RuntimeError(f"No images found in {root_dir}")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            image = Image.new('RGB', (224, 224), (0, 0, 0))
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


### Mixup Augmentation

In [4]:
def mixup_data(x, y, alpha=0.2):
    """Applies mixup augmentation to batch"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Compute mixup loss"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


## Dataset statistics

In [5]:
def calculate_mean_std(dataset_path, image_size=224, sample_size=1000):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor()
    ])
    
    temp_dataset = WasteDataset(root_dir=dataset_path, transform=transform)
    
    if sample_size and sample_size < len(temp_dataset):
        indices = np.random.choice(len(temp_dataset), sample_size, replace=False)
        temp_dataset = Subset(temp_dataset, indices)
    
    loader = DataLoader(temp_dataset, batch_size=32, shuffle=False, num_workers=2)
    
    channels_sum = torch.zeros(3)
    channels_squared_sum = torch.zeros(3)
    num_pixels = 0
    
    for images, _ in loader:
        channels_sum += torch.mean(images, dim=[0, 2, 3]) * images.size(0)
        channels_squared_sum += torch.mean(images ** 2, dim=[0, 2, 3]) * images.size(0)
        num_pixels += images.size(0)
    
    mean = channels_sum / num_pixels
    std = torch.sqrt(channels_squared_sum / num_pixels - mean ** 2)
    
    print(f"Dataset Mean (R, G, B): [{mean[0]:.4f}, {mean[1]:.4f}, {mean[2]:.4f}]")
    print(f"Dataset Std (R, G, B): [{std[0]:.4f}, {std[1]:.4f}, {std[2]:.4f}]")
    
    return mean.tolist(), std.tolist()


print("\n" + "=" * 70)
print("CALCULATING DATASET STATISTICS")
print("=" * 70)
dataset_mean, dataset_std = calculate_mean_std(DATA_DIR, IMAGE_SIZE, sample_size=1000)


CALCULATING DATASET STATISTICS
Loaded 4752 images from 9 classes


Dataset Mean (R, G, B): [0.5959, 0.6181, 0.6327]
Dataset Std (R, G, B): [0.1614, 0.1624, 0.1879]


## Enhanced data augmentation

In [6]:
print("\n" + "=" * 70)
print("CREATING ENHANCED TRANSFORMS")
print("=" * 70)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.7, 1.0), ratio=(0.9, 1.1)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(20),
    transforms.ColorJitter(0.25, 0.25, 0.25, 0.05),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=dataset_mean, std=dataset_std),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.25), ratio=(0.3, 3.3))
])

val_test_transform = transforms.Compose([
    transforms.Resize(int(IMAGE_SIZE * 1.15)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=dataset_mean, std=dataset_std)
])

print("✅ Enhanced augmentation pipeline created")


CREATING ENHANCED TRANSFORMS
✅ Enhanced augmentation pipeline created


## Load and split dataset

In [7]:
print("\n" + "=" * 70)
print("LOADING AND SPLITTING DATASET")
print("=" * 70)

train_ratio = 0.70
val_ratio = 0.15
test_ratio = 0.15

full_dataset = WasteDataset(root_dir=DATA_DIR, transform=None)
labels = np.array([full_dataset[i][1] for i in range(len(full_dataset))])

# Stratified split
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=(val_ratio + test_ratio), random_state=42)
train_idx, temp_idx = next(sss1.split(np.arange(len(labels)), labels))

temp_labels = labels[temp_idx]
relative_test_size = test_ratio / (val_ratio + test_ratio)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=relative_test_size, random_state=42)
val_idx_rel, test_idx_rel = next(sss2.split(np.arange(len(temp_idx)), temp_labels))

val_idx = temp_idx[val_idx_rel]
test_idx = temp_idx[test_idx_rel]

print(f"Train: {len(train_idx)} | Val: {len(val_idx)} | Test: {len(test_idx)}")

# Create datasets with transforms
train_dataset_full = WasteDataset(root_dir=DATA_DIR, transform=train_transform)
val_dataset_full = WasteDataset(root_dir=DATA_DIR, transform=val_test_transform)
test_dataset_full = WasteDataset(root_dir=DATA_DIR, transform=val_test_transform)

train_dataset = Subset(train_dataset_full, train_idx)
val_dataset = Subset(val_dataset_full, val_idx)
test_dataset = Subset(test_dataset_full, test_idx)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                         num_workers=2, pin_memory=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                       num_workers=2, pin_memory=True, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=2, pin_memory=True, drop_last=False)

print("✅ Data loaders created")



LOADING AND SPLITTING DATASET
Loaded 4752 images from 9 classes
Train: 3326 | Val: 713 | Test: 713
Loaded 4752 images from 9 classes
Loaded 4752 images from 9 classes
Loaded 4752 images from 9 classes
✅ Data loaders created


## Compute Class weights

In [8]:
print("\n" + "=" * 70)
print("COMPUTING CLASS WEIGHTS")
print("=" * 70)

train_labels = labels[train_idx].tolist()
class_counts = Counter(train_labels)
total_train = len(train_labels)
class_weights = torch.tensor(
    [total_train / (NUM_CLASSES * class_counts[i]) for i in range(NUM_CLASSES)],
    dtype=torch.float32
).to(device)

print("📊 Class weights computed")
for i, (name, weight) in enumerate(zip(CLASS_NAMES, class_weights)):
    print(f"  {name:20s}: {weight:.3f} (count: {class_counts[i]:3d})")



COMPUTING CLASS WEIGHTS
📊 Class weights computed
  Cardboard           : 1.144 (count: 323)
  Food Organics       : 1.283 (count: 288)
  Glass               : 1.257 (count: 294)
  Metal               : 0.668 (count: 553)
  Miscellaneous Trash : 1.068 (count: 346)
  Paper               : 1.056 (count: 350)
  Plastic             : 0.573 (count: 645)
  Textile Trash       : 1.665 (count: 222)
  Vegetation          : 1.212 (count: 305)


## Creating the model

In [9]:
print("\n" + "=" * 70)
print("CREATING RESNET50 MODEL")
print("=" * 70)

model = models.resnet50(weights='IMAGENET1K_V1')  # ✅ Updated for newer PyTorch
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, NUM_CLASSES)
model = model.to(device)

print("✅ ResNet50 loaded with pretrained ImageNet weights")
print(f"   Final layer: {in_features} → {NUM_CLASSES} classes")


CREATING RESNET50 MODEL


✅ ResNet50 loaded with pretrained ImageNet weights
   Final layer: 2048 → 9 classes


## Training

In [None]:
print("\n" + "=" * 70)
print("TRAINING CONFIGURATION")
print("=" * 70)

LEARNING_RATE = 0.001
MIXUP_ALPHA = 0.2

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
print(f"✅ Optimizer: AdamW (lr={LEARNING_RATE}, weight_decay=1e-4)")

scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,
    T_mult=1,
    eta_min=1e-6
)
print(f"✅ Scheduler: CosineAnnealingWarmRestarts (T_0=10)")

criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
print(f"✅ Loss: CrossEntropyLoss + Label Smoothing + Class Weights")
print(f"✅ Mixup: Enabled (alpha={MIXUP_ALPHA})")
print(f"✅ Epochs: {NUM_EPOCHS}")

class EarlyStopping:
    def __init__(self, patience=12, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.save_checkpoint(model)
            self.counter = 0
    
    def save_checkpoint(self, model):
        self.best_model = {k: v.cpu().clone() for k, v in model.state_dict().items()}

early_stopping = EarlyStopping(patience=12)
print(f"✅ Early Stopping: patience=12")

def train_epoch(model, dataloader, criterion, optimizer, device, use_mixup=True):
    """Training with MIXUP"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        # Apply mixup
        if use_mixup and np.random.rand() > 0.5:
            images, labels_a, labels_b, lam = mixup_data(images, labels, alpha=MIXUP_ALPHA)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (lam * (predicted == labels_a).sum().item() + 
                       (1 - lam) * (predicted == labels_b).sum().item())
        else:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc


def validate_epoch(model, dataloader, criterion, device):
    """Standard validation"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc


def validate_with_tta(model, dataloader, criterion, device, num_tta=5):
    """Validation with Test-Time Augmentation"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    tta_transforms = [
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.RandomRotation(10),
        transforms.RandomRotation(-10),
        transforms.Compose([
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.RandomRotation(5)
        ])
    ]
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            all_outputs = []
            all_outputs.append(model(images))
            
            # Apply different augmentations
            for i in range(min(num_tta - 1, len(tta_transforms))):
                aug_images = images.clone()
                for j in range(images.size(0)):
                    # Convert to PIL, apply transform, convert back
                    img_pil = transforms.ToPILImage()(images[j].cpu())
                    img_aug = tta_transforms[i](img_pil)
                    aug_images[j] = transforms.ToTensor()(img_aug)
                aug_images = aug_images.to(device)
                all_outputs.append(model(aug_images))
            
            outputs = torch.stack(all_outputs).mean(0)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

print("\n" + "=" * 70)
print("STARTING TRAINING - IMPROVED RESNET50")

history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'lr': []
}

best_val_acc = 0.0
best_epoch = 0
start_time = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    # Train with mixup
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)
    
    # Print progress
    epoch_time = time.time() - epoch_start
    print(f"Epoch [{epoch+1:3d}/{NUM_EPOCHS}] ({epoch_time:4.1f}s) | "
          f"LR: {current_lr:.6f} | "
          f"Train: {train_acc:5.2f}% | "
          f"Val: {val_acc:5.2f}%", end='')
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'train_acc': train_acc,
            'history': history
        }, 'best_resnet50_improved.pth')
        print(" ✨ BEST!", end='')
    
    print()
    
    # Milestone messages
    if epoch > 0:
        if val_acc >= 93 and history['val_acc'][-2] < 93:
            print("   🎯 Milestone: 93% accuracy - Excellent!")
        elif val_acc >= 95 and history['val_acc'][-2] < 95:
            print("   🏆 OUTSTANDING: 95% accuracy achieved!")
    
    # Early stopping
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print(f"\nℹ️  Early stopping triggered at epoch {epoch+1}")
        break

total_time = time.time() - start_time

print("\n" + "=" * 70)
print("✅ TRAINING COMPLETE")
print("=" * 70)
print(f"⏱️  Total training time: {total_time/60:.1f} minutes")
print(f"🏆 Best validation accuracy: {best_val_acc:.2f}%")
print(f"📍 Best epoch: {best_epoch}")
print(f"💾 Model saved as: best_resnet50_improved.pth")


TRAINING CONFIGURATION
✅ Optimizer: AdamW (lr=0.001, weight_decay=1e-4)
✅ Scheduler: CosineAnnealingWarmRestarts (T_0=10)
✅ Loss: CrossEntropyLoss + Label Smoothing + Class Weights
✅ Mixup: Enabled (alpha=0.2)
✅ Epochs: 60
✅ Early Stopping: patience=12

STARTING TRAINING - IMPROVED RESNET50
Epoch [  1/60] (33.9s) | LR: 0.000976 | Train: 32.22% | Val: 34.78% ✨ BEST!
Epoch [  2/60] (33.1s) | LR: 0.000905 | Train: 46.71% | Val: 43.90% ✨ BEST!
Epoch [  3/60] (33.3s) | LR: 0.000794 | Train: 52.65% | Val: 60.03% ✨ BEST!
Epoch [  4/60] (33.3s) | LR: 0.000655 | Train: 57.65% | Val: 67.46% ✨ BEST!
Epoch [  5/60] (33.3s) | LR: 0.000501 | Train: 62.24% | Val: 69.14% ✨ BEST!
Epoch [  6/60] (33.2s) | LR: 0.000346 | Train: 64.28% | Val: 69.28% ✨ BEST!
Epoch [  7/60] (33.2s) | LR: 0.000207 | Train: 67.90% | Val: 75.18% ✨ BEST!
Epoch [  8/60] (33.5s) | LR: 0.000096 | Train: 70.93% | Val: 80.65% ✨ BEST!
Epoch [  9/60] (33.5s) | LR: 0.000025 | Train: 76.93% | Val: 82.33% ✨ BEST!
Epoch [ 10/60] (33.5s) 

KeyboardInterrupt: 

: 

## Evaluation

In [None]:
print("\n" + "=" * 70)
print("FINAL EVALUATION WITH TEST-TIME AUGMENTATION")
print("=" * 70)

# Load best model
checkpoint = torch.load('best_resnet50_improved.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Standard test
test_loss, test_acc = validate_epoch(model, test_loader, criterion, device)
print(f"📊 Standard Test Accuracy: {test_acc:.2f}%")

# Test with TTA
print("🔄 Running Test-Time Augmentation...")
tta_loss, tta_acc = validate_with_tta(model, test_loader, criterion, device, num_tta=5)
print(f"🎯 TTA Test Accuracy: {tta_acc:.2f}% (boost: +{tta_acc-test_acc:.2f}%)")


# ================================
# PLOT TRAINING HISTORY
# ================================
print("\n" + "=" * 70)
print("GENERATING TRAINING PLOTS")
print("=" * 70)

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Accuracy plot
axes[0, 0].plot(history['train_acc'], label='Train Accuracy', linewidth=2)
axes[0, 0].plot(history['val_acc'], label='Val Accuracy', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Accuracy (%)')
axes[0, 0].set_title('Training and Validation Accuracy')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Loss plot
axes[0, 1].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 1].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Training and Validation Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Learning rate plot
axes[1, 0].plot(history['lr'], linewidth=2, color='green')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_yscale('log')

# Summary text
summary_text = f"""
TRAINING SUMMARY
{'='*30}

Best Val Accuracy: {best_val_acc:.2f}%
Best Epoch: {best_epoch}
Total Epochs: {len(history['train_acc'])}
Training Time: {total_time/60:.1f} min

Test Accuracy: {test_acc:.2f}%
TTA Test Accuracy: {tta_acc:.2f}%
TTA Improvement: +{tta_acc-test_acc:.2f}%

Improvements Applied:
✓ Mixup augmentation
✓ Enhanced data augmentation
✓ 40 epochs (optimized)
✓ Test-Time Augmentation
"""
axes[1, 1].text(0.1, 0.5, summary_text, fontsize=11, family='monospace',
                verticalalignment='center', bbox=dict(boxstyle='round', 
                facecolor='wheat', alpha=0.5))
axes[1, 1].axis('off')

plt.tight_layout()
plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
print("✅ Training plots saved as 'training_history.png'")

print("\n" + "=" * 70)
print("✅ ALL DONE!")
print("=" * 70)