# MRI Slice-based GAN

This notebook implements a DCGAN model for generating synthetic 2D MRI slices from 3D MRI volumes.

## Steps:
1. Install required packages
2. Load and preprocess 2D MRI data
3. Train the DCGAN model
4. Generate new MRI images

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torchvision.utils as vutils

# Set random seeds for reproducibility
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Dataset class for MRI slices
class MRISliceDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        
        # Lấy tất cả file ảnh PNG hoặc JPG trong thư mục
        self.image_paths = [os.path.join(image_dir, fname) 
                            for fname in os.listdir(image_dir) 
                            if fname.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        print(f"Loaded {len(self.image_paths)} 2D MRI slice images from: {image_dir}")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("L")  # Đọc ảnh grayscale
        
        if self.transform:
            image = self.transform(image)
        
        return image

In [None]:
# Define weight initialization function
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# Generator model
class Generator(nn.Module):
    def __init__(self, latent_dim, img_size, channels=1):
        super(Generator, self).__init__()

        self.init_size = img_size // 4
        self.latent_dim = latent_dim
        
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


# Discriminator model
class Discriminator(nn.Module):
    def __init__(self, img_size, channels=1):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [None]:
# Set hyperparameters
latent_dim = 128
img_size = 128  # Resize MRI slices to this size
batch_size = 128
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_epochs = 750
sample_interval = 50  # Save generated images every n epochs

# Data preprocessing
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.Grayscale(1),  # Ensure single channel
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

# Load the dataset
dataset = MRISliceDataset(image_dir="PATH/TO/DATA/FOLDER", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize models
generator = Generator(latent_dim=latent_dim, img_size=img_size).to(device)
discriminator = Discriminator(img_size=img_size).to(device)

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Set up optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# Loss function
adversarial_loss = torch.nn.BCELoss()

# Create directory for sample images
os.makedirs("mri_samples", exist_ok=True)

# For visualizing training progress
G_losses = []
D_losses = []

In [None]:
# Training loop
for epoch in range(n_epochs):
    for i, imgs in enumerate(dataloader):
        # Configure real and fake image batch
        real_imgs = imgs.to(device)
        
        # Adversarial ground truths
        valid = torch.ones(imgs.size(0), 1, device=device)
        fake = torch.zeros(imgs.size(0), 1, device=device)

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = torch.randn(imgs.shape[0], latent_dim, device=device)

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
        
        # Save losses for plotting
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())

    # Print progress
    print(
        f"[Epoch {epoch}/{n_epochs}] "
        f"[D loss: {d_loss.item():.4f}] "
        f"[G loss: {g_loss.item():.4f}]"
    )

    # Save generated samples at specified intervals
    if epoch % sample_interval == 0:
        # Generate and save images
        with torch.no_grad():
            z = torch.randn(16, latent_dim, device=device)  # Generate 16 samples
            gen_imgs = generator(z)
            vutils.save_image(gen_imgs.data[:16], f"mri_samples/epoch_{epoch}.png", 
                             normalize=True, nrow=4)
            
        # Also save a sample with a real MRI for comparison
        if i == 0:  # Just use the first batch
            comparison = torch.cat((real_imgs[:8], gen_imgs[:8]), dim=0)
            vutils.save_image(comparison, f"mri_samples/comparison_{epoch}.png",
                             normalize=True, nrow=8)

# Save the trained model
torch.save(generator.state_dict(), "mri_generator.pth")
torch.save(discriminator.state_dict(), "mri_discriminator.pth")

# Plot the training losses
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("mri_training_losses.png")
plt.show()

In [None]:
def generate_mri_samples(num_samples=16):
    # Load the trained generator
    trained_generator = Generator(latent_dim=latent_dim, img_size=img_size).to(device)
    trained_generator.load_state_dict(torch.load("mri_generator.pth"))
    trained_generator.eval()
    
    # Generate random latent vectors
    z = torch.randn(num_samples, latent_dim, device=device)
    
    # Generate images with the generator
    with torch.no_grad():
        generated_images = trained_generator(z)
    
    # Denormalize and convert to numpy for display
    generated_images = (generated_images.detach().cpu() * 0.5 + 0.5).numpy()
    
    # Plot images
    fig, axes = plt.subplots(int(np.sqrt(num_samples)), int(np.sqrt(num_samples)), figsize=(10, 10))
    axes = axes.flatten()
    
    for i, img in enumerate(generated_images):
        axes[i].imshow(img[0], cmap='gray')  # img[0] to get the first channel (grayscale)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig("mri_generated_samples.png", dpi=300)
    plt.show()
    
    return generated_images

# Generate and display new MRI samples
generated_samples = generate_mri_samples(16)