In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
from tqdm.auto import tqdm
from torch.nn.utils import spectral_norm


In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: NVIDIA L4


## Calculating the Normalization mean and std specifically for the dataset

In [18]:
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [43]:
minority_data = torch.load('class_0_reduced.pt')
print(minority_data.keys())

dict_keys(['data', 'labels', 'indices'])


In [45]:
from torch.utils.data import TensorDataset

minority_data = torch.load('class_0_reduced.pt')
minority_dataset = TensorDataset(minority_data['data'], minority_data['labels'])
minority_loader = DataLoader(minority_dataset, batch_size=32, shuffle=True)

minority_class_idx = 0
minority_class_name = "Class 0"
batch_size = 32

#  Defining the Models Classes

In [4]:
# Generator
class VanillaGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_size=128):
        super(VanillaGenerator, self).__init__()
        self.img_size = img_size
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),

            nn.Linear(1024,img_size * img_size),
            nn.Tanh()
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, self.img_size, self.img_size)
        return img

# Discriminator
class VanillaDiscriminator(nn.Module):
    def __init__(self, img_size=128):
        super(VanillaDiscriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

In [5]:
class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=1):
        super().__init__()
        self.init_size = 7  # Changed from 4 to 7 (28/4 = 7)
        self.l1 = nn.Linear(latent_dim, 128 * self.init_size ** 2)
        
        self.model = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 7->14
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),  # 14->28
            nn.Tanh()
        )
    
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        return self.model(out)

class DCGANDiscriminator(nn.Module):
    def __init__(self, img_channels=1):
        super().__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(img_channels, 64, 4, 2, 1)),  # 28->14
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),  # 14->7
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(128, 1, 7, 1, 0))  # 7->1
        )
    
    def forward(self, img):
        out = self.model(img)
        return torch.sigmoid(out.view(-1, 1))

In [6]:
class CGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10, img_channels=1):
        super(CGANGenerator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.init_size = 7
        self.l1 = nn.Linear(latent_dim * 2, 128 * self.init_size ** 2)
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 7->14
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),  # 14->28
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        label_input = self.label_emb(labels)
        gen_input = torch.cat([z, label_input], -1)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class CGANDiscriminator(nn.Module):
    def __init__(self, num_classes=10, img_channels=1):
        super(CGANDiscriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, 28 * 28)  # Changed from 128*128
        
        self.model = nn.Sequential(
            nn.Conv2d(img_channels + 1, 64, 4, stride=2, padding=1),  # 28->14
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 14->7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, 7, stride=1, padding=0),  # 7->1
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        label_input = self.label_emb(labels).view(labels.shape[0], 1, 28, 28)
        d_in = torch.cat([img, label_input], 1)
        validity = self.model(d_in)
        return validity.view(-1, 1)

In [11]:
import time
from torchvision.utils import save_image
from scipy import linalg
import numpy as np
from torchvision.models import inception_v3

# FID Score Calculation
def calculate_fid(real_images, fake_images, device):
    """Calculate Fr√©chet Inception Distance"""
    inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
    inception_model.fc = nn.Identity()
    inception_model.eval()
    
    def get_activations(images):
        with torch.no_grad():

            if images.shape[1] == 1:
                images = images.repeat(1, 3, 1, 1)
                
            # Resize to 299x299 for Inception
            images_resized = nn.functional.interpolate(images, size=(299, 299), mode='bilinear', align_corners=True)
            pred = inception_model(images_resized)
        return pred.cpu().numpy()
    
    act_real = get_activations(real_images)
    act_fake = get_activations(fake_images)
    
    mu_real, sigma_real = act_real.mean(axis=0), np.cov(act_real, rowvar=False)
    mu_fake, sigma_fake = act_fake.mean(axis=0), np.cov(act_fake, rowvar=False)
    
    diff = mu_real - mu_fake
    covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff.dot(diff) + np.trace(sigma_real + sigma_fake - 2 * covmean)
    return fid

# Save sample images
def save_sample_images(generator, epoch, latent_dim, save_path, num_samples=25, labels=None):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim).to(device)
        if labels is not None:
            # For CGAN
            sample_labels = torch.full((num_samples,), labels, dtype=torch.long).to(device)
            gen_imgs = generator(z, sample_labels)
        else:
            gen_imgs = generator(z)
        
        save_image(gen_imgs.data, f"{save_path}/epoch_{epoch}.png", nrow=5, normalize=True)
    generator.train()


# Defining the Training Loop for each Model

