# GANN Synthetic MRI generator

### Loading Preprocessed data

In [None]:
import h5py

# Load the h5py file
file_path = './mri_data_balanced.h5'
with h5py.File(file_path, 'r') as f:
    # Function to recursively print the structure of the file
    def print_structure(name, obj):
        if isinstance(obj, h5py.Group):
            print(f"Group: {name}")
        elif isinstance(obj, h5py.Dataset):
            print(f"Dataset: {name}, Shape: {obj.shape}, Data type: {obj.dtype}")

    # Visit all groups and datasets in the file
    f.visititems(print_structure)


Dataset: X, Shape: (54, 176, 240, 205), Data type: float64
Dataset: Y, Shape: (54,), Data type: |S10


In [None]:
import numpy as np

with h5py.File(file_path, 'r') as f:
    # Load MRI data
    X = np.array(f['X'])  # Shape: (54, 176, 240, 205)

    # Load labels
    Y = np.array(f['Y'])  # Shape: (54,)
    Y = [y.decode('utf-8') for y in Y]  # Decode the byte strings to regular strings

# Check the shapes and labels
print(f"MRI Data Shape: {X.shape}")
print(f"Labels: {Y}")

MRI Data Shape: (54, 176, 240, 205)
Labels: ['PD', 'NORMAL', 'NORMAL', 'NORMAL', 'PD', 'PD', 'NORMAL', 'PD', 'PD', 'PD', 'PD', 'NORMAL', 'PD', 'NORMAL', 'NORMAL', 'NORMAL', 'PD', 'PD', 'NORMAL', 'PD', 'PD', 'NORMAL', 'PD', 'PD', 'NORMAL', 'NORMAL', 'PD', 'NORMAL', 'PD', 'PD', 'PD', 'PD', 'PD', 'PD', 'NORMAL', 'PD', 'NORMAL', 'NORMAL', 'PD', 'PD', 'PD', 'PD', 'PD', 'NORMAL', 'NORMAL', 'NORMAL', 'NORMAL', 'NORMAL', 'NORMAL', 'NORMAL', 'NORMAL', 'NORMAL', 'NORMAL', 'NORMAL']


### Augmenting Data 

In [None]:
!pip install torchio

