In [2]:
import os
import torch
import torchvision
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from glob import glob
import torchvision.transforms as T
import torch.nn as nn
from tqdm import tqdm
from torch.autograd import grad
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torchvision.datasets import ImageFolder
from PIL import UnidentifiedImageError

In [3]:
!nvidia-smi

Tue Jun 17 13:46:26 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.94                 Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4070 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   47C    P8              3W /  140W |    1365MiB /   8188MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
import torch
print(torch.cuda.is_available())  # Should print True
print(torch.cuda.get_device_name(0))

True
NVIDIA GeForce RTX 4070 Laptop GPU


In [5]:
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 128
z_dim = 128
batch_size = 64
lambda_gp = 10
num_epochs = 100
lr = 1e-6

In [6]:
# Transforms
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

In [7]:
# Dataset
dataset_path = "C:/College/Projects/X-RayComparison/Data/train"  # Update this to the correct path
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)


In [8]:
class Generator(nn.Module):
    def __init__(self, z_dim=128, img_channels=1, features_g=64):
        super().__init__()
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(z_dim, features_g * 16, 4, 1, 0),  # 4x4
            nn.BatchNorm2d(features_g * 16),
            nn.ReLU(),

            nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1),  # 8x8
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(),

            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(),

            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1),  # 32x32
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(),

            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1),  # 64x64
            nn.BatchNorm2d(features_g),
            nn.ReLU(),

            nn.ConvTranspose2d(features_g, img_channels, 4, 2, 1),  # 128x128
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)


In [9]:
# Gradient Penalty
def gradient_penalty(critic, real, fake):
    batch_size, C, H, W = real.shape
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device).expand_as(real)
    interpolated = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)

    mixed_scores = critic(interpolated)
    gradient = grad(
        outputs=mixed_scores,
        inputs=interpolated,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(batch_size, -1)
    gp = ((gradient.norm(2, dim=1) - 1) ** 2).mean()
    return gp

In [12]:
# Initialize
gen = Generator(z_dim=z_dim).to(device)
critic = Discriminator().to(device)

opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=lr, betas=(0.0, 0.9))

NameError: name 'Critic' is not defined

In [11]:
class Discriminator(nn.Module):
    def __init__(self, img_channels=1, feature_d=64):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            # Input: N x 1 x 128 x 128
            nn.Conv2d(img_channels, feature_d, 4, 2, 1),  # 64x64
            nn.LeakyReLU(0.2),

            nn.Conv2d(feature_d, feature_d * 2, 4, 2, 1),  # 32x32
            nn.BatchNorm2d(feature_d * 2), # Corrected: should be feature_d * 2
            nn.LeakyReLU(0.2),

            nn.Conv2d(feature_d * 2, feature_d * 4, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(feature_d * 4),
            nn.LeakyReLU(0.2),

            nn.Conv2d(feature_d * 4, feature_d * 8, 4, 2, 1),  # 8x8
            nn.BatchNorm2d(feature_d * 8),
            nn.LeakyReLU(0.2),

            nn.Conv2d(feature_d * 8, feature_d * 16, 4, 2, 1),  # 4x4
            nn.BatchNorm2d(feature_d * 16),
            nn.LeakyReLU(0.2),

            nn.Conv2d(feature_d * 16, 1, 4, 1, 0),  # 1x1
        )

    def forward(self, x):
        return self.model(x).view(-1)

