**Importing necessary modules**

In [None]:
import torch
import torch.nn as nn 
import torch.optim as optim 
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import torchvision.datasets as datasets
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import torchvision.utils as vutils

**Setting up Hyperparameters**

In [None]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

learning_rate=5e-5
batch_size=64
image_size=64
img_channels=3
noise_dim=100
num_epochs=5
features_disc=64
features_gen=64
critic_iterations=5
weight_clip=0.01


**Building the discriminator**

In [None]:
class Discriminator(nn.Module):
    def __init__(self,img_channels,features_d):
        super(Discriminator,self).__init__() 
        
        self.disc=nn.Sequential(
            nn.Conv2d(img_channels,features_d,kernel_size=4,stride=2,padding=1),
            nn.LeakyReLU(0.2),
            
            self._block(features_d,features_d*2,4,2,1),
            self._block(features_d*2,features_d*4,4,2,1),
            self._block(features_d*4,features_d*8,4,2,1),
            
            nn.Conv2d(features_d*8,1,kernel_size=4,stride=2,padding=0),
        )
        
    def _block(self,in_channels,out_channels,kernel_size,stride,padding):
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,bias=False,),
            nn.LeakyReLU(0.2),
        )
        
    def forward(self,x):
        return self.disc(x)

**Building the Generator**

In [3]:
class Generator(nn.Module):
    def __init__(self,channels_noise,img_channels,features_g):
        super(Generator,self).__init__()
        
        self.gen=nn.Sequential(
            self._block(channels_noise,features_g*16,4,1,0),
            self._block(features_g*16,features_g*8,4,2,1),
            self._block(features_g*8,features_g*4,4,2,1),
            self._block(features_g*4,features_g*2,4,2,1),
            
            nn.ConvTranspose2d(features_g*2,img_channels,kernel_size=4,stride=2,padding=1),
            
            nn.Tanh(),
        )
    
    def _block(self,in_channels,out_channels,kernel_size,stride,padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,bias=False),
            nn.ReLU(),
        ) 
        
    def forward(self,x):
        return self.gen(x)

**Initialising weights**

In [4]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)
            

**Loading dataset and setting up transforms**

In [None]:
transforms=transforms.Compose(
    [
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(img_channels)],
            [0.5 for _ in range(img_channels)]
        ),
    ]
)

"""Loading the CelebA dataset with optimizations"""

# Load the full CelebA dataset
celeba_dataset = datasets.ImageFolder(
    root="/kaggle/input/celeba-dataset/img_align_celeba", 
    transform=transforms
)

# Optimized DataLoader with more workers and prefetching
dataLoader=DataLoader(
    dataset=celeba_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False,
    prefetch_factor=2,
    persistent_workers=True
)

# Verify dataset channels
print(f"Dataset image channels: {img_channels}")
print(f"Dataset size: {len(celeba_dataset)}")
print(f"Sample batch shape check:")
sample_batch, _ = next(iter(dataLoader))
print(f"Sample batch shape: {sample_batch.shape}")

# Force CPU to work by doing a quick pass through the data
print("Warming up data loading...")
for i, (batch, _) in enumerate(dataLoader):
    if i >= 5:
        break
print("Data loading optimized!")

Success, tests passed!


**Initialising the networks,optimizers and datasets**

In [None]:
"""Initialising the networks"""

gen=Generator(noise_dim,img_channels=img_channels,features_g=features_gen).to(device=device)

disc=Discriminator(img_channels=img_channels,features_d=features_disc).to(device=device)

if torch.cuda.device_count() > 1:
    gen = nn.DataParallel(gen)
    disc = nn.DataParallel(disc)

# Test the models with sample input
test_noise=torch.randn(1,noise_dim,1,1).to(device=device)
test_real=torch.randn(1,img_channels,image_size,image_size).to(device=device)

print(f"Generator output shape: {gen(test_noise).shape}")
print(f"Discriminator real output shape: {disc(test_real).shape}")
print(f"Discriminator fake output shape: {disc(gen(test_noise)).shape}")

