In [1]:
# --- 1. IMPORTS AND SETUP ---
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset
from torchvision.utils import save_image
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
import os
from tqdm import tqdm

In [2]:
class Config_Q9:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers = 2
    dataset_path = '/kaggle/input/dgm-animals/Animals_data/animals/animals'
    q6_output_dir = '/kaggle/input/dgm_q6_results/pytorch/default/1/Q6_cGAN_20_classes_G_heavy'
    q8_output_dir = '/kaggle/input/dgm_q8_results/pytorch/default/1/Q8_ResNet_20_classes'
    aug_data_dir = '/kaggle/working/Q9_augmented_data'
    num_images_per_class_to_gen = 100
    num_classes = 20
    batch_size = 32
    num_epochs = 20
    lr = 0.001
    output_dir = '/kaggle/working/Q9_ResNet_augmented'

config = Config_Q9()
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.aug_data_dir, exist_ok=True)
print(f"Configuration loaded. Using device: {config.device}")

Configuration loaded. Using device: cuda


In [3]:
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, embedding_dim, channels_img, features_g):
        super(ConditionalGenerator, self).__init__(); self.embed = nn.Embedding(num_classes, embedding_dim)
        self.net = nn.Sequential(
            self._block(latent_dim+embedding_dim, features_g*16, 4, 1, 0), self._block(features_g*16, features_g*8, 4, 2, 1),
            self._block(features_g*8, features_g*4, 4, 2, 1), self._block(features_g*4, features_g*2, 4, 2, 1),
            self._block(features_g*2, features_g, 4, 2, 1), nn.ConvTranspose2d(features_g, channels_img, 4, 2, 1), nn.Tanh())
    def _block(self, i, o, k, s, p): return nn.Sequential(nn.ConvTranspose2d(i, o, k, s, p, bias=False), nn.BatchNorm2d(o), nn.ReLU(True))
    def forward(self, z, l): return self.net(torch.cat([z, self.embed(l).unsqueeze(2).unsqueeze(3)], 1))

class ClassSubsetDataset(Dataset):
    def __init__(self, subset, class_mapping): self.subset, self.class_mapping = subset, class_mapping
    def __getitem__(self, i): img, orig_l = self.subset[i]; new_l = self.class_mapping[orig_l]; return img, new_l
    def __len__(self): return len(self.subset)

In [4]:
# --- GENERATE AUGMENTED DATA ---
gen_q6 = ConditionalGenerator(100, config.num_classes, 100, 3, 64).to(config.device)
gen_q6.load_state_dict(torch.load(os.path.join(config.q6_output_dir, 'generator_20_class.pth')))
gen_q6.eval()
print("Loaded c-GAN from Question 6.")

selected_class_names = np.load(os.path.join(config.q6_output_dir, 'selected_class_names.npy'))

print(f"Generating {config.num_images_per_class_to_gen} images for each of the {len(selected_class_names)} classes...")
for i, class_name in enumerate(tqdm(selected_class_names, desc="Generating Classes")):
    class_dir = os.path.join(config.aug_data_dir, class_name)
    os.makedirs(class_dir, exist_ok=True)
    with torch.no_grad():
        noise = torch.randn(config.num_images_per_class_to_gen, 100, 1, 1, device=config.device)
        labels = torch.full((config.num_images_per_class_to_gen,), i, device=config.device, dtype=torch.long)
        fake_images = gen_q6(noise, labels).detach().cpu()
        for j, image in enumerate(fake_images):
            save_image(image, os.path.join(class_dir, f"fake_{j+1}.png"), normalize=True)
print("Augmented data generation complete.")

Loaded c-GAN from Question 6.
Generating 100 images for each of the 20 classes...


Generating Classes: 100%|██████████| 20/20 [00:11<00:00,  1.75it/s]

Augmented data generation complete.