Collecting torchio
  Downloading torchio-0.20.0-py2.py3-none-any.whl.metadata (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.6/50.6 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting Deprecated (from torchio)
  Downloading Deprecated-1.2.14-py2.py3-none-any.whl.metadata (5.4 kB)
Collecting SimpleITK!=2.0.*,!=2.1.1.1 (from torchio)
  Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Collecting humanize (from torchio)
  Downloading humanize-4.10.0-py3-none-any.whl.metadata (7.9 kB)
Collecting nibabel (from torchio)
  Downloading nibabel-5.2.1-py3-none-any.whl.metadata (8.8 kB)
Downloading torchio-0.20.0-py2.py3-none-any.whl (175 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.3/175.3 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import torchio as tio


# Define the augmentation transforms
transforms = tio.Compose([
    tio.RandomFlip(axes=(0, 1, 2)),  # Flip along different axes
    tio.RandomAffine(scales=(0.9, 1.1), degrees=10),  # Random scaling and rotation
    tio.RandomNoise(mean=0, std=0.1),  # Adding Gaussian noise
    tio.RandomElasticDeformation()  # Elastic deformation
])

# Create lists to hold the augmented data and labels
augmented_data = []
augmented_labels = []

# Augment both PD and NORMAL labeled MRIs
for i in range(len(X)):
    # Add channel dimension to make the data 4D: (1, H, W, D)
    mri_sample = X[i][np.newaxis, ...]

    # Convert the MRI data to a format for TorchIO
    subject = tio.Subject(mri=tio.ScalarImage(tensor=mri_sample.astype(np.float32)))

    # Apply the augmentation
    augmented_subject = transforms(subject)

    # Get augmented MRI and add to the augmented data list
    augmented_mri = augmented_subject['mri'].data.numpy()
    augmented_data.append(augmented_mri)

    # Append the corresponding label for the augmented data
    augmented_labels.append(Y[i])

# Convert the augmented data and labels to NumPy arrays
augmented_data = np.array(augmented_data)
augmented_labels = np.array(augmented_labels)

# Print the new augmented data shape
print(f"Augmented MRI Data Shape: {augmented_data.shape}")
print(f"Augmented Labels: {augmented_labels}")


Augmented MRI Data Shape: (54, 1, 176, 240, 205)
Augmented Labels: ['PD' 'NORMAL' 'NORMAL' 'NORMAL' 'PD' 'PD' 'NORMAL' 'PD' 'PD' 'PD' 'PD'
 'NORMAL' 'PD' 'NORMAL' 'NORMAL' 'NORMAL' 'PD' 'PD' 'NORMAL' 'PD' 'PD'
 'NORMAL' 'PD' 'PD' 'NORMAL' 'NORMAL' 'PD' 'NORMAL' 'PD' 'PD' 'PD' 'PD'
 'PD' 'PD' 'NORMAL' 'PD' 'NORMAL' 'NORMAL' 'PD' 'PD' 'PD' 'PD' 'PD'
 'NORMAL' 'NORMAL' 'NORMAL' 'NORMAL' 'NORMAL' 'NORMAL' 'NORMAL' 'NORMAL'
 'NORMAL' 'NORMAL' 'NORMAL']


### Encoding Y labels and saving the data

In [None]:
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(augmented_labels)

output_file_path = './mri_data_augmented.h5'

# Open a new h5py file to write the augmented data
with h5py.File(output_file_path, 'w') as f:
    # Save the augmented MRI data
    f.create_dataset('X_augmented', data=augmented_data, compression="gzip")

    # Save the encoded labels
    f.create_dataset('Y_augmented', data=encoded_labels, compression="gzip")

print(f"Augmented data saved to {output_file_path}")


Augmented data saved to /content/drive/MyDrive/mri_data_augmented.h5


### load augmented data

In [4]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np


In [None]:
# Custom Dataset for loading MRI data
class MRIDataset(Dataset):
    def __init__(self, h5_file):
        with h5py.File(h5_file, 'r') as f:
            self.X = f['X_augmented'][:]  # Assuming your augmented MRI images are stored in 'X'
            self.Y = f['Y_augmented'][:]  # Assuming labels are stored in 'Y_augmented', not 'X_augmented'
        self.Y = self.Y.astype(np.float32) #  Casting to float32 for compatibility with PyTorch

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        image = self.X[idx]
        label = self.Y[idx]
        # Removing the extra dimension addition, allowing the Dataloader to handle batching
        # image = np.expand_dims(image, axis=0)
        return torch.tensor(image, dtype=torch.float32), torch.tensor(label)

# Initialize dataset and dataloader
h5_file = './mri_data_augmented.h5'  # Replace with your augmented h5 file path
dataset = MRIDataset(h5_file)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

### Define GAN architecture 

In [5]:
# Define the GAN architecture
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose3d(100, 256, kernel_size=4, stride=1, padding=0),
            nn.ReLU(),
            nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),  # Assuming input images are normalized between -1 and 1
        )

    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv3d(1, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv3d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv3d(256, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid(),  # Output probability for real or fake
        )

    def forward(self, x):
        return self.model(x)


In [6]:
# Instantiate models, loss function, and optimizers
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()  # Binary Cross Entropy loss
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

### Run Training

#### code to load checkpoint model


In [None]:
# Load the saved checkpoint
checkpoint = torch.load('checkpoint_path_of_GANN_model')

# Load the models
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

# Load the optimizers
optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])

# Load the epoch to resume training from
start_epoch = checkpoint['epoch']

print(f"Resuming training from epoch {start_epoch + 1}")

In [None]:
# Training loop
num_epochs = 100  # Set the number of epochs
for epoch in range(num_epochs):
    for real_images, _ in dataloader:
        batch_size = real_images.size(0)

        # Create labels for real and fake images
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # Train the Discriminator
        optimizer_d.zero_grad()
        outputs = discriminator(real_images)  # Use real images
        # Reshape the discriminator output to match the label shape
        outputs = outputs.view(batch_size, -1)  # Flatten the output
        #before applying sigmoid function.
        outputs = torch.sigmoid(outputs)
        # Calculate loss after reshaping
        d_loss_real = criterion(outputs[:,0], real_labels.view(-1))  # Calculate loss
                                                           #after flattening outputs.


        noise = torch.randn(batch_size, 100, 1, 1, 1)  # Random noise for the generator
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())
        # Reshape the discriminator output to match the label shape
        outputs = outputs.view(batch_size, -1) # Flatten the output
                                              # before applying sigmoid
                                              #function.
        outputs = torch.sigmoid(outputs)
        # Calculate loss after reshaping
        d_loss_fake = criterion(outputs[:,0], fake_labels.view(-1))   #Calculate loss
                                                            #after flattening outputs.

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # Train the Generator
        optimizer_g.zero_grad()
        outputs = discriminator(fake_images)
        # Reshape the discriminator output to match the label shape
        outputs = outputs.view(batch_size, -1) # Flatten the output before
                                              #applying sigmoid function.
        outputs = torch.sigmoid(outputs)
        # Calculate loss after reshaping
        g_loss = criterion(outputs[:,0], real_labels.view(-1)) #Calculate loss
                                                             #after flattening outputs.
        g_loss.backward()
        optimizer_g.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