In [20]:
def compute_gradient_penalty(discriminator, real_imgs, fake_imgs, labels=None):
    batch_size = real_imgs.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)
    
    interpolates = (alpha * real_imgs + (1 - alpha) * fake_imgs).requires_grad_(True)
    
    if labels is not None:
        d_interpolates = discriminator(interpolates, labels)
    else:
        d_interpolates = discriminator(interpolates)
    
    fake = torch.ones(batch_size, 1).to(device).requires_grad_(False)  # Changed
    
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,  # Changed
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [21]:
def train_vanilla_gan(generator, discriminator, dataloader, num_epochs=100, starting_epoch=1, latent_dim=100, 
                      save_path='generated_images/vanilla_gan', model_path='models/vanilla_gan'):
    
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(model_path, exist_ok=True)
    
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.5, 0.999))
    criterion = nn.BCELoss()

    # Create metrics file
    metrics_file = f'{save_path}/training_metrics.txt'
    with open(metrics_file, 'w') as f:
        f.write("Epoch,Time(s),D_Loss,G_Loss,FID_Score\n")
    
    print(f"\n{'='*80}")
    print(f"STARTING VANILLA GAN TRAINING - {num_epochs} EPOCHS")
    print(f"{'='*80}\n")
    
    generator.train()
    discriminator.train()
    
    for epoch in range(starting_epoch, num_epochs + 1):
        epoch_start = time.time()
        epoch_d_loss = 0
        epoch_g_loss = 0
        num_batches = 0
        
        for i, (imgs, _) in enumerate(dataloader):
            batch_size = imgs.size(0)
            real_imgs = imgs.to(device)

            
            # Labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)
            
            if epoch <= 10 or i % 2 == 0:
                optimizer_D.zero_grad()
                real_discriminator_prediction = discriminator(real_imgs)
                real_loss = criterion(real_discriminator_prediction, real_labels)
                
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_imgs = generator(z)
                fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
                
                gp = compute_gradient_penalty(discriminator, real_imgs, fake_imgs)
                d_loss = (real_loss + fake_loss) / 2 + 5 * gp
                d_loss.backward()
                optimizer_D.step()
            else:
                d_loss = torch.tensor(0.0)

            # Train Generator
            optimizer_G.zero_grad()
            z = torch.randn(batch_size, latent_dim).to(device)
            gen_imgs = generator(z)
            g_loss = criterion(discriminator(gen_imgs), real_labels)
            g_loss.backward()
            optimizer_G.step()
            
            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            num_batches += 1
        
        # Calculate average losses
        avg_d_loss = epoch_d_loss / num_batches
        avg_g_loss = epoch_g_loss / num_batches
        

        fid_score = 0.0
        if epoch % 2 == 0:
            with torch.no_grad():
                real_batch = next(iter(dataloader))[0][:128].to(device)
                z = torch.randn(128, latent_dim).to(device)
                fake_batch = generator(z)
                fid_score = calculate_fid(real_batch, fake_batch, device)
        
        epoch_time = time.time() - epoch_start
        
        # Print metrics
        print(f"Epoch {epoch} | Time: {epoch_time:.2f}s")
        print(f"  D_Loss: {avg_d_loss:.6f} | G_Loss: {avg_g_loss:.6f}")
        if epoch % 2 == 0:
            print(f"  FID Score: {fid_score:.4f}")
        print()
        
        # Save metrics to file
        with open(metrics_file, 'a') as f:
            f.write(f"{epoch},{epoch_time:.2f},{avg_d_loss:.6f},{avg_g_loss:.6f},{fid_score:.4f}\n")
        
        save_sample_images(generator, epoch, latent_dim, save_path)
        
        # Save model after epoch 5
        if epoch > 20 and epoch % 4 == 0:
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
            }, f"{model_path}/model_epoch_{epoch}.pt")

        del imgs, real_imgs, fake_imgs, gen_imgs, z
        if i % 10 == 0:
            torch.cuda.empty_cache()
    
    print(f"\n{'='*80}")
    print(f"VANILLA GAN TRAINING COMPLETE")
    print(f"{'='*80}\n")
    
    return generator, discriminator