"""Initialising the optimizers""" 

opt_disc=optim.RMSprop(disc.parameters(),lr=learning_rate)

opt_gen=optim.RMSprop(gen.parameters(),lr=learning_rate)

"""Initialising the fixed noise"""
fixed_noise=torch.randn(32,noise_dim,1,1).to(device=device)

In [None]:
step=0
disc.train()
gen.train()

**Training loop**

In [None]:
os.makedirs("outputs", exist_ok=True)

best_gen_loss = float('inf')
best_disc_loss = float('inf')

total_steps = num_epochs * len(dataLoader)
pbar = tqdm(total=total_steps, desc='Training', dynamic_ncols=True)

for epoch in range(num_epochs):
    epoch_loss_critic = 0.0
    epoch_loss_gen = 0.0
    
    for batch_idx,(real,_) in enumerate(dataLoader):
        real=real.to(device)
        bsz = real.size(0)
        
        for _ in range(critic_iterations):
            noise=torch.randn(bsz,noise_dim,1,1).to(device)
            fake=gen(noise)
            
            critic_real=disc(real).reshape(-1)
            critic_fake=disc(fake.detach()).reshape(-1)
            loss_critic= -(torch.mean(critic_real)- torch.mean(critic_fake))
            disc.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_disc.step()
            
            for p in disc.parameters():
                p.data.clamp_(-weight_clip,weight_clip)
                
        epoch_loss_critic += loss_critic.item()
                
        output=disc(fake).reshape(-1)
        loss_gen=-(torch.mean(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        epoch_loss_gen += loss_gen.item()
        
        pbar.set_postfix({
            'epoch': epoch+1,
            'Loss D': f'{loss_critic.item():.4f}',
            'Loss G': f'{loss_gen.item():.4f}'
        })
        pbar.update(1)
    
    avg_loss_critic = epoch_loss_critic / len(dataLoader)
    avg_loss_gen = epoch_loss_gen / len(dataLoader)
    
    with torch.no_grad():
        fake_img = gen(fixed_noise)
        real_grid = vutils.make_grid(real[:16], nrow=4, normalize=True)
        fake_grid = vutils.make_grid(fake_img[:16], nrow=4, normalize=True)
        fig, axes = plt.subplots(1, 2, figsize=(8, 4))
        axes[0].imshow(real_grid.permute(1, 2, 0).cpu())
        axes[0].axis('off')
        axes[0].set_title('Real')
        axes[1].imshow(fake_grid.permute(1, 2, 0).cpu())
        axes[1].axis('off')
        axes[1].set_title('Fake')
        fig.savefig(f'outputs/epoch_{epoch+1}_real_vs_fake.png', bbox_inches='tight', pad_inches=0)
        plt.close(fig)
    
    # Safety: ensure models were not accidentally overwritten by a tensor
    if not isinstance(gen, nn.Module) or not isinstance(disc, nn.Module):
        raise RuntimeError(f"Generator or Discriminator overwritten. Types: gen={type(gen)}, disc={type(disc)}")
    
    # Explicit cast so static analysis recognizes modules; prevents Tensor shadowing issues
    from typing import cast
    gen_to_save = cast(nn.Module, gen.module if hasattr(gen, 'module') else gen)
    disc_to_save = cast(nn.Module, disc.module if hasattr(disc, 'module') else disc)
    assert isinstance(gen_to_save, nn.Module) and isinstance(disc_to_save, nn.Module), \
        f"Unexpected types: gen_to_save={type(gen_to_save)}, disc_to_save={type(disc_to_save)}"
    
    if avg_loss_gen < best_gen_loss:
        best_gen_loss = avg_loss_gen
        torch.save(gen_to_save.state_dict(), 'outputs/best_generator.pth')
    
    if avg_loss_critic < best_disc_loss:
        best_disc_loss = avg_loss_critic
        torch.save(disc_to_save.state_dict(), 'outputs/best_discriminator.pth')

pbar.close()