In [None]:
#%matplotlib inline
import config.unetConfig as cfg
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import os
import random

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils

from data import BraTSDataset
from UNet_trainer import GANTrainer
from UNet_generator import UNetGenerator
from UNet_encoder_decoder import UNetDiscriminator
from IPython.display import HTML
from torch.utils.data import DataLoader
from utils import weights_init_norm, weights_init_ortho, plot_train_metrics


In [None]:
# Dataset preprocessing from the 77th slice 
class BraTSDataset(Dataset):
    def __init__(self, image_paths, transforms=None) -> None:
        # Filter out non-existent or empty files
        valid_paths = []
        for path in image_paths:
            if os.path.exists(path) and os.path.getsize(path) > 0:
                valid_paths.append(path)
            else:
                print(f"[Warning] Skipping invalid or missing file: {path}")

        if len(valid_paths) == 0:
            raise ValueError("No valid image files found!")

        self.imagePaths = valid_paths
        self.transforms = transforms

    def __len__(self):
        return len(self.imagePaths)

    def __getitem__(self, index):
        imagePath = self.imagePaths[index]
        try:
            nii_image = nib.load(imagePath)
            image = nii_image.get_fdata()[:, :, 77]
            image = np.uint8(image / image.max() * 255)
            image = Image.fromarray(image)

            if self.transforms is not None:
                image = self.transforms(image)

            return image

        except Exception as e:
            print(f"[Error] Failed to load or process: {imagePath}\n{e}")
            # Optionally: Return a blank image or raise the error
            raise e

    def save(self, store_path):
        os.makedirs(store_path, exist_ok=True)

        for i, impath in enumerate(self.imagePaths):
            try:
                nii_image = nib.load(impath)
                image = nii_image.get_fdata()[:, :, 77]
                image = np.uint8(image / image.max() * 255)
                image = Image.fromarray(image)

                if self.transforms is not None:
                    image = self.transforms(image)

                vutils.save_image(image, f'{store_path}/{i}.png')

            except Exception as e:
                print(f"[Warning] Skipped saving image {i} ({impath}) due to error:\n{e}")

In [None]:
!jupyter nbextension enable --py widgetsnbextension

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
dpath = cfg.DSET_CPATHS[f"{cfg.BT_CLASS}"]

In [None]:
# load the T1 image filepaths in a sorted manner
image_paths = [os.path.join(dpath, impath) for impath in sorted(os.listdir(dpath))]

tf = transforms.Compose([
    transforms.Resize((cfg.INPUT_IMAGE_HEIGHT,cfg.INPUT_IMAGE_WIDTH)),
    transforms.CenterCrop((cfg.INPUT_IMAGE_HEIGHT,cfg.INPUT_IMAGE_WIDTH)),
    transforms.ToTensor(),
    ])

# Create the dataset
dataset = BraTSDataset(image_paths, tf)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=cfg.BATCH_SIZE,
                        shuffle=True, num_workers=cfg.NUM_WORKERS)

In [None]:
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(15,15))
plt.axis("off")
plt.title(f"Training Images ({cfg.BT_CLASS})")
plt.imshow(np.transpose(vutils.make_grid(real_batch.to(cfg.DEVICE), normalize=True).cpu(),(1,2,0)))

# Create the generator
netG = UNetGenerator(cfg.LATENT_SZ, cfg.NGF, cfg.NGC).to(cfg.DEVICE)

# Apply the weights_init_ortho
netG.apply(weights_init_norm)

# Print the model
print(netG)

In [None]:
# Create the Discriminator
netD = UNetDiscriminator(ns=cfg.NEG_SLOPE).to(cfg.DEVICE)

# Apply the weights_init_ortho
netD.apply(weights_init_ortho)

# Print the model
print(netD)
trainer = GANTrainer(num_epochs=cfg.NUM_EPOCHS,
                        glr=cfg.GLR, dlr=cfg.DLR,
                        gbeta1=cfg.GBETA1, dbeta1=cfg.DBETA1,
                        dataloader=dataloader,
                        netG=netG, netD=netD,
                        device=cfg.DEVICE)
trainer.train(nz=cfg.LATENT_SZ, batch_sz=cfg.BATCH_SIZE)