In [None]:
def train_dcgan(generator, discriminator, dataloader, starting_epoch=1, num_epochs=100, latent_dim=100,
                save_path='generated_images/dcgan', model_path='models/dcgan'):
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(model_path, exist_ok=True)
    
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.5, 0.999))
    criterion = nn.BCELoss()
    
    metrics_file = f'{save_path}/training_metrics.txt'
    with open(metrics_file, 'w') as f:
        f.write("Epoch,Time(s),D_Loss,G_Loss,FID_Score\n")
    
    print(f"\n{'='*80}")
    print(f"STARTING DCGAN TRAINING - {num_epochs} EPOCHS")
    print(f"{'='*80}\n")
    
    generator.train()
    discriminator.train()
    
    for epoch in range(starting_epoch, num_epochs + starting_epoch):
        epoch_start = time.time()
        epoch_d_loss = 0
        epoch_g_loss = 0
        num_batches = 0
        
        for i, (imgs, _) in enumerate(dataloader):
            batch_size = imgs.size(0)
            real_imgs = imgs.to(device)

            real_labels = torch.ones(batch_size, 1).to(device) 
            fake_labels = torch.zeros(batch_size, 1).to(device)
            
            if epoch <= 10 or i % 3 == 0:
                optimizer_D.zero_grad()
                real_loss = criterion(discriminator(real_imgs), real_labels)
                
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_imgs = generator(z)
                fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
                
                gp = compute_gradient_penalty(discriminator, real_imgs, fake_imgs)
                d_loss = (real_loss + fake_loss) / 2 + 2 * gp
                d_loss.backward()
                optimizer_D.step()
            else:
                d_loss = torch.tensor(0.0)

            # Train Generator
            optimizer_G.zero_grad()
            z = torch.randn(batch_size, latent_dim).to(device)
            gen_imgs = generator(z)
            g_loss = criterion(discriminator(gen_imgs), real_labels)
            g_loss.backward()
            optimizer_G.step()
            
            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            num_batches += 1
        
        avg_d_loss = epoch_d_loss / num_batches
        avg_g_loss = epoch_g_loss / num_batches
        
        # Calculate FID every 2 epochs
        fid_score = 0.0
        if epoch % 2 == 0:
            with torch.no_grad():
                real_batch = next(iter(dataloader))[0][:128].to(device)
                z = torch.randn(128, latent_dim).to(device)
                fake_batch = generator(z)
                fid_score = calculate_fid(real_batch, fake_batch, device)
        
        epoch_time = time.time() - epoch_start
        
        print(f"Epoch {epoch} | Time: {epoch_time:.2f}s")
        print(f"  D_Loss: {avg_d_loss:.6f} | G_Loss: {avg_g_loss:.6f}")
        if epoch % 2 == 0:
            print(f"  FID Score: {fid_score:.4f}")
        print()
        
        with open(metrics_file, 'a') as f:
            f.write(f"{epoch},{epoch_time:.2f},{avg_d_loss:.6f},{avg_g_loss:.6f},{fid_score:.4f}\n")
        
       
        save_sample_images(generator, epoch, latent_dim, save_path)
        
        if epoch > 20 and epoch % 4 == 0:
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
            }, f"{model_path}/model_epoch_{epoch}.pt")
    
    print(f"\n{'='*80}")
    print(f"DCGAN TRAINING COMPLETE")
    print(f"{'='*80}\n")
    
    return generator, discriminator


In [25]:
def train_cgan(generator, discriminator, dataloader, minority_class_idx, num_epochs=100, starting_epoch=1, 
               latent_dim=100, save_path='generated_images/cgan', model_path='models/cgan'):
    

    os.makedirs(save_path, exist_ok=True)
    os.makedirs(model_path, exist_ok=True)

    optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.5, 0.999))
    criterion = nn.BCELoss()
    
    metrics_file = f'{save_path}/training_metrics.txt'
    with open(metrics_file, 'w') as f:
        f.write("Epoch,Time(s),D_Loss,G_Loss,FID_Score\n")
    
    print(f"\n{'='*80}")
    print(f"STARTING CGAN TRAINING - {num_epochs} EPOCHS")
    print(f"{'='*80}\n")
    
    generator.train()
    discriminator.train()
    
    for epoch in range(starting_epoch, num_epochs + 1):
        epoch_start = time.time()
        epoch_d_loss = 0
        epoch_g_loss = 0
        num_batches = 0
        
        for i, (imgs, labels) in enumerate(dataloader):
            batch_size = imgs.size(0)
            real_imgs = imgs.to(device)
            labels = labels.to(device)
            
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)
            
            # Train Discriminator but skip training it every 3 epochs

            if epoch <= 5 or i % 2 == 0:
                optimizer_D.zero_grad()
                real_loss = criterion(discriminator(real_imgs, labels), real_labels)
                
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_imgs = generator(z, labels)
                fake_loss = criterion(discriminator(fake_imgs.detach(), labels), fake_labels)
                
                gp = compute_gradient_penalty(discriminator, real_imgs, fake_imgs, labels)
                d_loss = (real_loss + fake_loss) / 2 + 5 * gp
                d_loss.backward()
                optimizer_D.step()
            else:
                d_loss = torch.tensor(0.0)

            # Train G twice
            for _ in range(2):
                optimizer_G.zero_grad()
                z = torch.randn(batch_size, latent_dim).to(device)
                gen_imgs = generator(z, labels)
                g_loss = criterion(discriminator(gen_imgs, labels), real_labels)
                g_loss.backward()
                optimizer_G.step()
            
            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            num_batches += 1
        
        avg_d_loss = epoch_d_loss / num_batches
        avg_g_loss = epoch_g_loss / num_batches
        
        # Calculate FID every 2 epochs
        fid_score = 0.0
        if epoch % 2 == 0:
            with torch.no_grad():
                real_batch = next(iter(dataloader))[0][:128].to(device)
                z = torch.randn(128, latent_dim).to(device)
                sample_labels = torch.full((128,), minority_class_idx, dtype=torch.long).to(device)
                fake_batch = generator(z, sample_labels)
                fid_score = calculate_fid(real_batch, fake_batch, device)
        
        epoch_time = time.time() - epoch_start
        
        print(f"Epoch {epoch} | Time: {epoch_time:.2f}s")
        print(f"  D_Loss: {avg_d_loss:.6f} | G_Loss: {avg_g_loss:.6f}")
        if epoch % 2 == 0:
            print(f"  FID Score: {fid_score:.4f}")
        print()
        
        with open(metrics_file, 'a') as f:
            f.write(f"{epoch},{epoch_time:.2f},{avg_d_loss:.6f},{avg_g_loss:.6f},{fid_score:.4f}\n")
        
        
        save_sample_images(generator, epoch, latent_dim, save_path, labels=minority_class_idx)
        
        if epoch > 20 and epoch % 4 == 0:
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
            }, f"{model_path}/model_epoch_{epoch}.pt")
    
    print(f"\n{'='*80}")
    print(f"CGAN TRAINING COMPLETE")
    print(f"{'='*80}\n")
    
    return generator, discriminator

