In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

# --- 1. CONFIGURATION ---
BATCH_SIZE = 32
LR = 0.001
EPOCHS = 10 # Short run to see convergence speed
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path to YOUR trained rotation model
YOUR_MODEL_PATH = 'model_saving/fullTraining_resnet18.pth' 
ARCH_NAME = 'resnet18' # Options: resnet50, efficientnet_b0, inception_v3

# Pascal VOC Classes
VOC_CLASSES = [
    'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 
    'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 
    'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]

# --- 2. DATASET HANDLING (MULTI-LABEL) ---
class PascalVOC_Classification(datasets.VOCDetection):
    """
    Wrapper to convert VOC dictionary targets into One-Hot Multi-label vectors.
    """
    def __init__(self, root, year, image_set, transform=None):
        super().__init__(root, year, image_set, download=True, transform=transform)
        self.class_to_idx = {name: i for i, name in enumerate(VOC_CLASSES)}

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        
        # Convert XML dictionary to One-Hot Vector (20 classes)
        label_vector = torch.zeros(20, dtype=torch.float32)
        
        objects = target['annotation']['object']
        if not isinstance(objects, list):
            objects = [objects]
            
        for obj in objects:
            class_name = obj['name']
            if class_name in self.class_to_idx:
                idx = self.class_to_idx[class_name]
                label_vector[idx] = 1.0
                
        return img, label_vector

# Transforms (Standard ImageNet stats)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load Datasets
print("Preparing Data...")
# Ensure you have the data downloaded or set download=True
train_dataset = PascalVOC_Classification(root='./data', year='2012', image_set='train', transform=transform)
val_dataset = PascalVOC_Classification(root='./data', year='2012', image_set='val', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# --- 3. MODEL BUILDER ---
def build_model(architecture, weights_type):
    """
    weights_type: 'imagenet', 'custom', 'random'
    """
    print(f"Building {architecture} with {weights_type} weights...")
    
    # 1. Initialize Base
    if weights_type == 'imagenet':
        model = getattr(models, architecture)(weights='DEFAULT')
    else:
        model = getattr(models, architecture)(weights=None)
    
    # 2. Load Custom Weights (if needed)
    if weights_type == 'custom':
        state_dict = torch.load(YOUR_MODEL_PATH, map_location='cpu')
        
        # Smart Filter (Same as before)
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        
        filtered_dict = {k: v for k, v in state_dict.items() 
                         if not ('fc' in k or 'classifier' in k or 'AuxLogits' in k)}
        
        model.load_state_dict(filtered_dict, strict=False)
        print("-> Custom rotation weights loaded (Head ignored).")

    # 3. Replace Head for 20 Classes (Pascal VOC)
    if 'resnet' in architecture or 'inception' in architecture:
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 20)
    elif 'efficientnet' in architecture:
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, 20)
    
    return model.to(DEVICE)

# --- 4. TRAINING LOOP ---
def train_fine_tuning(model, name):
    criterion = nn.BCEWithLogitsLoss() # Crucial for Multi-label
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
    
    loss_history = []
    
    print(f"--- Starting Fine-Tuning: {name} ---")
    model.train()
    
    for epoch in range(EPOCHS):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            
            # Inception returns NamedTuple (logits, aux_logits) during train
            if ARCH_NAME == 'inception_v3' and model.training:
                outputs, _ = model(inputs)
            else:
                outputs = model(inputs)
                
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        loss_history.append(epoch_loss)
        print(f"[{name}] Epoch {epoch+1}/{EPOCHS} - Loss: {epoch_loss:.4f}")
        
    return loss_history

# --- 5. EXECUTION & COMPARISON ---

# A. Train Baseline (ImageNet Supervised)
model_sup = build_model(ARCH_NAME, 'imagenet')
hist_sup = train_fine_tuning(model_sup, "ImageNet_Baseline")

# B. Train Your Model (Rotation Self-Supervised)
model_self = build_model(ARCH_NAME, 'custom')
hist_self = train_fine_tuning(model_self, "My_Rotation_Model")

# C. (Optional) Random Init - To prove you learned SOMETHING
model_rand = build_model(ARCH_NAME, 'random')
hist_rand = train_fine_tuning(model_rand, "Random_Init")

# --- 6. PLOT RESULTS ---
plt.figure(figsize=(10, 6))
plt.plot(hist_sup, label='Supervised (ImageNet)', marker='o')
plt.plot(hist_self, label='Self-Supervised (Rotation)', marker='x')
plt.plot(hist_rand, label='Random Init', linestyle='--') 

plt.title(f'Fine-Tuning Efficiency on Pascal VOC ({ARCH_NAME})')
plt.xlabel('Epochs')
plt.ylabel('BCE Loss (Lower is better)')
plt.legend()
plt.grid(True)
plt.show()