# ResNet with attention training



In [None]:
# ResNet with attention training

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import time
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

print("ResNet50 + Attention Training Pipeline")
print("=" * 45)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Create directories
models_dir = Path("../../models/saved_models")
models_dir.mkdir(parents=True, exist_ok=True)

results_plots_dir = Path("../../results/plots")
results_plots_dir.mkdir(parents=True, exist_ok=True)

print(f"Created directories:")
print(f"   {models_dir}")
print(f"   {results_plots_dir}")

# Load configuration
processed_root = Path("../../data/processed")
results_dir = Path("../../results")

with open(processed_root / "training_config.json", "r") as f:
    config = json.load(f)

class_weights = torch.load(processed_root / "fast_class_weights.pt", map_location='cpu')

print(f"\nTraining Configuration:")
print(f"   Classes: {config['num_classes']}")
print(f"   Class names: {config['class_names']}")
print(f"   Device: {device}")

# Load dataset code (reusing from EfficientNet)
import xml.etree.ElementTree as ET
from collections import Counter
from PIL import Image

# Load simplified mappings
with open(results_dir / "simplified_class_mapping.json", "r") as f:
    class_mapping = json.load(f)

with open(results_dir / "detailed_to_simplified_mapping.json", "r") as f:
    detailed_to_simplified = json.load(f)

class FastVehicleDataset:
    """Lightweight dataset for training"""
    
    def __init__(self, split_name, max_files=100, transform=None):
        self.transform = transform
        self.data = self._load_data(split_name, max_files)
        print(f"{split_name} dataset: {len(self.data)} samples")
    
    def _load_data(self, split_name, max_files):
        data_root = Path("../../data/raw")
        annos_dir = data_root / split_name / "annos"
        images_dir = data_root / split_name / "images"
        
        xml_files = list(annos_dir.glob("*.xml"))[:max_files]
        data_samples = []
        
        for xml_file in xml_files:
            try:
                tree = ET.parse(xml_file)
                root = tree.getroot()
                
                img_id = xml_file.stem
                img_path = images_dir / f"{img_id}.jpg"
                
                if not img_path.exists():
                    continue
                
                for obj in root.findall('object'):
                    try:
                        name_elem = obj.find('name')
                        if name_elem is None:
                            continue
                        
                        detailed_class = name_elem.text.strip().lower()
                        simplified_class = detailed_to_simplified.get(detailed_class, 'unknown')
                        
                        if simplified_class == 'unknown':
                            continue
                        
                        bbox_elem = obj.find('bndbox')
                        if bbox_elem is None:
                            continue
                        
                        xmin = max(0, int(float(bbox_elem.find('xmin').text)))
                        ymin = max(0, int(float(bbox_elem.find('ymin').text)))
                        xmax = int(float(bbox_elem.find('xmax').text))
                        ymax = int(float(bbox_elem.find('ymax').text))
                        
                        if xmax > xmin + 30 and ymax > ymin + 30:
                            data_samples.append({
                                'img_path': str(img_path),
                                'bbox': [xmin, ymin, xmax, ymax],
                                'class_idx': class_mapping[simplified_class],
                                'class_name': simplified_class
                            })
                    except:
                        continue
            except:
                continue
        
        return data_samples
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        try:
            item = self.data[idx]
            image = Image.open(item['img_path']).convert('RGB')
            bbox = item['bbox']
            cropped = image.crop(bbox)
            
            if self.transform:
                cropped = self.transform(cropped)
            
            return cropped, item['class_idx']
        except:
            dummy_img = torch.zeros(3, 224, 224)
            return dummy_img, 0

# CBAM Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size)
    
    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

# ResNet50 with CBAM Attention
class ResNetWithAttention(nn.Module):
    def __init__(self, num_classes=6, pretrained=True):
        super(ResNetWithAttention, self).__init__()
        
        # Load pretrained ResNet50
        self.backbone = models.resnet50(pretrained=pretrained)
        
        # Remove final layers
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # Add CBAM attention
        self.attention = CBAM(2048)  # ResNet50 final feature size
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Custom classifier with attention
        self.classifier = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        # Extract features
        features = self.backbone(x)
        
        # Apply attention
        attended_features = self.attention(features)
        
        # Global pooling
        pooled = self.global_pool(attended_features)
        pooled = pooled.view(pooled.size(0), -1)
        
        # Classification
        output = self.classifier(pooled)
        
        return output