In [26]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

**Training Vanilla GAN**

In [28]:
# Set random seed for reproducibility
torch.manual_seed(42)
latent_dim = 100

In [36]:
# Initialize models
vanilla_gen = VanillaGenerator(latent_dim=latent_dim, img_size=28).to(device)
vanilla_disc = VanillaDiscriminator(img_size=28).to(device)
vanilla_gen.apply(weights_init)
vanilla_disc.apply(weights_init)

print(f"Vanilla Generator parameters: {sum(p.numel() for p in vanilla_gen.parameters()):,}")
print(f"Vanilla Discriminator parameters: {sum(p.numel() for p in vanilla_disc.parameters()):,}")

Vanilla Generator parameters: 1,489,424
Vanilla Discriminator parameters: 1,460,225


In [37]:
# Train Vanilla GAN
vanilla_gen, vanilla_disc = train_vanilla_gan(
    vanilla_gen, 
    vanilla_disc, 
    minority_loader, 
    num_epochs=120,
    latent_dim=latent_dim,
    save_path='generated_images/vanilla_gan',
    model_path='models/vanilla_gan'
)


STARTING VANILLA GAN TRAINING - 120 EPOCHS

Epoch 1 | Time: 2.82s
  D_Loss: 5.264101 | G_Loss: 0.560085



  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


Epoch 2 | Time: 12.73s
  D_Loss: 4.642878 | G_Loss: 0.859338
  FID Score: 518.0356

Epoch 3 | Time: 1.63s
  D_Loss: 4.456237 | G_Loss: 0.823284

Epoch 4 | Time: 16.53s
  D_Loss: 4.525640 | G_Loss: 0.729868
  FID Score: 463.1566

Epoch 5 | Time: 1.53s
  D_Loss: 4.362242 | G_Loss: 0.765675

Epoch 6 | Time: 15.25s
  D_Loss: 4.093177 | G_Loss: 0.884674
  FID Score: 487.8307

Epoch 7 | Time: 1.68s
  D_Loss: 3.741555 | G_Loss: 1.000810

Epoch 8 | Time: 14.44s
  D_Loss: 3.408796 | G_Loss: 0.904279
  FID Score: 422.4494

Epoch 9 | Time: 2.23s
  D_Loss: 2.909197 | G_Loss: 0.811868

Epoch 10 | Time: 14.21s
  D_Loss: 2.651665 | G_Loss: 0.747097
  FID Score: 385.0221

Epoch 11 | Time: 1.33s
  D_Loss: 1.379952 | G_Loss: 0.582346

Epoch 12 | Time: 15.85s
  D_Loss: 1.306065 | G_Loss: 0.531220
  FID Score: 361.5755

Epoch 13 | Time: 1.36s
  D_Loss: 1.338386 | G_Loss: 0.511708

Epoch 14 | Time: 11.50s
  D_Loss: 1.247839 | G_Loss: 0.507876
  FID Score: 386.5652

Epoch 15 | Time: 1.27s
  D_Loss: 1.147187

The best Epoch is 72

In [38]:
del vanilla_gen, vanilla_disc
torch.cuda.empty_cache()

**Training DCGAN GAN**

In [40]:
# Initialize DCGAN
dcgan_gen = DCGANGenerator(latent_dim=latent_dim, img_channels=1).to(device)
dcgan_disc = DCGANDiscriminator().to(device)

dcgan_gen.apply(weights_init)
dcgan_disc.apply(weights_init)

print(f"\nDCGAN Generator parameters: {sum(p.numel() for p in dcgan_gen.parameters()):,}")
print(f"DCGAN Discriminator parameters: {sum(p.numel() for p in dcgan_disc.parameters()):,}")


