**Imports**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import torchvision.utils as vutils


**Hyperparameters**

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 100
FEATURES_CRITIC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

**Gradient Penalty**

In [None]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake.detach() * (1 - alpha)
    interpolated_images.requires_grad_(True)
    mixed_scores = critic(interpolated_images)
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

**Discriminator (Critic)**

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

**Generator**

In [None]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            self._block(channels_noise, features_g * 16, 4, 1, 0),
            self._block(features_g * 16, features_g * 8, 4, 2, 1),
            self._block(features_g * 8, features_g * 4, 4, 2, 1),
            self._block(features_g * 4, features_g * 2, 4, 2, 1),
            nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

**Weight Initialization**

In [None]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

**Dataset Transforms**

In [None]:
transforms_pipeline = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * CHANNELS_IMG, [0.5] * CHANNELS_IMG),
])

Success, tests passed!


**Setup Networks and Checkpoints**

In [None]:
loader = DataLoader(
    datasets.ImageFolder(root="/kaggle/input/celeba-dataset/img_align_celeba", transform=transforms_pipeline),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

print("Verifying dataset dimensions...")
sample_batch = next(iter(loader))
print(f"Real batch shape: {sample_batch[0].shape}")
del sample_batch

gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

print("Verifying generator output...")
test_noise = torch.randn(1, Z_DIM, 1, 1).to(device)
test_fake = gen(test_noise)
print(f"Fake batch shape: {test_fake.shape}")
del test_noise, test_fake

if torch.cuda.device_count() > 1:
    gen = nn.DataParallel(gen)
    critic = nn.DataParallel(critic)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
step = 0

gen.train()
critic.train()

os.makedirs("checkpoints", exist_ok=True)
os.makedirs("outputs", exist_ok=True)

best_gen_loss = float('inf')
best_critic_loss = float('inf')
start_epoch = 0
checkpoint_path = "checkpoints/latest_checkpoint.pth"
best_gen_path = "outputs/best_generator.pth"
best_critic_path = "outputs/best_critic.pth"

def get_model_state_dict(model):
    return model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()

def load_model_state_dict(model, state_dict):
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(state_dict)
    else:
        model.load_state_dict(state_dict)

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    load_model_state_dict(gen, checkpoint['gen_state'])
    load_model_state_dict(critic, checkpoint['critic_state'])
    opt_gen.load_state_dict(checkpoint['opt_gen_state'])
    opt_critic.load_state_dict(checkpoint['opt_critic_state'])
    start_epoch = checkpoint['epoch']
    best_gen_loss = checkpoint['best_gen_loss']
    best_critic_loss = checkpoint['best_critic_loss']
    step = checkpoint['step']
    print(f"Resumed from epoch {start_epoch}")
else:
    print("Starting fresh training")

**Training Configuration**

In [None]:
import time as time_module

epoch_times = []
training_start_time = time_module.time()

print(f"Device: {device} | GPUs: {torch.cuda.device_count()}")
print(f"Config: BS={BATCH_SIZE}, IMG={IMAGE_SIZE}x{IMAGE_SIZE}, Epochs={NUM_EPOCHS}")
print(f"Est. time: ~220h (~9.2 days) | Opt.: ~100-150h (~4.2-6.3 days)\n")

**Training Loop**

In [None]:
for epoch in range(start_epoch, NUM_EPOCHS):
    epoch_start_time = time_module.time()
    epoch_gen_losses = []
    epoch_critic_losses = []
    
    for batch_idx, (real, _) in enumerate(tqdm(loader)):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake.detach()).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()

        epoch_critic_losses.append(loss_critic.item())

        noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
        fake = gen(noise)
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        epoch_gen_losses.append(loss_gen.item())

        if batch_idx % 100 == 0 and batch_idx > 0:
            print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} | D: {loss_critic:.4f}, G: {loss_gen:.4f}")
            step += 1

    avg_gen_loss = sum(epoch_gen_losses) / len(epoch_gen_losses)
    avg_critic_loss = sum(epoch_critic_losses) / len(epoch_critic_losses)

    with torch.no_grad():
        fake = gen(fixed_noise)
        img_grid = torchvision.utils.make_grid(fake[:32], normalize=True)
        vutils.save_image(img_grid, f"outputs/epoch_{epoch:03d}.png")
    
    gen_state = get_model_state_dict(gen)
    critic_state = get_model_state_dict(critic)
    
    if avg_gen_loss < best_gen_loss:
        best_gen_loss = avg_gen_loss
        torch.save(gen_state, best_gen_path)
        print(f"✓ Best generator (loss: {best_gen_loss:.4f})")
    
    if avg_critic_loss < best_critic_loss:
        best_critic_loss = avg_critic_loss
        torch.save(critic_state, best_critic_path)
        print(f"✓ Best critic (loss: {best_critic_loss:.4f})")
    
    checkpoint = {
        'epoch': epoch + 1,
        'gen_state': gen_state,
        'critic_state': critic_state,
        'opt_gen_state': opt_gen.state_dict(),
        'opt_critic_state': opt_critic.state_dict(),
        'best_gen_loss': best_gen_loss,
        'best_critic_loss': best_critic_loss,
        'step': step,
    }
    torch.save(checkpoint, checkpoint_path)
    
    epoch_time = time_module.time() - epoch_start_time
    epoch_times.append(epoch_time)
    avg_epoch_time = sum(epoch_times) / len(epoch_times)
    remaining_epochs = NUM_EPOCHS - (epoch + 1)
    eta_seconds = remaining_epochs * avg_epoch_time
    eta_hours = eta_seconds / 3600
    eta_days = eta_hours / 24
    elapsed_time = time_module.time() - training_start_time
    elapsed_hours = elapsed_time / 3600
    
    print(f"Epoch [{epoch}/{NUM_EPOCHS}] | G: {avg_gen_loss:.4f} | D: {avg_critic_loss:.4f} | Time: {epoch_time/60:.1f}m | ETA: {eta_hours:.1f}h ({eta_days:.2f}d)\n")

print(f"\n✓ Training complete! Total: {(time_module.time() - training_start_time)/3600:.1f}h")
print(f"Best models: {best_gen_path} & {best_critic_path}")

**Visualize Training Progress**

In [None]:
import glob
from PIL import Image

sample_files = sorted(glob.glob("outputs/epoch_*.png"))
if sample_files:
    num_to_show = min(4, len(sample_files))
    fig, axes = plt.subplots(1, num_to_show, figsize=(15, 4))
    if num_to_show == 1:
        axes = [axes]
    for idx, sample_file in enumerate(sample_files[-num_to_show:]):
        img = Image.open(sample_file)
        axes[idx].imshow(img)
        epoch_num = sample_file.split('_')[-1].replace('.png', '')
        axes[idx].set_title(f"Epoch {epoch_num}")
        axes[idx].axis('off')
    plt.tight_layout()
    plt.show()