# Enhanced transforms for ResNet
def get_resnet_transforms(stage='train'):
    if stage == 'train':
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

# Create datasets
print(f"\nCREATING RESNET TRAINING DATASETS")
print("=" * 40)

train_dataset = FastVehicleDataset('train', max_files=100, transform=get_resnet_transforms('train'))
val_dataset = FastVehicleDataset('val', max_files=50, transform=get_resnet_transforms('val'))

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True, num_workers=0)  # Smaller batch for ResNet
val_loader = DataLoader(val_dataset, batch_size=6, shuffle=False, num_workers=0)

print(f"Data loaders ready:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")

# Create ResNet model
print(f"\nCREATING RESNET50 + ATTENTION MODEL")
print("=" * 40)

model = ResNetWithAttention(num_classes=config['num_classes'])
model = model.to(device)

print(f"ResNet50 + Attention model created")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Focal Loss for better handling of class imbalance
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()

# Setup training
print(f"\nTRAINING SETUP")
print("=" * 20)

# Focal loss with class weights
criterion = FocalLoss(alpha=0.25, gamma=2, weight=class_weights.to(device))

# Optimizer with different learning rates for backbone and classifier
backbone_params = []
classifier_params = []

for name, param in model.named_parameters():
    if 'classifier' in name or 'attention' in name:
        classifier_params.append(param)
    else:
        backbone_params.append(param)

optimizer = optim.AdamW([
    {'params': backbone_params, 'lr': 0.0001},  # Lower LR for pretrained backbone
    {'params': classifier_params, 'lr': 0.001}  # Higher LR for new layers
], weight_decay=0.01)

# Step learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

print(f"Training setup complete:")
print(f"   Loss: FocalLoss (alpha=0.25, gamma=2) with class weights")
print(f"   Optimizer: AdamW with different LRs")
print(f"   Scheduler: StepLR (step=5, gamma=0.5)")

# Training and validation functions
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with tqdm(train_loader, desc="Training", leave=False) as pbar:
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'Loss': f'{loss.item():.3f}',
                'Acc': f'{100.*correct/total:.1f}%'
            })
    
    return running_loss / len(train_loader), 100. * correct / total

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(val_loader), 100. * correct / total

# Training loop
print(f"\nSTARTING RESNET TRAINING")
print("=" * 30)

num_epochs = 12  # More epochs for ResNet
best_acc = 0.0
train_losses = []
train_accs = []
val_losses = []
val_accs = []

start_time = time.time()

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 25)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    
    # Save metrics
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    # Print results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"LR: {scheduler.get_last_lr()[0]:.6f}")
    
    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        model_save_path = models_dir / "resnet_attention_best.pth"
        torch.save(model.state_dict(), model_save_path)
        print(f"New best model saved! (Val Acc: {val_acc:.2f}%)")

training_time = time.time() - start_time

print(f"\nRESNET TRAINING COMPLETE!")
print("=" * 30)
print(f"Training time: {training_time:.2f}s")
print(f"Best validation accuracy: {best_acc:.2f}%")

# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
plt.plot(val_losses, label='Val Loss', color='red', linewidth=2)
plt.title('ResNet + Attention: Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(train_accs, label='Train Acc', color='blue', linewidth=2)
plt.plot(val_accs, label='Val Acc', color='red', linewidth=2)
plt.title('ResNet + Attention: Accuracy Curves')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True, alpha=0.3)

# Compare with EfficientNet if results exist
plt.subplot(1, 3, 3)
try:
    with open(results_dir / "efficientnet_training_results.json", "r") as f:
        efficientnet_results = json.load(f)
    
    plt.bar(['EfficientNet-B3', 'ResNet50+Attention'], 
            [efficientnet_results['best_val_accuracy'], best_acc],
            color=['skyblue', 'lightcoral'])
    plt.title('Model Comparison')
    plt.ylabel('Best Validation Accuracy (%)')
    plt.grid(True, alpha=0.3)
    
    for i, v in enumerate([efficientnet_results['best_val_accuracy'], best_acc]):
        plt.text(i, v + 1, f'{v:.1f}%', ha='center', va='bottom', fontweight='bold')
        