DCGAN Generator parameters: 766,017
DCGAN Discriminator parameters: 138,561


In [41]:
dcgan_gen, dcgan_disc = train_dcgan(
    dcgan_gen,
    dcgan_disc,
    minority_loader,
    num_epochs=90,
    latent_dim=latent_dim,
)


STARTING DCGAN TRAINING - 90 EPOCHS

Epoch 1 | Time: 1.64s
  D_Loss: 2.260379 | G_Loss: 0.920222



  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


Epoch 2 | Time: 12.26s
  D_Loss: 1.931814 | G_Loss: 1.500543
  FID Score: 431.7587

Epoch 3 | Time: 1.35s
  D_Loss: 1.810199 | G_Loss: 1.756070



  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


Epoch 4 | Time: 14.95s
  D_Loss: 1.800336 | G_Loss: 1.370924
  FID Score: 445.4994

Epoch 5 | Time: 1.28s
  D_Loss: 1.716866 | G_Loss: 1.184177

Epoch 6 | Time: 15.31s
  D_Loss: 1.656825 | G_Loss: 0.870761
  FID Score: 369.2838

Epoch 7 | Time: 1.32s
  D_Loss: 1.643644 | G_Loss: 0.917189

Epoch 8 | Time: 11.68s
  D_Loss: 1.522435 | G_Loss: 1.032365
  FID Score: 324.9901

Epoch 9 | Time: 1.32s
  D_Loss: 1.596839 | G_Loss: 0.778700

Epoch 10 | Time: 15.24s
  D_Loss: 1.677904 | G_Loss: 0.732065
  FID Score: 392.4905

Epoch 11 | Time: 0.90s
  D_Loss: 0.566141 | G_Loss: 0.771075

Epoch 12 | Time: 14.17s
  D_Loss: 0.617097 | G_Loss: 0.601899
  FID Score: 308.6601

Epoch 13 | Time: 0.99s
  D_Loss: 0.619508 | G_Loss: 0.652254

Epoch 14 | Time: 12.08s
  D_Loss: 0.582077 | G_Loss: 0.754117
  FID Score: 282.4316

Epoch 15 | Time: 1.06s
  D_Loss: 0.565006 | G_Loss: 0.751214

Epoch 16 | Time: 10.74s
  D_Loss: 0.570861 | G_Loss: 0.663671
  FID Score: 251.4420

Epoch 17 | Time: 1.04s
  D_Loss: 0.5918

In [43]:
dcgan_gen, dcgan_disc = train_dcgan(
    dcgan_gen,
    dcgan_disc,
    minority_loader,
    starting_epoch=91,
    num_epochs=40,
    latent_dim=latent_dim,
)


STARTING DCGAN TRAINING - 40 EPOCHS

Epoch 91 | Time: 0.92s
  D_Loss: 0.332925 | G_Loss: 0.665576



  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


Epoch 92 | Time: 11.22s
  D_Loss: 0.330868 | G_Loss: 0.667165
  FID Score: 158.4287

Epoch 93 | Time: 0.97s
  D_Loss: 0.323998 | G_Loss: 0.661991



  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


Epoch 94 | Time: 14.98s
  D_Loss: 0.329644 | G_Loss: 0.669420
  FID Score: 175.9784

Epoch 95 | Time: 0.91s
  D_Loss: 0.327549 | G_Loss: 0.666487

Epoch 96 | Time: 14.00s
  D_Loss: 0.326502 | G_Loss: 0.663914
  FID Score: 162.7652

Epoch 97 | Time: 1.01s
  D_Loss: 0.327408 | G_Loss: 0.668782

Epoch 98 | Time: 10.32s
  D_Loss: 0.321347 | G_Loss: 0.659073
  FID Score: 177.0736

Epoch 99 | Time: 0.90s
  D_Loss: 0.328572 | G_Loss: 0.666037

Epoch 100 | Time: 15.26s
  D_Loss: 0.322016 | G_Loss: 0.662711
  FID Score: 172.6832

Epoch 101 | Time: 0.98s
  D_Loss: 0.320817 | G_Loss: 0.666348

Epoch 102 | Time: 11.05s
  D_Loss: 0.322077 | G_Loss: 0.667308
  FID Score: 198.5219

Epoch 103 | Time: 0.94s
  D_Loss: 0.319539 | G_Loss: 0.665896

Epoch 104 | Time: 14.89s
  D_Loss: 0.320245 | G_Loss: 0.668212
  FID Score: 152.3712

Epoch 105 | Time: 0.86s
  D_Loss: 0.318892 | G_Loss: 0.666660

Epoch 106 | Time: 13.55s
  D_Loss: 0.318114 | G_Loss: 0.661518
  FID Score: 187.5616

Epoch 107 | Time: 0.93s
  

In [44]:
dcgan_gen, dcgan_disc = train_dcgan(
    dcgan_gen,
    dcgan_disc,
    minority_loader,
    starting_epoch=131,
    num_epochs=40,
    latent_dim=latent_dim,
)