# Save the models if needed
# torch.save(generator.state_dict(), 'generator.pth')
# torch.save(discriminator.state_dict(), 'discriminator.pth')

Epoch [1/100], D Loss: 1.0258734226226807, G Loss: 0.6862360835075378
Epoch [2/100], D Loss: 1.0170859098434448, G Loss: 0.6889603137969971
Epoch [3/100], D Loss: 1.0130282640457153, G Loss: 0.6904316544532776
Epoch [4/100], D Loss: 1.0113495588302612, G Loss: 0.6908804774284363
Epoch [5/100], D Loss: 1.009940505027771, G Loss: 0.6916881203651428
Epoch [6/100], D Loss: 1.009946584701538, G Loss: 0.6916550993919373
Epoch [7/100], D Loss: 1.0090763568878174, G Loss: 0.692183792591095
Epoch [8/100], D Loss: 1.009307622909546, G Loss: 0.6920391917228699
Epoch [9/100], D Loss: 1.0090856552124023, G Loss: 0.692150890827179
Epoch [10/100], D Loss: 1.0078866481781006, G Loss: 0.6925312876701355
Epoch [11/100], D Loss: 1.0077693462371826, G Loss: 0.6925809383392334
Epoch [12/100], D Loss: 1.007596731185913, G Loss: 0.6926838755607605
Epoch [13/100], D Loss: 1.0073738098144531, G Loss: 0.6927151083946228
Epoch [14/100], D Loss: 1.0073097944259644, G Loss: 0.6927604675292969
Epoch [15/100], D Los

### Use if u want to train model later using checkpoint

In [None]:
# Define a path to save the model checkpoints
checkpoint_path = "./gan_checkpoint.pth"

# Save the generator, discriminator, and optimizer states
torch.save({
    'epoch': epoch,
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_g_state_dict': optimizer_g.state_dict(),
    'optimizer_d_state_dict': optimizer_d.state_dict(),
    'd_loss': d_loss.item(),
    'g_loss': g_loss.item(),
}, checkpoint_path)

print(f"Model saved after {epoch + 1} epochs.")

### Saving the model

In [None]:
# Save the models if needed
torch.save(generator.state_dict(), '/content/drive/MyDrive/generator.pth')
torch.save(discriminator.state_dict(), '/content/drive/MyDrive/discriminator.pth')

### Loading the model

In [8]:
import torch
import nibabel as nib


# Load the state dicts from the saved .pth files
generator.load_state_dict(torch.load('/content/drive/MyDrive/generator.pth'))
discriminator.load_state_dict(torch.load('/content/drive/MyDrive/discriminator.pth'))

# Set the models to evaluation mode
generator.eval()
discriminator.eval()

print("Generator and Discriminator loaded successfully.")


  generator.load_state_dict(torch.load('/content/drive/MyDrive/generator.pth'))
  discriminator.load_state_dict(torch.load('/content/drive/MyDrive/discriminator.pth'))


Generator and Discriminator loaded successfully.


In [9]:
import os
import torch
import nibabel as nib
import numpy as np
from datetime import datetime

# Generate synthetic MRI images using the loaded generator
def generate_synthetic_images(generator, num_images=1, save_as_nii=True):
    generator.eval()  # Set generator to evaluation mode

    # Generate random noise for input to the generator
    noise = torch.randn(num_images, 100, 1, 1, 1)  # Adjust latent vector size if needed

    with torch.no_grad():  # Disable gradients for generation
        generated_images = generator(noise).squeeze(1)  # Squeeze to remove channel dimension

    # Create output directory with date and time
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f"output/generated_images_{timestamp}"
    os.makedirs(output_dir, exist_ok=True)  # Create directory if it doesn't exist

    # Optionally, save generated images as .nii files
    if save_as_nii:
        for i in range(num_images):
            img_np = generated_images[i].cpu().numpy()
            nii_img = nib.Nifti1Image(img_np, affine=np.eye(4))  # Convert numpy array to NIfTI format
            nii_file_path = os.path.join(output_dir, f"generated_image_{i+1}.nii")
            nib.save(nii_img, nii_file_path)
            print(f"Generated and saved: {nii_file_path}")

# Generate and save synthetic MRI images
generate_synthetic_images(generator, num_images=3)


Generated and saved: generated_image_1.nii
Generated and saved: generated_image_2.nii
Generated and saved: generated_image_3.nii
