In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, random_split, Subset, Dataset
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
import os
from tqdm import tqdm

# --- 1. CONFIGURATION ---
class Config_Q8:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers = 2
    dataset_path = '/kaggle/input/dgm-animals/Animals_data/animals/animals'
    
    # Path to the results from Q6, which contains the class list
    q6_output_dir = '/kaggle/input/dgm_q6_results/pytorch/default/1/Q6_cGAN_20_classes_G_heavy'
    
    # Model and Training Parameters
    num_classes = 20
    batch_size = 32
    num_epochs = 20 # Increased epochs for better convergence
    lr = 0.001
    val_split = 0.2
    output_dir = '/kaggle/working/Q8_ResNet_20_classes'

config = Config_Q8()
os.makedirs(config.output_dir, exist_ok=True)
print(f"Configuration loaded for Question 8. Using device: {config.device}")

Configuration loaded for Question 8. Using device: cuda


In [2]:
# --- CUSTOM DATASET WRAPPER (from Q6) ---
class ClassSubsetDataset(Dataset):
    def __init__(self, subset, class_mapping):
        self.subset, self.class_mapping = subset, class_mapping
    def __getitem__(self, index):
        image, original_label_idx = self.subset[index]
        new_label_idx = self.class_mapping[original_label_idx]
        return image, new_label_idx
    def __len__(self):
        return len(self.subset)

# --- DATA TRANSFORMS (from Q7) ---
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    'val': transforms.Compose([
        transforms.Resize(256), transforms.CenterCrop(224),
        transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

# --- LOAD THE 20 SELECTED CLASSES FROM Q6 ---
selected_class_names = np.load(os.path.join(config.q6_output_dir, 'selected_class_names.npy'))
class_mapping = torch.load(os.path.join(config.q6_output_dir, 'class_mapping.pth'))
print(f"Loaded {len(selected_class_names)} classes from Question 6.")

# --- CREATE FILTERED DATASETS WITH DIFFERENT TRANSFORMS ---
full_dataset_train = datasets.ImageFolder(root=config.dataset_path, transform=data_transforms['train'])
full_dataset_val = datasets.ImageFolder(root=config.dataset_path, transform=data_transforms['val'])

selected_class_indices = [full_dataset_train.class_to_idx[name] for name in selected_class_names]
subset_indices = [i for i, (_, label_idx) in enumerate(full_dataset_train.samples) if label_idx in selected_class_indices]

# Create subsets
train_subset_filtered = Subset(full_dataset_train, subset_indices)
val_subset_filtered = Subset(full_dataset_val, subset_indices)

# Wrap them to remap labels
train_dataset = ClassSubsetDataset(train_subset_filtered, class_mapping)
val_dataset = ClassSubsetDataset(val_subset_filtered, class_mapping)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)
print(f"Data prepared. Total samples for 20 classes: {len(train_dataset)}")

Loaded 20 classes from Question 6.
Data prepared. Total samples for 20 classes: 1193


In [3]:
# --- MODEL SETUP ---
model_q8 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model_q8.fc = nn.Linear(model_q8.fc.in_features, config.num_classes)
model_q8 = model_q8.to(config.device)

# --- TRAINING ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_q8.parameters(), lr=config.lr)
best_val_accuracy_q8 = 0.0
print("\nStarting training for Question 8...")

for epoch in range(config.num_epochs):
    model_q8.train()
    # Simplified training loop for brevity
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}"):
        inputs, labels = inputs.to(config.device), labels.to(config.device)
        optimizer.zero_grad()
        outputs = model_q8(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # --- VALIDATION ---
    model_q8.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(config.device), labels.to(config.device)
            outputs = model_q8(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_val_accuracy = accuracy_score(all_labels, all_preds)
    if epoch_val_accuracy > best_val_accuracy_q8:
        best_val_accuracy_q8 = epoch_val_accuracy
        torch.save(model_q8.state_dict(), os.path.join(config.output_dir, 'best_model_q8.pth'))

print("\nTraining for Question 8 finished.")

# --- FINAL METRIC CALCULATION (Q8) ---
model_q8.load_state_dict(torch.load(os.path.join(config.output_dir, 'best_model_q8.pth')))
final_accuracy_q8 = accuracy_score(all_labels, all_preds)
final_f1_score_q8 = f1_score(all_labels, all_preds, average='weighted')

print("\n" + "="*50)
print("  Question 8: Classifier Performance (Without Augmentation)")
print("="*50)
print(f"Classification Accuracy: {final_accuracy_q8:.4f}")
print(f"Weighted F1 Score:       {final_f1_score_q8:.4f}")
print("="*50)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 182MB/s]



Starting training for Question 8...


Epoch 1/20: 100%|██████████| 38/38 [00:12<00:00,  2.94it/s]
Epoch 2/20: 100%|██████████| 38/38 [00:08<00:00,  4.49it/s]
Epoch 3/20: 100%|██████████| 38/38 [00:08<00:00,  4.57it/s]
Epoch 4/20: 100%|██████████| 38/38 [00:08<00:00,  4.60it/s]
Epoch 5/20: 100%|██████████| 38/38 [00:08<00:00,  4.58it/s]
Epoch 6/20: 100%|██████████| 38/38 [00:08<00:00,  4.51it/s]
Epoch 7/20: 100%|██████████| 38/38 [00:08<00:00,  4.54it/s]
Epoch 8/20: 100%|██████████| 38/38 [00:08<00:00,  4.67it/s]
Epoch 9/20: 100%|██████████| 38/38 [00:08<00:00,  4.75it/s]
Epoch 10/20: 100%|██████████| 38/38 [00:08<00:00,  4.37it/s]
Epoch 11/20: 100%|██████████| 38/38 [00:08<00:00,  4.73it/s]
Epoch 12/20: 100%|██████████| 38/38 [00:08<00:00,  4.66it/s]
Epoch 13/20: 100%|██████████| 38/38 [00:08<00:00,  4.68it/s]
Epoch 14/20: 100%|██████████| 38/38 [00:08<00:00,  4.65it/s]
Epoch 15/20: 100%|██████████| 38/38 [00:08<00:00,  4.40it/s]
Epoch 16/20: 100%|██████████| 38/38 [00:09<00:00,  4.06it/s]
Epoch 17/20: 100%|██████████| 38/


Training for Question 8 finished.

  Question 8: Classifier Performance (Without Augmentation)
Classification Accuracy: 0.9883
Weighted F1 Score:       0.9882