In [13]:
class SafeImageFolder(ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        try:
            sample = self.loader(path)
        except UnidentifiedImageError:
            # Replace with a black image or skip with random image
            sample = Image.new("L", (128, 128))  # Grayscale fallback
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target

In [14]:
def gradient_penalty(critic, real, fake, device):
    batch_size, c, h, w = real.shape
    alpha = torch.rand(batch_size, 1, 1, 1, device=device).expand_as(real)

    interpolated = (alpha * real + (1 - alpha) * fake).requires_grad_(True)

    critic_interpolated = critic(interpolated)
    grad_outputs = torch.ones_like(critic_interpolated, device=device)

    gradients = torch.autograd.grad(
        outputs=critic_interpolated,
        inputs=interpolated,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(batch_size, -1)
    grad_norm = gradients.norm(2, dim=1)
    gp = ((grad_norm - 1) ** 2).mean()
    return gp


In [63]:
import os
os.makedirs("generated_samples", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

In [64]:

critic_iterations = 5  # Number of Critic updates per Generator update
step = 0
fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)

for epoch in range(num_epochs):
    loop = tqdm(dataloader, leave=True)
    
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(device)
        cur_batch_size = real.size(0)

        # === Train Critic ===
        for _ in range(critic_iterations):
            noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
            fake = gen(noise)

            critic_real = critic(real).view(-1)
            critic_fake = critic(fake.detach()).view(-1)

            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = -torch.mean(critic_real) + torch.mean(critic_fake) + lambda_gp * gp

            opt_critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()

        # === Train Generator ===
        noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
        fake = gen(noise)
        output = critic(fake).view(-1)
        loss_gen = -torch.mean(output)

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # === Logging ===
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(critic_loss=loss_critic.item(), gen_loss=loss_gen.item())

        # === Save Sample Images ===
        if step % 500 == 0:
            with torch.no_grad():
                fake_samples = gen(fixed_noise)
                utils.save_image(fake_samples, f"generated_samples/sample_{step}.png", normalize=True, nrow=8)
        step += 1

    # === Save Model Checkpoints ===
    if (epoch + 1) % 10 == 0:
        torch.save(gen.state_dict(), f"checkpoints/gen_epoch_{epoch+1}.pth")
        torch.save(critic.state_dict(), f"checkpoints/critic_epoch_{epoch+1}.pth")

  0%|          | 0/576 [00:10<?, ?it/s]


RuntimeError: Given normalized_shape=[128, 64, 64], expected input with shape [*, 128, 64, 64], but got input of size[64, 128, 32, 32]

## The cells below are working finally 😭

In [18]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from torchvision.datasets import ImageFolder
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
from torch.autograd import grad

# ========== Setup ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 128
z_dim = 128
batch_size = 64
lambda_gp = 5
num_epochs = 100
lr = 1e-4  # Increased for faster and stable convergence
critic_iterations = 2

# ========== Transforms ==========
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# ========== Safe Dataset Loader ==========
class SafeImageFolder(ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        try:
            sample = self.loader(path)
        except UnidentifiedImageError:
            sample = Image.new("L", (128, 128))
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target

dataset_path = "C:/College/Projects/X-RayComparison/Data/train"
dataset = SafeImageFolder(root=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

# ========== Generator ==========
class Generator(nn.Module):
    def __init__(self, z_dim=128, img_channels=1, features_g=64):
        super().__init__()
        self.net = nn.Sequential(
            self._block(z_dim, features_g * 16, 4, 1, 0),  # 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # 32x32
            self._block(features_g * 2, features_g, 4, 2, 1),  # 64x64
            nn.ConvTranspose2d(features_g, img_channels, 4, 2, 1),  # 128x128
            nn.Tanh()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.net(x)

# ========== Critic ==========
class Critic(nn.Module):
    def __init__(self, img_channels=1, features_d=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, features_d, 4, 2, 1),  # 64x64
            nn.LeakyReLU(0.2, inplace=True),

            self._block(features_d, features_d * 2, 4, 2, 1),  # 32x32
            self._block(features_d * 2, features_d * 4, 4, 2, 1),  # 16x16
            self._block(features_d * 4, features_d * 8, 4, 2, 1),  # 8x8
            self._block(features_d * 8, features_d * 16, 4, 2, 1),  # 4x4

            nn.Conv2d(features_d * 16, 1, 4, 1, 0),  # 1x1
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        return self.net(x).view(-1)

# ========== Gradient Penalty ==========
def gradient_penalty(critic, real, fake, device):
    batch_size, c, h, w = real.shape
    alpha = torch.rand(batch_size, 1, 1, 1, device=device).expand_as(real)
    interpolated = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    mixed_scores = critic(interpolated)
    gradients = grad(outputs=mixed_scores, inputs=interpolated,
                     grad_outputs=torch.ones_like(mixed_scores),
                     create_graph=True, retain_graph=True)[0]
    gradients = gradients.view(batch_size, -1)
    return ((gradients.norm(2, dim=1) - 1) ** 2).mean()

# ========== Initialize ==========
gen = Generator(z_dim=z_dim).to(device)
critic = Critic().to(device)
opt_gen = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.0, 0.9))
opt_critic = torch.optim.Adam(critic.parameters(), lr=1e-4, betas=(0.0, 0.9))


os.makedirs("generated_samples", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)

# ========== Training Loop ==========
for epoch in range(num_epochs):
    loop = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(device)
        cur_batch_size = real.size(0)

        # === Train Critic ===
        for _ in range(critic_iterations):
            noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
            fake = gen(noise)
            critic_real = critic(real)
            critic_fake = critic(fake.detach())
            gp = gradient_penalty(critic, real, fake, device)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp

            opt_critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()

        # === Train Generator ===
        noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
        fake = gen(noise)
        output = critic(fake)
        loss_gen = -torch.mean(output)

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        loop.set_postfix(critic_loss=loss_critic.item(), gen_loss=loss_gen.item())

    with torch.no_grad():
        fake = gen(fixed_noise)
        utils.save_image(fake, f"generated_samples/epoch_{epoch+1}.png", normalize=True, nrow=8)
    torch.save(gen.state_dict(), f"generator_epoch_{epoch+1}.pth")
    torch.save(critic.state_dict(), f"critic_epoch_{epoch+1}.pth")

Epoch [1/100]: 100%|██████████| 576/576 [05:49<00:00,  1.65it/s, critic_loss=-19.6, gen_loss=68.2]
Epoch [2/100]: 100%|██████████| 576/576 [05:52<00:00,  1.64it/s, critic_loss=-15.7, gen_loss=123] 
Epoch [3/100]: 100%|██████████| 576/576 [05:52<00:00,  1.63it/s, critic_loss=-12.6, gen_loss=121] 
Epoch [4/100]: 100%|██████████| 576/576 [05:51<00:00,  1.64it/s, critic_loss=-8.92, gen_loss=107] 
Epoch [5/100]: 100%|██████████| 576/576 [05:51<00:00,  1.64it/s, critic_loss=-14.7, gen_loss=137] 
Epoch [6/100]: 100%|██████████| 576/576 [05:51<00:00,  1.64it/s, critic_loss=-20.9, gen_loss=100] 
Epoch [7/100]: 100%|██████████| 576/576 [05:51<00:00,  1.64it/s, critic_loss=-11, gen_loss=113]   
Epoch [8/100]: 100%|██████████| 576/576 [05:51<00:00,  1.64it/s, critic_loss=-16.4, gen_loss=110]
Epoch [9/100]: 100%|██████████| 576/576 [05:51<00:00,  1.64it/s, critic_loss=-27.8, gen_loss=136]
Epoch [10/100]: 100%|██████████| 576/576 [05:51<00:00,  1.64it/s, critic_loss=-22.8, gen_loss=138]
Epoch [11/10

KeyboardInterrupt: 

In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torchvision.utils as vutils
import torch.nn.utils as utils
import os
import random
from torch.autograd import grad

# Set random seed for reproducibility
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# Parameters
image_size = 128
batch_size = 64
z_dim = 128
lr_gen = 2e-4
lr_critic = 1e-4
beta1, beta2 = 0.0, 0.9
critic_iter = 2
lambda_gp = 5
num_epochs = 100

# Dataset and Dataloader
dataset = datasets.ImageFolder(
    root="./Data/train",
    transform=transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Generator
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.kaiming_normal_(m.weight)

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, 4, 1, 0),      # -> (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),         # -> (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),         # -> (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),          # -> (64, 32, 32)
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1),           # -> (32, 64, 64)
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.ConvTranspose2d(32, 1, 4, 2, 1),            # -> (1, 128, 128)
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)
    
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            utils.spectral_norm(nn.Conv2d(1, 64, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            utils.spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            utils.spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            utils.spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            utils.spectral_norm(nn.Conv2d(512, 1, 4, 1, 0))
        )

    def forward(self, x):
        return self.model(x)

# Gradient Penalty
def gradient_penalty(critic, real, fake, device):
    batch_size = real.size(0)
    epsilon = torch.rand((batch_size, 1, 1, 1), device=device)
    interpolated = epsilon * real + (1 - epsilon) * fake
    interpolated.requires_grad_(True)

    mixed_scores = critic(interpolated)
    grad_outputs = torch.ones_like(mixed_scores)

    gradients = grad(
        outputs=mixed_scores,
        inputs=interpolated,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.reshape(batch_size, -1)  # Fix here
    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gp



# Training Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen = Generator(z_dim).to(device)
gen.apply(weights_init)
critic = Critic().to(device)

opt_gen = optim.Adam(gen.parameters(), lr=lr_gen, betas=(beta1, beta2))
opt_critic = optim.Adam(critic.parameters(), lr=lr_critic, betas=(beta1, beta2))

fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)



In [14]:
noise = torch.randn(1, z_dim, 1, 1).to(device)
fake = gen(noise)
print(fake.shape)  # should be torch.Size([1, 1, 128, 128])


torch.Size([1, 1, 128, 128])


In [9]:
# Training Loop
for epoch in range(num_epochs):
    for i, (real, _) in enumerate(dataloader):
        real = real.to(device)
        real += 0.01 * torch.randn_like(real)  # Add small noise to real images
        cur_batch_size = real.size(0)

        # Train Critic
        for _ in range(critic_iter):
            noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
            fake = gen(noise)

            critic_real = critic(real).view(-1)
            critic_fake = critic(fake.detach()).view(-1)
            gp = gradient_penalty(critic, real, fake, device)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator
        noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
        fake = gen(noise)
        gen_loss = -torch.mean(critic(fake))

        gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch {i}/{len(dataloader)} \
                Loss D: {loss_critic.item():.4f}, loss G: {gen_loss.item():.4f}")

    with torch.no_grad():
        fake_images = gen(fixed_noise)
        vutils.save_image(fake_images.detach(), f"output/fake_epoch_{epoch:03d}.png", normalize=True)

Epoch [0/100] Batch 0/576                 Loss D: 2.2373, loss G: 0.4114
Epoch [0/100] Batch 100/576                 Loss D: -3.8755, loss G: 2.1199
Epoch [0/100] Batch 200/576                 Loss D: -3.0142, loss G: 2.6960
Epoch [0/100] Batch 300/576                 Loss D: -0.9504, loss G: 1.1503
Epoch [0/100] Batch 400/576                 Loss D: -0.3754, loss G: 0.2039


KeyboardInterrupt: 

In [10]:
# Training Loop with tqdm
for epoch in range(num_epochs):
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}")
    for i, (real, _) in pbar:
        real = real.to(device)
        real += 0.01 * torch.randn_like(real)  # Add small noise to real images
        cur_batch_size = real.size(0)

        # Train Critic
        for _ in range(critic_iter):
            noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
            fake = gen(noise)

            critic_real = critic(real).view(-1)
            critic_fake = critic(fake.detach()).view(-1)
            gp = gradient_penalty(critic, real, fake, device)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator
        noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
        fake = gen(noise)
        gen_loss = -torch.mean(critic(fake))

        gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        # tqdm progress
        pbar.set_postfix({
            'Loss_D': f"{loss_critic.item():.4f}",
            'Loss_G': f"{gen_loss.item():.4f}"
        })

    # Save generated samples after each epoch
    with torch.no_grad():
        fake_images = gen(fixed_noise)
        os.makedirs("output", exist_ok=True)
        vutils.save_image(fake_images.detach(), f"output/fake_epoch_{epoch:03d}.png", normalize=True)

Epoch 1/100: 100%|██████████| 576/576 [04:09<00:00,  2.31it/s, Loss_D=-0.1812, Loss_G=-0.5009]
Epoch 2/100: 100%|██████████| 576/576 [03:47<00:00,  2.53it/s, Loss_D=-0.0898, Loss_G=-0.9866]
Epoch 3/100: 100%|██████████| 576/576 [03:47<00:00,  2.53it/s, Loss_D=-0.0297, Loss_G=-0.9117]
Epoch 4/100: 100%|██████████| 576/576 [03:47<00:00,  2.53it/s, Loss_D=-0.1004, Loss_G=-1.3933]
Epoch 5/100: 100%|██████████| 576/576 [03:47<00:00,  2.54it/s, Loss_D=-0.1256, Loss_G=-1.9822]
Epoch 6/100: 100%|██████████| 576/576 [03:47<00:00,  2.53it/s, Loss_D=-0.3410, Loss_G=-1.4743]
Epoch 7/100: 100%|██████████| 576/576 [03:46<00:00,  2.54it/s, Loss_D=-0.1156, Loss_G=-0.8499]
Epoch 8/100: 100%|██████████| 576/576 [03:47<00:00,  2.54it/s, Loss_D=0.1167, Loss_G=-1.0588] 
Epoch 9/100: 100%|██████████| 576/576 [03:47<00:00,  2.54it/s, Loss_D=-0.0921, Loss_G=-0.6659]
Epoch 10/100: 100%|██████████| 576/576 [03:47<00:00,  2.53it/s, Loss_D=-0.0592, Loss_G=-1.2430]
Epoch 11/100: 100%|██████████| 576/576 [03:47<00:

In [16]:
start_epoch = 0
checkpoint_path = "checkpoints/checkpoint_epoch_150.pth"

# Resume from checkpoint if available
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    # print(checkpoint['critic_state_dict'])
    gen.load_state_dict(checkpoint['generator_state_dict'])
    critic.load_state_dict(checkpoint['critic_state_dict'])
    opt_gen.load_state_dict(checkpoint['opt_gen_state_dict'])
    opt_critic.load_state_dict(checkpoint['opt_critic_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from epoch {start_epoch}")
else:
    print("Starting training from scratch.")



  checkpoint = torch.load(checkpoint_path, map_location=device)


RuntimeError: Error(s) in loading state_dict for Generator:
	Missing key(s) in state_dict: "gen.0.weight", "gen.0.bias", "gen.1.weight", "gen.1.bias", "gen.1.running_mean", "gen.1.running_var", "gen.3.weight", "gen.3.bias", "gen.4.weight", "gen.4.bias", "gen.4.running_mean", "gen.4.running_var", "gen.6.weight", "gen.6.bias", "gen.7.weight", "gen.7.bias", "gen.7.running_mean", "gen.7.running_var", "gen.9.weight", "gen.9.bias", "gen.10.weight", "gen.10.bias", "gen.10.running_mean", "gen.10.running_var", "gen.12.weight", "gen.12.bias", "gen.13.weight", "gen.13.bias", "gen.13.running_mean", "gen.13.running_var", "gen.15.weight", "gen.15.bias". 
	Unexpected key(s) in state_dict: "model.0.weight", "model.0.bias", "model.1.weight", "model.1.bias", "model.1.running_mean", "model.1.running_var", "model.1.num_batches_tracked", "model.3.weight", "model.3.bias", "model.4.weight", "model.4.bias", "model.4.running_mean", "model.4.running_var", "model.4.num_batches_tracked", "model.6.weight", "model.6.bias", "model.7.weight", "model.7.bias", "model.7.running_mean", "model.7.running_var", "model.7.num_batches_tracked", "model.9.weight", "model.9.bias", "model.10.weight", "model.10.bias", "model.10.running_mean", "model.10.running_var", "model.10.num_batches_tracked", "model.12.weight", "model.12.bias", "model.13.weight", "model.13.bias", "model.13.running_mean", "model.13.running_var", "model.13.num_batches_tracked", "model.15.weight", "model.15.bias". 

In [16]:
import torch
import torchvision.utils as vutils
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
import os

# === Setup ===
fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)
total_epochs = start_epoch + 50
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("generated_samples", exist_ok=True)
os.makedirs("output", exist_ok=True)

for epoch in range(start_epoch, total_epochs):
    loop = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch [{epoch+1}/{total_epochs}]")
    gen.train()
    critic.train()

    for batch_idx, (real, _) in loop:
        real = real.to(device)
        real += 0.01 * torch.randn_like(real)  # Label smoothing
        cur_batch_size = real.size(0)

        # === Train Critic ===
        for _ in range(critic_iterations):
            noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
            fake = gen(noise)

            critic_real = critic(real).view(-1)
            critic_fake = critic(fake.detach()).view(-1)

            gp = gradient_penalty(critic, real, fake, device)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp

            opt_critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()

        # === Train Generator ===
        noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
        fake = gen(noise)
        gen_loss = -torch.mean(critic(fake))

        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        loop.set_postfix(critic_loss=loss_critic.item(), gen_loss=gen_loss.item())

    # === Save generated image every epoch ===
    with torch.no_grad():
        fake_images = gen(fixed_noise)
        vutils.save_image(fake_images, f"output/fake_epoch_{epoch+1:03d}.png", normalize=True, nrow=8)

        # Also save for generated_samples (optional pretty grid format)
        fake_grid = make_grid(fake_images, normalize=True, nrow=8)
        save_image(fake_grid, f"generated_samples/sample_epoch_{epoch+1:03d}.png")

    # === Save checkpoint every 10 epochs ===
    if (epoch + 1) % 10 == 0:
        torch.save(gen.state_dict(), f"checkpoints/generator_epoch_{epoch+1}.pth")
        torch.save(critic.state_dict(), f"checkpoints/critic_epoch_{epoch+1}.pth")

        torch.save({
            'epoch': epoch,
            'generator_state_dict': gen.state_dict(),
            'critic_state_dict': critic.state_dict(),
            'opt_gen_state_dict': opt_gen.state_dict(),
            'opt_critic_state_dict': opt_critic.state_dict(),
        }, f"checkpoints/checkpoint_epoch_{epoch+1}.pth")

        print(f"Checkpoint saved at epoch {epoch+1}")


Epoch [101/150]: 100%|██████████| 576/576 [03:46<00:00,  2.55it/s, critic_loss=-0.0275, gen_loss=0.826]    
Epoch [102/150]: 100%|██████████| 576/576 [03:35<00:00,  2.67it/s, critic_loss=-0.0659, gen_loss=0.741]   
Epoch [103/150]: 100%|██████████| 576/576 [03:39<00:00,  2.63it/s, critic_loss=-0.0964, gen_loss=0.939]  
Epoch [104/150]: 100%|██████████| 576/576 [03:50<00:00,  2.50it/s, critic_loss=-0.0534, gen_loss=0.379]   
Epoch [105/150]: 100%|██████████| 576/576 [03:55<00:00,  2.45it/s, critic_loss=-0.0693, gen_loss=0.349]   
Epoch [106/150]: 100%|██████████| 576/576 [04:49<00:00,  1.99it/s, critic_loss=-0.157, gen_loss=0.0063]  
Epoch [107/150]: 100%|██████████| 576/576 [03:34<00:00,  2.68it/s, critic_loss=-0.0463, gen_loss=0.356]  
Epoch [108/150]: 100%|██████████| 576/576 [03:35<00:00,  2.67it/s, critic_loss=-0.152, gen_loss=0.81]    
Epoch [109/150]: 100%|██████████| 576/576 [03:36<00:00,  2.66it/s, critic_loss=-0.0836, gen_loss=0.902]  
Epoch [110/150]: 100%|██████████| 576/576

Checkpoint saved at epoch 110


Epoch [111/150]: 100%|██████████| 576/576 [03:35<00:00,  2.67it/s, critic_loss=-0.0456, gen_loss=0.773]  
Epoch [112/150]: 100%|██████████| 576/576 [03:34<00:00,  2.69it/s, critic_loss=-0.244, gen_loss=0.215]   
Epoch [113/150]: 100%|██████████| 576/576 [03:35<00:00,  2.68it/s, critic_loss=-0.0533, gen_loss=0.512]  
Epoch [114/150]: 100%|██████████| 576/576 [03:38<00:00,  2.64it/s, critic_loss=0.0377, gen_loss=0.978]   
Epoch [115/150]: 100%|██████████| 576/576 [03:44<00:00,  2.56it/s, critic_loss=-0.0321, gen_loss=0.843]  
Epoch [116/150]: 100%|██████████| 576/576 [03:38<00:00,  2.63it/s, critic_loss=-0.0742, gen_loss=0.79]   
Epoch [117/150]: 100%|██████████| 576/576 [03:34<00:00,  2.69it/s, critic_loss=-0.0543, gen_loss=1.01]    
Epoch [118/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=-0.116, gen_loss=1.02]    
Epoch [119/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=0.0153, gen_loss=0.584]   
Epoch [120/150]: 100%|██████████| 576/576 [03

Checkpoint saved at epoch 120


Epoch [121/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=-0.0346, gen_loss=0.838]  
Epoch [122/150]: 100%|██████████| 576/576 [03:34<00:00,  2.69it/s, critic_loss=-0.00544, gen_loss=0.567] 
Epoch [123/150]: 100%|██████████| 576/576 [03:34<00:00,  2.69it/s, critic_loss=-0.0951, gen_loss=0.821]  
Epoch [124/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=-0.0611, gen_loss=0.483]  
Epoch [125/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=-0.0722, gen_loss=1.03]   
Epoch [126/150]: 100%|██████████| 576/576 [03:34<00:00,  2.68it/s, critic_loss=-0.0205, gen_loss=0.59]   
Epoch [127/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=-0.0372, gen_loss=0.544]  
Epoch [128/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=0.00157, gen_loss=1.18]   
Epoch [129/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=-0.0671, gen_loss=1.06]     
Epoch [130/150]: 100%|██████████| 576/576 [0

Checkpoint saved at epoch 130


Epoch [131/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=0.0335, gen_loss=0.69]    
Epoch [132/150]: 100%|██████████| 576/576 [03:32<00:00,  2.71it/s, critic_loss=-0.12, gen_loss=0.716]    
Epoch [133/150]: 100%|██████████| 576/576 [03:32<00:00,  2.71it/s, critic_loss=-0.11, gen_loss=1.28]     
Epoch [134/150]: 100%|██████████| 576/576 [03:33<00:00,  2.70it/s, critic_loss=0.0386, gen_loss=0.701]   
Epoch [135/150]: 100%|██████████| 576/576 [03:40<00:00,  2.61it/s, critic_loss=-0.00273, gen_loss=0.28]  
Epoch [136/150]: 100%|██████████| 576/576 [03:41<00:00,  2.60it/s, critic_loss=-0.0863, gen_loss=0.601]   
Epoch [137/150]: 100%|██████████| 576/576 [03:36<00:00,  2.66it/s, critic_loss=-0.17, gen_loss=0.461]    
Epoch [138/150]: 100%|██████████| 576/576 [04:09<00:00,  2.31it/s, critic_loss=-0.0235, gen_loss=0.848]  
Epoch [139/150]: 100%|██████████| 576/576 [05:34<00:00,  1.72it/s, critic_loss=-0.0407, gen_loss=0.614]  
Epoch [140/150]: 100%|██████████| 576/576 [04

Checkpoint saved at epoch 140


Epoch [141/150]: 100%|██████████| 576/576 [04:01<00:00,  2.39it/s, critic_loss=-0.0329, gen_loss=0.452]  
Epoch [142/150]: 100%|██████████| 576/576 [03:44<00:00,  2.57it/s, critic_loss=0.0326, gen_loss=0.803]   
Epoch [143/150]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=-0.181, gen_loss=0.909]    
Epoch [144/150]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.0974, gen_loss=0.708]  
Epoch [145/150]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=-0.0284, gen_loss=0.581]  
Epoch [146/150]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=-0.107, gen_loss=0.429]     
Epoch [147/150]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=-0.0979, gen_loss=0.608]  
Epoch [148/150]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=0.0423, gen_loss=0.583]   
Epoch [149/150]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=-0.0502, gen_loss=0.558] 
Epoch [150/150]: 100%|██████████| 576/576 [0

Checkpoint saved at epoch 150


In [35]:
start_epoch = 0
checkpoint_path = "checkpoints/checkpoint_epoch_150.pth"

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    try:
        gen.load_state_dict(checkpoint['generator_state_dict'])
        critic.load_state_dict(checkpoint['critic_state_dict'])
        opt_gen.load_state_dict(checkpoint['opt_gen_state_dict'])
        opt_critic.load_state_dict(checkpoint['opt_critic_state_dict'])
    except RuntimeError as e:
        print("Error loading model states. Check model architecture and z_dim.")
        print(str(e))
        raise
    
    gen.to(device)
    critic.to(device)
    
    start_epoch = checkpoint.get('epoch', 0) + 1

    # Optionally restore random state if you saved it
    if 'random_state' in checkpoint:
        torch.set_rng_state(checkpoint['random_state'])

    print(f"✅ Resumed training from epoch {start_epoch}")
    
    # If you stored loss values:
    if 'gen_loss' in checkpoint and 'critic_loss' in checkpoint:
        print(f"Previous losses - Generator: {checkpoint['gen_loss']:.4f}, Critic: {checkpoint['critic_loss']:.4f}")
else:
    print("🚀 Starting training from scratch.")


✅ Resumed training from epoch 150


  checkpoint = torch.load(checkpoint_path, map_location=device)


In [39]:
import torch
import torchvision.utils as vutils
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
import os

# === Setup ===
checkpoint_epoch = 150
checkpoint_path = f"checkpoints/checkpoint_epoch_{checkpoint_epoch}.pth"
gen_path = f"checkpoints/generator_epoch_{checkpoint_epoch}.pth"
critic_path = f"checkpoints/critic_epoch_{checkpoint_epoch}.pth"

resume = True
z_dim = 128  # or your value
fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)
total_epochs = checkpoint_epoch + 50  # Run 50 more epochs