except:
    plt.text(0.5, 0.5, 'EfficientNet results\nnot found', ha='center', va='center', transform=plt.gca().transAxes)

plt.tight_layout()
plt.savefig(results_plots_dir / "resnet_attention_training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

# Save training results
training_results = {
    'model': 'ResNet50+Attention',
    'best_val_accuracy': float(best_acc),
    'final_train_accuracy': float(train_accs[-1]),
    'final_val_accuracy': float(val_accs[-1]),
    'training_time': float(training_time),
    'epochs': num_epochs,
    'train_losses': train_losses,
    'train_accs': train_accs,
    'val_losses': val_losses,
    'val_accs': val_accs,
    'parameters': sum(p.numel() for p in model.parameters()),
    'model_size_mb': sum(p.numel() * 4 for p in model.parameters()) / (1024 * 1024),
    'architecture_features': {
        'attention_mechanism': 'CBAM (Channel + Spatial)',
        'loss_function': 'FocalLoss',
        'learning_rate_schedule': 'Different LRs for backbone/classifier',
        'regularization': 'Dropout + BatchNorm + Gradient Clipping'
    }
}

with open(results_dir / "resnet_attention_training_results.json", "w") as f:
    json.dump(training_results, f, indent=2)

print(f"\nRESULTS SAVED:")
print(f"   Model weights: {model_save_path}")
print(f"   Training curves: {results_plots_dir}/resnet_attention_training_curves.png")
print(f"   Results: {results_dir}/resnet_attention_training_results.json")

print(f"\nFINAL RESULTS:")
print(f"   Best Val Accuracy: {best_acc:.2f}%")
print(f"   Final Train Accuracy: {train_accs[-1]:.2f}%")
print(f"   Training Time: {training_time:.1f}s")
print(f"   Model Size: {training_results['model_size_mb']:.1f}MB")

print(f"\nRESNET + ATTENTION TRAINING COMPLETE!")
print("Next: Train YOLO model for object detection")

ResNet50 + Attention Training Pipeline
Device: cpu
Created directories:
   ..\..\models\saved_models
   ..\..\results\plots

Training Configuration:
   Classes: 6
   Class names: ['auto_rickshaw', 'bus', 'car', 'motorcycle', 'scooter', 'truck']
   Device: cpu

CREATING RESNET TRAINING DATASETS
train dataset: 414 samples
val dataset: 203 samples
Data loaders ready:
   Train batches: 69
   Val batches: 34

CREATING RESNET50 + ATTENTION MODEL


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\abhir/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:30<00:00, 3.40MB/s]


ResNet50 + Attention model created
   Parameters: 25,215,912
   Trainable: 25,215,912

TRAINING SETUP
Training setup complete:
   Loss: FocalLoss (alpha=0.25, gamma=2) with class weights
   Optimizer: AdamW with different LRs
   Scheduler: StepLR (step=5, gamma=0.5)

STARTING RESNET TRAINING

Epoch 1/12
-------------------------


                                                                                

Train Loss: 0.2734, Train Acc: 39.86%
Val Loss: 0.1929, Val Acc: 59.11%
LR: 0.000100
New best model saved! (Val Acc: 59.11%)

Epoch 2/12
-------------------------


                                                                                

Train Loss: 0.2082, Train Acc: 50.72%
Val Loss: 0.1766, Val Acc: 57.64%
LR: 0.000100

Epoch 3/12
-------------------------


                                                                                

Train Loss: 0.2104, Train Acc: 53.38%
Val Loss: 0.1797, Val Acc: 43.35%
LR: 0.000100

Epoch 4/12
-------------------------


Training:  17%|█▋        | 12/69 [00:23<01:50,  1.95s/it, Loss=0.128, Acc=54.2%]