STARTING DCGAN TRAINING - 40 EPOCHS

Epoch 131 | Time: 0.91s
  D_Loss: 0.301055 | G_Loss: 0.667480



  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


Epoch 132 | Time: 11.28s
  D_Loss: 0.300216 | G_Loss: 0.667764
  FID Score: 138.1669

Epoch 133 | Time: 1.03s
  D_Loss: 0.301454 | G_Loss: 0.666016

Epoch 134 | Time: 14.27s
  D_Loss: 0.300604 | G_Loss: 0.671652
  FID Score: 154.1957

Epoch 135 | Time: 0.91s
  D_Loss: 0.297838 | G_Loss: 0.665889

Epoch 136 | Time: 15.81s
  D_Loss: 0.298130 | G_Loss: 0.671982
  FID Score: 138.1341

Epoch 137 | Time: 1.05s
  D_Loss: 0.303703 | G_Loss: 0.668358

Epoch 138 | Time: 15.31s
  D_Loss: 0.299851 | G_Loss: 0.672150
  FID Score: 151.7499

Epoch 139 | Time: 0.90s
  D_Loss: 0.301413 | G_Loss: 0.667649

Epoch 140 | Time: 14.02s
  D_Loss: 0.291866 | G_Loss: 0.669754
  FID Score: 151.3447

Epoch 141 | Time: 0.91s
  D_Loss: 0.296742 | G_Loss: 0.668489

Epoch 142 | Time: 14.39s
  D_Loss: 0.295960 | G_Loss: 0.672417
  FID Score: 127.6377

Epoch 143 | Time: 0.92s
  D_Loss: 0.296623 | G_Loss: 0.667422

Epoch 144 | Time: 13.54s
  D_Loss: 0.297901 | G_Loss: 0.673009
  FID Score: 143.2086

Epoch 145 | Time: 0.

In [45]:
dcgan_gen, dcgan_disc = train_dcgan(
    dcgan_gen,
    dcgan_disc,
    minority_loader,
    starting_epoch=171,
    num_epochs=40,
    latent_dim=latent_dim,
)


STARTING DCGAN TRAINING - 40 EPOCHS

Epoch 171 | Time: 0.87s
  D_Loss: 0.278413 | G_Loss: 0.674006



  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


Epoch 172 | Time: 11.21s
  D_Loss: 0.281117 | G_Loss: 0.670733
  FID Score: 114.7207

Epoch 173 | Time: 1.03s
  D_Loss: 0.280320 | G_Loss: 0.670960

Epoch 174 | Time: 12.56s
  D_Loss: 0.278804 | G_Loss: 0.671277
  FID Score: 118.4346

Epoch 175 | Time: 1.09s
  D_Loss: 0.281510 | G_Loss: 0.673314

Epoch 176 | Time: 11.36s
  D_Loss: 0.279320 | G_Loss: 0.674648
  FID Score: 129.0693

Epoch 177 | Time: 0.89s
  D_Loss: 0.279626 | G_Loss: 0.677095

Epoch 178 | Time: 15.00s
  D_Loss: 0.279706 | G_Loss: 0.676262
  FID Score: 134.4134

Epoch 179 | Time: 0.95s
  D_Loss: 0.277238 | G_Loss: 0.678139

Epoch 180 | Time: 14.96s
  D_Loss: 0.277637 | G_Loss: 0.671616
  FID Score: 114.5718

Epoch 181 | Time: 1.02s
  D_Loss: 0.274847 | G_Loss: 0.677895

Epoch 182 | Time: 11.75s
  D_Loss: 0.275118 | G_Loss: 0.678672
  FID Score: 120.3658

Epoch 183 | Time: 0.91s
  D_Loss: 0.275249 | G_Loss: 0.672648

Epoch 184 | Time: 14.61s
  D_Loss: 0.276342 | G_Loss: 0.676224
  FID Score: 130.1348

Epoch 185 | Time: 0.

In [None]:
dcgan_gen, dcgan_disc = train_dcgan(
    dcgan_gen,
    dcgan_disc,
    minority_loader,
    starting_epoch=211,
    num_epochs=30,
    latent_dim=latent_dim,
)


STARTING DCGAN TRAINING - 30 EPOCHS

Epoch 211 | Time: 0.97s
  D_Loss: 0.273095 | G_Loss: 0.673367



  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


Epoch 212 | Time: 10.99s
  D_Loss: 0.268685 | G_Loss: 0.677662
  FID Score: 103.2224

Epoch 213 | Time: 0.90s
  D_Loss: 0.268398 | G_Loss: 0.672598

Epoch 214 | Time: 13.66s
  D_Loss: 0.270828 | G_Loss: 0.671816
  FID Score: 108.7755

Epoch 215 | Time: 0.91s
  D_Loss: 0.271301 | G_Loss: 0.680873