In [5]:
# --- CREATE AUGMENTED DATASET AND RETRAIN CLASSIFIER ---
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    'val': transforms.Compose([
        transforms.Resize(256), transforms.CenterCrop(224),
        transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

# Create the original real dataset for the 20 classes
full_dataset = datasets.ImageFolder(root=config.dataset_path, transform=data_transforms['train'])
class_mapping = torch.load(os.path.join(config.q6_output_dir, 'class_mapping.pth'))
selected_class_indices = [full_dataset.class_to_idx[name] for name in selected_class_names]
subset_indices = [i for i, (_, label_idx) in enumerate(full_dataset.samples) if label_idx in selected_class_indices]
real_train_dataset = ClassSubsetDataset(Subset(full_dataset, subset_indices), class_mapping)

# Create the fake dataset from the generated images
fake_dataset = datasets.ImageFolder(root=config.aug_data_dir, transform=data_transforms['train'])
fake_dataset.class_to_idx = {class_name: i for i, class_name in enumerate(selected_class_names)}

# Combine them
augmented_train_dataset = ConcatDataset([real_train_dataset, fake_dataset])

# Create DataLoaders
train_loader_aug = DataLoader(augmented_train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
full_dataset_val = datasets.ImageFolder(root=config.dataset_path, transform=data_transforms['val'])
val_dataset_real = ClassSubsetDataset(Subset(full_dataset_val, subset_indices), class_mapping)
val_loader = DataLoader(val_dataset_real, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)
print(f"Augmented training dataset created with {len(augmented_train_dataset)} total images.")

# Retrain the model on augmented data
model_q9 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model_q9.fc = nn.Linear(model_q9.fc.in_features, config.num_classes)
model_q9 = model_q9.to(config.device)
criterion = nn.CrossEntropyLoss()
optimizer_aug = optim.Adam(model_q9.parameters(), lr=config.lr)
print("\nStarting retraining on augmented data...")

for epoch in range(config.num_epochs):
    model_q9.train()
    for inputs, labels in tqdm(train_loader_aug, desc=f"Epoch {epoch+1}/{config.num_epochs}"):
        inputs, labels = inputs.to(config.device), labels.to(config.device)
        optimizer_aug.zero_grad(); outputs = model_q9(inputs); loss = criterion(outputs, labels)
        loss.backward(); optimizer_aug.step()
print("Retraining finished.")


Augmented training dataset created with 3193 total images.


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 188MB/s]



Starting retraining on augmented data...


Epoch 1/20: 100%|██████████| 100/100 [00:18<00:00,  5.54it/s]
Epoch 2/20: 100%|██████████| 100/100 [00:17<00:00,  5.84it/s]
Epoch 3/20: 100%|██████████| 100/100 [00:17<00:00,  5.81it/s]
Epoch 4/20: 100%|██████████| 100/100 [00:17<00:00,  5.77it/s]
Epoch 5/20: 100%|██████████| 100/100 [00:17<00:00,  5.82it/s]
Epoch 6/20: 100%|██████████| 100/100 [00:17<00:00,  5.82it/s]
Epoch 7/20: 100%|██████████| 100/100 [00:17<00:00,  5.79it/s]
Epoch 8/20: 100%|██████████| 100/100 [00:17<00:00,  5.80it/s]
Epoch 9/20: 100%|██████████| 100/100 [00:17<00:00,  5.80it/s]
Epoch 10/20: 100%|██████████| 100/100 [00:17<00:00,  5.83it/s]
Epoch 11/20: 100%|██████████| 100/100 [00:17<00:00,  5.83it/s]
Epoch 12/20: 100%|██████████| 100/100 [00:17<00:00,  5.75it/s]
Epoch 13/20: 100%|██████████| 100/100 [00:17<00:00,  5.80it/s]
Epoch 14/20: 100%|██████████| 100/100 [00:17<00:00,  5.82it/s]
Epoch 15/20: 100%|██████████| 100/100 [00:17<00:00,  5.83it/s]
Epoch 16/20: 100%|██████████| 100/100 [00:17<00:00,  5.83it/s]
E

Retraining finished.





In [6]:
# --- FINAL COMPARISON ---
model_q9.eval()
all_preds_q9, all_labels_q9 = [], []
with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model_q9(inputs.to(config.device)); _, preds = torch.max(outputs, 1)
        all_preds_q9.extend(preds.cpu().numpy()); all_labels_q9.extend(labels.cpu().numpy())
final_accuracy_q9 = accuracy_score(all_labels_q9, all_preds_q9)
final_f1_score_q9 = f1_score(all_labels_q9, all_preds_q9, average='weighted')

# Evaluate the old Q8 model for comparison
model_q8 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model_q8.fc = nn.Linear(model_q8.fc.in_features, config.num_classes)
model_q8.load_state_dict(torch.load(os.path.join(config.q8_output_dir, 'best_model_q8.pth')))
model_q8 = model_q8.to(config.device)
model_q8.eval()
all_preds_q8, all_labels_q8 = [], []
with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model_q8(inputs.to(config.device)); _, preds = torch.max(outputs, 1)
        all_preds_q8.extend(preds.cpu().numpy()); all_labels_q8.extend(labels.cpu().numpy())
final_accuracy_q8 = accuracy_score(all_labels_q8, all_preds_q8)
final_f1_score_q8 = f1_score(all_labels_q8, all_preds_q8, average='weighted')

# Print the final comparison report
print("\n" + "#"*50)
print("              Question 9: Final Comparison")
print("#"*50)
print(f"\nClassifier WITHOUT Augmentation (Q8):")
print(f"  - Accuracy: {final_accuracy_q8:.4f}")
print(f"  - F1 Score: {final_f1_score_q8:.4f}")
print(f"\nClassifier WITH GAN Augmentation (Q9):")
print(f"  - Accuracy: {final_accuracy_q9:.4f}")
print(f"  - F1 Score: {final_f1_score_q9:.4f}")
acc_diff = final_accuracy_q9 - final_accuracy_q8
f1_diff = final_f1_score_q9 - final_f1_score_q8
print("\n" + "-"*50)
print("Difference (Augmented - Original):")
print(f"  - Accuracy Delta: {acc_diff:+.4f}")
print(f"  - F1 Score Delta: {f1_diff:+.4f}")
print("#"*50)


##################################################
              Question 9: Final Comparison
##################################################

Classifier WITHOUT Augmentation (Q8):
  - Accuracy: 0.9941
  - F1 Score: 0.9941

Classifier WITH GAN Augmentation (Q9):
  - Accuracy: 0.9413
  - F1 Score: 0.9400

--------------------------------------------------
Difference (Augmented - Original):
  - Accuracy Delta: -0.0528
  - F1 Score Delta: -0.0542
##################################################