# Create necessary directories
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("generated_samples", exist_ok=True)
os.makedirs("output", exist_ok=True)

# === Resume Checkpoint ===
if resume and os.path.exists(checkpoint_path):
    print(f"Resuming from checkpoint at epoch {checkpoint_epoch}")

    # Load model weights
    gen.load_state_dict(torch.load(gen_path, map_location=device))
    critic.load_state_dict(torch.load(critic_path, map_location=device))

    # Load optimizer states and epoch
    checkpoint = torch.load(checkpoint_path, map_location=device)
    opt_gen.load_state_dict(checkpoint['opt_gen_state_dict'])
    opt_critic.load_state_dict(checkpoint['opt_critic_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # resume from next epoch
else:
    print("No checkpoint found. Starting from scratch.")
    start_epoch = 0

# === Training Loop ===
for epoch in range(start_epoch, total_epochs):
    loop = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch [{epoch+1}/{total_epochs}]")
    gen.train()
    critic.train()

    for batch_idx, (real, _) in loop:
        real = real.to(device)
        real += 0.01 * torch.randn_like(real)  # Label smoothing
        cur_batch_size = real.size(0)

        # === Train Critic ===
        for _ in range(critic_iterations):
            noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
            fake = gen(noise)

            critic_real = critic(real).view(-1)
            critic_fake = critic(fake.detach()).view(-1)

            gp = gradient_penalty(critic, real, fake, device)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp

            opt_critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()

        # === Train Generator ===
        noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
        fake = gen(noise)
        gen_loss = -torch.mean(critic(fake))

        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        loop.set_postfix(critic_loss=loss_critic.item(), gen_loss=gen_loss.item())

    # === Save generated image every epoch ===
    with torch.no_grad():
        fake_images = gen(fixed_noise)
        vutils.save_image(fake_images, f"output/fake_epoch_{epoch+1:03d}.png", normalize=True, nrow=8)

        fake_grid = make_grid(fake_images, normalize=True, nrow=8)
        save_image(fake_grid, f"generated_samples/sample_epoch_{epoch+1:03d}.png")

    # === Save checkpoint every 10 epochs ===
    if (epoch + 1) % 10 == 0:
        torch.save(gen.state_dict(), f"checkpoints/generator_epoch_{epoch+1}.pth")
        torch.save(critic.state_dict(), f"checkpoints/critic_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch,
            'generator_state_dict': gen.state_dict(),
            'critic_state_dict': critic.state_dict(),
            'opt_gen_state_dict': opt_gen.state_dict(),
            'opt_critic_state_dict': opt_critic.state_dict(),
        }, f"checkpoints/checkpoint_epoch_{epoch+1}.pth")

        print(f"Checkpoint saved at epoch {epoch+1}")


  gen.load_state_dict(torch.load(gen_path, map_location=device))
  critic.load_state_dict(torch.load(critic_path, map_location=device))
  checkpoint = torch.load(checkpoint_path, map_location=device)


Resuming from checkpoint at epoch 150


Epoch [151/200]: 100%|██████████| 576/576 [03:55<00:00,  2.45it/s, critic_loss=0.0155, gen_loss=0.801]   
Epoch [152/200]: 100%|██████████| 576/576 [03:28<00:00,  2.76it/s, critic_loss=-0.146, gen_loss=0.407]     
Epoch [153/200]: 100%|██████████| 576/576 [03:28<00:00,  2.76it/s, critic_loss=-0.108, gen_loss=0.911]   
Epoch [154/200]: 100%|██████████| 576/576 [03:28<00:00,  2.76it/s, critic_loss=0.03, gen_loss=0.512]     
Epoch [155/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.0979, gen_loss=0.754]  
Epoch [156/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.0131, gen_loss=0.474]   
Epoch [157/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.078, gen_loss=0.513]    
Epoch [158/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.096, gen_loss=1.3]     
Epoch [159/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.14, gen_loss=1.06]       
Epoch [160/200]: 100%|██████████| 576/57

Checkpoint saved at epoch 160


Epoch [161/200]: 100%|██████████| 576/576 [03:31<00:00,  2.73it/s, critic_loss=0.00757, gen_loss=0.813]  
Epoch [162/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.0768, gen_loss=0.583]  
Epoch [163/200]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=-0.0737, gen_loss=0.955]    
Epoch [164/200]: 100%|██████████| 576/576 [03:30<00:00,  2.73it/s, critic_loss=-0.16, gen_loss=0.838]    
Epoch [165/200]: 100%|██████████| 576/576 [03:28<00:00,  2.76it/s, critic_loss=-0.035, gen_loss=0.88]    
Epoch [166/200]: 100%|██████████| 576/576 [03:28<00:00,  2.76it/s, critic_loss=0.0304, gen_loss=0.974]   
Epoch [167/200]: 100%|██████████| 576/576 [03:29<00:00,  2.76it/s, critic_loss=-0.013, gen_loss=0.547]   
Epoch [168/200]: 100%|██████████| 576/576 [03:28<00:00,  2.76it/s, critic_loss=-0.094, gen_loss=0.68]    
Epoch [169/200]: 100%|██████████| 576/576 [03:29<00:00,  2.76it/s, critic_loss=-0.0574, gen_loss=0.835]  
Epoch [170/200]: 100%|██████████| 576/576 [0

Checkpoint saved at epoch 170


Epoch [171/200]: 100%|██████████| 576/576 [03:29<00:00,  2.74it/s, critic_loss=-0.00623, gen_loss=0.761] 
Epoch [172/200]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=-0.0896, gen_loss=-0.00651]
Epoch [173/200]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=0.00226, gen_loss=0.521]  
Epoch [174/200]: 100%|██████████| 576/576 [03:29<00:00,  2.74it/s, critic_loss=-0.0328, gen_loss=0.553]   
Epoch [175/200]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=-0.0203, gen_loss=0.727]  
Epoch [176/200]: 100%|██████████| 576/576 [03:29<00:00,  2.74it/s, critic_loss=-0.0688, gen_loss=0.703]   
Epoch [177/200]: 100%|██████████| 576/576 [03:31<00:00,  2.72it/s, critic_loss=-0.0876, gen_loss=0.585]  
Epoch [178/200]: 100%|██████████| 576/576 [03:31<00:00,  2.72it/s, critic_loss=-0.076, gen_loss=0.703]   
Epoch [179/200]: 100%|██████████| 576/576 [03:29<00:00,  2.74it/s, critic_loss=0.00589, gen_loss=0.818]  
Epoch [180/200]: 100%|██████████| 576/576 [

Checkpoint saved at epoch 180


Epoch [181/200]: 100%|██████████| 576/576 [03:32<00:00,  2.71it/s, critic_loss=0.0049, gen_loss=0.623]   
Epoch [182/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.186, gen_loss=0.485]   
Epoch [183/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.103, gen_loss=0.685]   
Epoch [184/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.0805, gen_loss=1.04]   
Epoch [185/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=0.00774, gen_loss=0.789]  
Epoch [186/200]: 100%|██████████| 576/576 [03:29<00:00,  2.74it/s, critic_loss=-0.0116, gen_loss=0.884]  
Epoch [187/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.0756, gen_loss=0.606]  
Epoch [188/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.0371, gen_loss=0.989]  
Epoch [189/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.116, gen_loss=1.56]    
Epoch [190/200]: 100%|██████████| 576/576 [03:

Checkpoint saved at epoch 190


Epoch [191/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.0772, gen_loss=0.975]  
Epoch [192/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=0.00364, gen_loss=1.02]   
Epoch [193/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.00239, gen_loss=0.899]  
Epoch [194/200]: 100%|██████████| 576/576 [03:31<00:00,  2.72it/s, critic_loss=-0.012, gen_loss=0.99]    
Epoch [195/200]: 100%|██████████| 576/576 [03:31<00:00,  2.72it/s, critic_loss=-0.0467, gen_loss=0.175]  
Epoch [196/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.0579, gen_loss=0.454]  
Epoch [197/200]: 100%|██████████| 576/576 [03:30<00:00,  2.74it/s, critic_loss=-0.0455, gen_loss=0.59]   
Epoch [198/200]: 100%|██████████| 576/576 [03:32<00:00,  2.72it/s, critic_loss=-0.112, gen_loss=0.813]   
Epoch [199/200]: 100%|██████████| 576/576 [03:29<00:00,  2.75it/s, critic_loss=-0.036, gen_loss=0.656]   
Epoch [200/200]: 100%|██████████| 576/576 [03

Checkpoint saved at epoch 200


In [37]:
print(gen.model[0])  # should show ConvTranspose2d with in_channels = z_dim

ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1))


In [19]:
import torch
import os
from torchvision.utils import save_image
from tqdm import tqdm
from torch import nn

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator definition (must match your architecture)
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, 4, 1, 0),      # -> (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),         # -> (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),         # -> (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),          # -> (64, 32, 32)
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1),           # -> (32, 64, 64)
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.ConvTranspose2d(32, 1, 4, 2, 1),            # -> (1, 128, 128)
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Configuration
z_dim = 128
checkpoint_path = r"checkpoints\generator_epoch_200.pth"
output_dir = "generated_images"
num_images = 1000
batch_size = 64

# Load the generator
gen = Generator(z_dim).to(device)
gen.load_state_dict(torch.load(checkpoint_path, map_location=device))
gen.eval()

# Make output folder
os.makedirs(output_dir, exist_ok=True)

# Generate and save images
with torch.no_grad():
    for i in tqdm(range(0, num_images, batch_size), desc="Generating Images"):
        cur_batch_size = min(batch_size, num_images - i)
        noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
        fake_images = gen(noise)
        fake_images = (fake_images + 1) / 2  # if Tanh was used

        for j in range(cur_batch_size):
            save_path = os.path.join(output_dir, f"image_{i + j:04d}.png")
            save_image(fake_images[j], save_path)


  gen.load_state_dict(torch.load(checkpoint_path, map_location=device))
Generating Images: 100%|██████████| 16/16 [00:03<00:00,  4.01it/s]