Epoch 216 | Time: 13.35s
  D_Loss: 0.269378 | G_Loss: 0.679156
  FID Score: 119.3399

Epoch 217 | Time: 1.05s
  D_Loss: 0.270222 | G_Loss: 0.672598

Epoch 218 | Time: 11.11s
  D_Loss: 0.268689 | G_Loss: 0.679771
  FID Score: 96.1246

Epoch 219 | Time: 0.91s
  D_Loss: 0.272238 | G_Loss: 0.675410

Epoch 220 | Time: 12.62s
  D_Loss: 0.266781 | G_Loss: 0.680804
  FID Score: 108.2871

Epoch 221 | Time: 0.92s
  D_Loss: 0.268754 | G_Loss: 0.673535

Epoch 222 | Time: 13.92s
  D_Loss: 0.270613 | G_Loss: 0.676236
  FID Score: 99.6651

Epoch 223 | Time: 0.91s
  D_Loss: 0.270044 | G_Loss: 0.681073

Epoch 224 | Time: 11.04s
  D_Loss: 0.267199 | G_Loss: 0.675593
  FID Score: 95.8992

Epoch 225 | Time: 0.93s

In [47]:
del dcgan_gen, dcgan_disc
torch.cuda.empty_cache()

epoch 236 is the best

**Training CGAN GAN**

In [53]:
# Initialize CGAN
num_classes = 10
cgan_gen = CGANGenerator(latent_dim=latent_dim, num_classes=num_classes, img_channels=1).to(device)
cgan_disc = CGANDiscriminator(num_classes=num_classes, img_channels=1).to(device)

print(f"\nCGAN Generator parameters: {sum(p.numel() for p in cgan_gen.parameters()):,}")
print(f"CGAN Discriminator parameters: {sum(p.numel() for p in cgan_disc.parameters()):,}")


CGAN Generator parameters: 1,394,217
CGAN Discriminator parameters: 147,681


In [56]:
full_dataset = datasets.MNIST(
    root="./mnist",
    train=True,
    download=False,
    transform=transforms.ToTensor()
)

full_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=True)

In [57]:
cgan_gen, cgan_disc = train_cgan(
    cgan_gen,
    cgan_disc,
    full_loader,
    minority_class_idx,
    num_epochs=100,
    latent_dim=latent_dim,
)


STARTING CGAN TRAINING - 100 EPOCHS

Epoch 1 | Time: 54.86s
  D_Loss: 1.146193 | G_Loss: 0.625094



  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


Epoch 2 | Time: 50.10s
  D_Loss: 0.960949 | G_Loss: 0.667081
  FID Score: 370.3369

Epoch 3 | Time: 40.70s
  D_Loss: 0.786557 | G_Loss: 0.648492

Epoch 4 | Time: 57.38s
  D_Loss: 0.649404 | G_Loss: 0.716543
  FID Score: 319.0408

Epoch 5 | Time: 43.28s
  D_Loss: 0.609288 | G_Loss: 0.749667

Epoch 6 | Time: 42.31s
  D_Loss: 0.367074 | G_Loss: 0.626096
  FID Score: 395.3452

Epoch 7 | Time: 30.22s
  D_Loss: 0.362740 | G_Loss: 0.629262

Epoch 8 | Time: 41.84s
  D_Loss: 0.345884 | G_Loss: 0.666167
  FID Score: 303.0732

Epoch 9 | Time: 29.52s
  D_Loss: 0.332199 | G_Loss: 0.692052

Epoch 10 | Time: 42.95s
  D_Loss: 0.324145 | G_Loss: 0.706850
  FID Score: 263.5411

Epoch 11 | Time: 32.94s
  D_Loss: 0.322066 | G_Loss: 0.711228

Epoch 12 | Time: 49.77s
  D_Loss: 0.314607 | G_Loss: 0.720868
  FID Score: 294.5211

Epoch 13 | Time: 31.28s
  D_Loss: 0.303359 | G_Loss: 0.753724

Epoch 14 | Time: 47.99s
  D_Loss: 0.300660 | G_Loss: 0.756405
  FID Score: 257.6621

Epoch 15 | Time: 31.65s
  D_Loss: 0

LinAlgError: Internal error in scipy.linalg.sqrtm: -101

In [52]:
del cgan_gen, cgan_disc
torch.cuda.empty_cache()

best epoch is 40

In [13]:
import os
from torchvision.utils import save_image
from PIL import Image

# Create directories for synthetic images
os.makedirs('synthetic_data/vanilla_gan', exist_ok=True)
os.makedirs('synthetic_data/dcgan', exist_ok=True)
os.makedirs('synthetic_data/cgan', exist_ok=True)

print("Created directories for synthetic images")

Created directories for synthetic images


In [19]:
latent_dim = 100

