In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('CUDA is available. PyTorch will use the GPU.')
else:
    device = torch.device('cpu')
    print('CUDA is not available. PyTorch will use the CPU.')

CUDA is available. PyTorch will use the GPU.


In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.autograd import grad
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

In [3]:
# ================================
# Configuration
# ================================

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
IMAGE_SIZE = 256
Z_DIM = 128
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10
EPOCHS = 300
LEARNING_RATE = 1e-4
SAMPLE_DIR = r"D:\\Faisal\\WGAN\\generated_images"
METRICS_DIR = r"D:\\Faisal\\WGAN\\metrics"
CHECKPOINT_PATH = "D:\\Faisal\\WGAN\\checkpoint.pth"
start_epoch = 0

os.makedirs(SAMPLE_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

In [4]:
# ================================
# Dataset
# ================================

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

class MRIDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.png') or img.endswith('.jpg')]
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert('L')
        if self.transform:
            img = self.transform(img)
        return img

dataset = MRIDataset(r"D:\\Faisal\\Datasets\\tumordatasetnew\\NO-PREPROCESSED", transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [5]:
# ================================
# Models
# ================================

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 4*4*1024),
            nn.ReLU(),
            nn.Unflatten(1, (1024, 4, 4)),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Flatten()
        )

    def forward(self, x):
        return self.critic(x)


In [6]:
# ================================
# Gradient Penalty
# ================================

def gradient_penalty(critic, real, fake):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1), device=DEVICE)
    interpolated = real * epsilon + fake * (1 - epsilon)
    interpolated.requires_grad_(True)

    mixed_scores = critic(interpolated)
    gradient = grad(outputs=mixed_scores, inputs=interpolated,
                    grad_outputs=torch.ones_like(mixed_scores),
                    create_graph=True, retain_graph=True)[0]
    gradient = gradient.view(gradient.size(0), -1)
    gp = ((gradient.norm(2, dim=1) - 1) ** 2).mean()
    return gp

In [7]:
# ================================
# Initialize models and optimizers
# ================================

gen = Generator(Z_DIM).to(DEVICE)
critic = Critic().to(DEVICE)
opt_gen = torch.optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

In [None]:
# ================================
# Inception and FID
# ================================

fid = FrechetInceptionDistance(normalize=True).to(DEVICE)
inception = InceptionScore().to(DEVICE)



In [9]:
# ================================
# Load Checkpoint
# ================================

gen_losses = []
critic_losses = []

if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    gen.load_state_dict(checkpoint["gen_state_dict"])
    critic.load_state_dict(checkpoint["critic_state_dict"])
    opt_gen.load_state_dict(checkpoint["opt_gen_state_dict"])
    opt_critic.load_state_dict(checkpoint["opt_critic_state_dict"])
    gen_losses = checkpoint["gen_losses"]
    critic_losses = checkpoint["critic_losses"]
    start_epoch = checkpoint["epoch"] + 1
    print(f"✅ Resumed training from epoch {start_epoch}")
else:
    print("🆕 Starting training from scratch.")

✅ Resumed training from epoch 3


  checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)


In [None]:
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import os
from tqdm import tqdm

# ================================
# Setup for FID and IS
# ================================
fid = FrechetInceptionDistance(feature=64).to(DEVICE)  # Keep FID calculation on the same device as your training
inception = InceptionScore().to(DEVICE)  # Initialize InceptionScore

# Create directories to save generated images and metrics
os.makedirs(SAMPLE_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

# ================================
# Training Loop
# ================================
fixed_noise = torch.randn(3, Z_DIM).to(DEVICE)
print("Starting training...")

for epoch in range(start_epoch, EPOCHS):
    for real in tqdm(dataloader):
        real = real.to(DEVICE)
        cur_batch_size = real.size(0)

        # ================================
        # Train Critic
        # ================================
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)
            fake = gen(noise).detach()
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake)
            loss_critic = -(critic_real.mean() - critic_fake.mean()) + LAMBDA_GP * gp

            opt_critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()

        # ================================
        # Train Generator
        # ================================
        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)
        fake = gen(noise)
        output = critic(fake).reshape(-1)
        loss_gen = -output.mean()

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

    gen_losses.append(loss_gen.item())
    critic_losses.append(loss_critic.item())

    # ================================
    # Save Generated Images and Metrics
    # ================================
    with torch.no_grad():
        fake_images = gen(fixed_noise).cpu()
        for i in range(fake_images.size(0)):
            save_image(fake_images[i], os.path.join(SAMPLE_DIR, f"epoch_{epoch+1}_sample_{i+1}.png"), normalize=True)

        # Resize real and fake to 299x299
        real_resized = F.interpolate(real, size=(299, 299), mode='bilinear', align_corners=False)
        fake_resized = F.interpolate(fake, size=(299, 299), mode='bilinear', align_corners=False)

        # Convert 1-channel grayscale images to 3-channel by repeating
        if real_resized.shape[1] == 1:
            real_resized = real_resized.repeat(1, 3, 1, 1)
        if fake_resized.shape[1] == 1:
            fake_resized = fake_resized.repeat(1, 3, 1, 1)

        # Convert to uint8 for FID and Inception Score
        device = next(fid.parameters()).device
        real_uint8 = ((real_resized + 1) * 127.5).clamp(0, 255).to(torch.uint8).to(device)
        fake_uint8 = ((fake_resized + 1) * 127.5).clamp(0, 255).to(torch.uint8).to(device)

        # Now it's safe to update
        fid.update(real_uint8, real=True)
        fid.update(fake_uint8, real=False)
        inception.update(fake_uint8)

        # Compute FID and IS scores
        fid_score = fid.compute().item()
        is_score = inception.compute()[0].item()

        # Save the scores to a text file
        with open(os.path.join(METRICS_DIR, "scores.txt"), 'a') as f:
            f.write(f"Epoch {epoch+1}: FID: {fid_score:.4f}, IS: {is_score:.4f}\n")

    # Print the results for the current epoch
    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}, FID: {fid_score:.4f}, IS: {is_score:.4f}")

    # Save the model checkpoints
    torch.save({
        "epoch": epoch,
        "gen_state_dict": gen.state_dict(),
        "critic_state_dict": critic.state_dict(),
        "opt_gen_state_dict": opt_gen.state_dict(),
        "opt_critic_state_dict": opt_critic.state_dict(),
        "gen_losses": gen_losses,
        "critic_losses": critic_losses,
    }, CHECKPOINT_PATH)


In [None]:
# ================================
# Save Loss Graph
# ================================

plt.figure(figsize=(10, 5))
plt.plot(gen_losses, label="Generator Loss")
plt.plot(critic_losses, label="Critic Loss")
plt.title('Training Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig(os.path.join(METRICS_DIR, "loss_plot.png"))
plt.close()