# MRI Image Generation with WGAN-GP

This notebook implements a WGAN-GP (Wasserstein GAN with Gradient Penalty) to generate synthetic MRI images. We'll:

## Steps:
1. Install required packages
2. Load and preprocess 2D MRI data
3. Train the DCGAN model
4. Generate new MRI images

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.autograd import Variable, grad
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import nibabel as nib
from PIL import Image
import glob
from tqdm import tqdm
import random

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:

# Define data directory
data_dir = "PATH/TO/DATA/FOLDER"
png_files = glob.glob(f"{data_dir}/*.png")
print(f"Found {len(png_files)} PNG images")

# Transformations for PNG images
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize((128, 128)),                # Resize to 128x128
    transforms.ToTensor(),                        # Convert to tensor [0,1]
    transforms.Normalize([0.5], [0.5])            # Normalize to [-1, 1]
])

# Function to load and preprocess PNG image
def load_and_preprocess_png(file_path):
    """
    Load and preprocess a 2D PNG image
    """
    image = Image.open(file_path).convert("RGB")
    return transform(image)

# Load all images
print("Loading and preprocessing PNG images...")
slices = []
for file_path in tqdm(png_files):
    try:
        slice_tensor = load_and_preprocess_png(file_path)
        slices.append(slice_tensor)
    except Exception as e:
        print(f"Error processing {file_path}: {e}")

# Create tensor dataset
dataset = TensorDataset(torch.stack(slices))
print(f"Dataset created with {len(dataset)} images")

# Visualize some slices
plt.figure(figsize=(15, 5))
for i in range(min(5, len(slices))):
    plt.subplot(1, 5, i+1)
    plt.imshow(slices[i].squeeze().numpy(), cmap='gray')
    plt.axis('off')
    plt.title(f"Image {i+1}")
plt.tight_layout()
plt.show()

## 4. Define Generator and Discriminator Models

In [None]:
# Configuration
img_size = 128
channels = 1
latent_dim = 128
batch_size = 64
n_epochs = 750
n_critic = 5
lambda_gp = 10
lr = 0.0002              
b1 = 0.1
b2 = 0.999
sample_interval = 2000

img_shape = (channels, img_size, img_size)

# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

# Initialize generator and discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Create output directory for generated images
os.makedirs("generated_images", exist_ok=True)

# DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# Function to compute gradient penalty
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(torch.ones(real_samples.shape[0], 1).to(device), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
# Training
g_losses = []
d_losses = []
batches_done = 0

print("Starting training...")
for epoch in range(n_epochs):
    for i, (real_imgs,) in enumerate(dataloader):

        # Configure input
        real_imgs = real_imgs.to(device)

        # -----------------
        # Train Discriminator
        # -----------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = torch.randn(real_imgs.shape[0], latent_dim).to(device)

        # Generate a batch of images
        fake_imgs = generator(z)

        # Real images
        real_validity = discriminator(real_imgs)
        # Fake images
        fake_validity = discriminator(fake_imgs.detach())
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs.detach())
        # Adversarial loss
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

        d_loss.backward()
        optimizer_D.step()

        # -----------------
        # Train Generator
        # -----------------

        if i % n_critic == 0:
            optimizer_G.zero_grad()

            # Generate a batch of images
            fake_imgs = generator(z)
            # Loss measures generator's ability to fool the discriminator
            # Train on fake images
            fake_validity = discriminator(fake_imgs)
            g_loss = -torch.mean(fake_validity)

            g_loss.backward()
            optimizer_G.step()

            # Save losses for plotting
            g_losses.append(g_loss.item())
            d_losses.append(d_loss.item())

            # Print progress
            if i % 10 == 0:
                print(
                    f"[Epoch {epoch}/{n_epochs}] "
                    f"[Batch {i}/{len(dataloader)}] "
                    f"[D loss: {d_loss.item():.4f}] "
                    f"[G loss: {g_loss.item():.4f}]"
                )

            # Save generated images
            if batches_done % sample_interval == 0:
                save_image(fake_imgs.data[:25], f"generated_images/{batches_done}.png", 
                           nrow=5, normalize=True)
                
            batches_done += n_critic

# Save the model
torch.save(generator.state_dict(), "mri_wgan_gp_generator.pth")
torch.save(discriminator.state_dict(), "mri_wgan_gp_discriminator.pth")
print("Training completed!")

In [None]:
# Plot the training losses
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label="Generator loss")
plt.plot(d_losses, label="Discriminator loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.title("Training Losses")
plt.savefig("loss_plot.png")
plt.show()

In [None]:
# Load the trained generator
generator = Generator().to(device)
generator.load_state_dict(torch.load("mri_wgan_gp_generator.pth"))
generator.eval()

# Generate new images
num_images = 16
z = torch.randn(num_images, latent_dim).to(device)
with torch.no_grad():
    generated_images = generator(z)

# Display generated images
plt.figure(figsize=(12, 12))
for i in range(num_images):
    plt.subplot(4, 4, i+1)
    # Convert from tensor (-1,1) range to numpy (0,1) for display
    img = (generated_images[i].cpu().squeeze().numpy() + 1) / 2
    plt.imshow(img, cmap='gray')
    plt.axis('off')
plt.tight_layout()
plt.savefig("generated_mri_samples.png")
plt.show()

print("New MRI images have been successfully generated!")