# Load Vanilla GAN
vanilla_gen = VanillaGenerator(latent_dim=latent_dim, img_size=28).to(device)
checkpoint_vanilla = torch.load('gan_saved_models/vanilla_gan/model_epoch_72.pt')
vanilla_gen.load_state_dict(checkpoint_vanilla['generator_state_dict'])
vanilla_gen.eval()
print("‚úì Loaded Vanilla GAN from epoch 72")

# Load DCGAN 
dcgan_gen = DCGANGenerator(latent_dim=latent_dim, img_channels=1).to(device)
checkpoint_dcgan = torch.load('gan_saved_models/dcgan/model_epoch_236.pt')
dcgan_gen.load_state_dict(checkpoint_dcgan['generator_state_dict'])
dcgan_gen.eval()
print("‚úì Loaded DCGAN from epoch 236")

# Load CGAN 
num_classes = 10    
cgan_gen = CGANGenerator(latent_dim=latent_dim, num_classes=num_classes, img_channels=1).to(device)
checkpoint_cgan = torch.load('gan_saved_models/cgan/model_epoch_40.pt')
cgan_gen.load_state_dict(checkpoint_cgan['generator_state_dict'])
cgan_gen.eval()
print("‚úì Loaded CGAN from epoch 40")

‚úì Loaded Vanilla GAN from epoch 72
‚úì Loaded DCGAN from epoch 236
‚úì Loaded CGAN from epoch 40


In [20]:
def generate_synthetic_images(generator, num_images, save_dir, gan_type='vanilla', minority_class_idx=None):
    """
    Generate synthetic images and save them individually
    
    Args:
        generator: trained generator model
        num_images: number of images to generate
        save_dir: directory to save images
        gan_type: 'vanilla', 'dcgan', or 'cgan'
        minority_class_idx: class index for CGAN
    """
    generator.eval()
    batch_size = 64  # Generate in batches to avoid memory issues
    num_batches = (num_images + batch_size - 1) // batch_size
    
    print(f"\n{'='*60}")
    print(f"Generating {num_images} images for {gan_type.upper()}")
    print(f"{'='*60}")
    
    img_count = 0
    
    with torch.no_grad():
        for batch_idx in range(num_batches):
            # Calculate batch size for last batch
            current_batch_size = min(batch_size, num_images - img_count)
            
            # Generate noise
            z = torch.randn(current_batch_size, latent_dim).to(device)
            
            # Generate images
            if gan_type == 'cgan':
                labels = torch.full((current_batch_size,), minority_class_idx, dtype=torch.long).to(device)
                fake_imgs = generator(z, labels)
            else:
                fake_imgs = generator(z)
            
            # Save each image individually
            for i in range(current_batch_size):
                img = fake_imgs[i]
                # Denormalize from [-1, 1] to [0, 1]
                img = (img + 1) / 2
                img = torch.clamp(img, 0, 1)
                
                # Save image
                img_path = os.path.join(save_dir, f'synthetic_{img_count + 4800 :05d}.png')
                save_image(img, img_path)
                img_count += 1
            
            if (batch_idx + 1) % 10 == 0:
                print(f"Generated {img_count}/{num_images} images...")
    
    print(f"‚úì Complete! Generated {img_count} images in {save_dir}")
    return img_count

In [21]:
num_synthetic = 900
minority_class_idx = 0

# Generate for Vanilla GAN
count_vanilla = generate_synthetic_images(
    vanilla_gen, 
    num_synthetic, 
    'synthetic_data/vanilla_gan',
    gan_type='vanilla'
)

# Generate for DCGAN
count_dcgan = generate_synthetic_images(
    dcgan_gen,
    num_synthetic,
    'synthetic_data/dcgan',
    gan_type='dcgan'
)

# Generate for CGAN (class 2 = Tuberculosis)
count_cgan = generate_synthetic_images(
    cgan_gen,
    num_synthetic,
    'synthetic_data/cgan',
    gan_type='cgan',
    minority_class_idx=minority_class_idx
)

print(f"\n{'='*60}")
print(f"GENERATION SUMMARY")
print(f"{'='*60}")
print(f"Vanilla GAN: {count_vanilla} images")
print(f"DCGAN: {count_dcgan} images")
print(f"CGAN: {count_cgan} images")
print(f"Total synthetic images: {count_vanilla + count_dcgan + count_cgan}")
# the output is after adding 4800 already before


Generating 900 images for VANILLA
Generated 640/900 images...
‚úì Complete! Generated 900 images in synthetic_data/vanilla_gan

Generating 900 images for DCGAN
Generated 640/900 images...
‚úì Complete! Generated 900 images in synthetic_data/dcgan

Generating 900 images for CGAN
Generated 640/900 images...
‚úì Complete! Generated 900 images in synthetic_data/cgan

GENERATION SUMMARY
Vanilla GAN: 900 images
DCGAN: 900 images
CGAN: 900 images
Total synthetic images: 2700
