In [None]:
# data preprocessing considering the mid 77th slice

class BraTSDataset(Dataset):
    def __init__(self, image_paths, transforms=None) -> None:
        # Filter out non-existent or empty files
        valid_paths = []
        for path in image_paths:
            if os.path.exists(path) and os.path.getsize(path) > 0:
                valid_paths.append(path)
            else:
                print(f"[Warning] Skipping invalid or missing file: {path}")

        if len(valid_paths) == 0:
            raise ValueError("No valid image files found!")

        self.imagePaths = valid_paths
        self.transforms = transforms

    def __len__(self):
        return len(self.imagePaths)

    def __getitem__(self, index):
        imagePath = self.imagePaths[index]
        try:
            nii_image = nib.load(imagePath)
            image = nii_image.get_fdata()[:, :, 77]
            image = np.uint8(image / image.max() * 255)
            image = Image.fromarray(image)

            if self.transforms is not None:
                image = self.transforms(image)

            return image

        except Exception as e:
            print(f"[Error] Failed to load or process: {imagePath}\n{e}")
            # Optionally: Return a blank image or raise the error
            raise e

    def save(self, store_path):
        os.makedirs(store_path, exist_ok=True)

        for i, impath in enumerate(self.imagePaths):
            try:
                nii_image = nib.load(impath)
                image = nii_image.get_fdata()[:, :, 77]
                image = np.uint8(image / image.max() * 255)
                image = Image.fromarray(image)

                if self.transforms is not None:
                    image = self.transforms(image)

                vutils.save_image(image, f'{store_path}/{i}.png')

            except Exception as e:
                print(f"[Warning] Skipped saving image {i} ({impath}) due to error:\n{e}")

In [None]:
import os

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import tqdm

from data import BraTSDataset

from IPython.display import HTML

import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

In [None]:
def compute_psnr(real_batch: np.ndarray, fake_batch: np.ndarray) -> float:
    b_size = real_batch.shape[0]
    psnr_val = 0.0
    for i in range(b_size):
        psnr_val += psnr(
            real_batch[i, :, :, :].transpose(1, 2, 0),
            fake_batch[i, :, :, :].transpose(1, 2, 0),
            data_range=1.0,
        )
    return psnr_val / b_size

In [None]:
def compute_ssim(real_batch: np.ndarray, fake_batch: np.ndarray) -> float:
    b_size = real_batch.shape[0]
    ssim_val = 0.0
    for i in range(b_size):
        ssim_val += ssim(
            real_batch[i, :, :, :],
            fake_batch[i, :, :, :],
            channel_axis=0,
            data_range=1.0,
        )
    return ssim_val / b_size

In [None]:
# config for training
dataset_root = "dataset"
t1_train_data = "T1c_BraTS_2023"
image_size = 64
batch_size = 8
#batch_size = 128
num_workers = 16
device = torch.device("cuda:0")

# Generator model configuration
latent_size = 128
feature_map_size = image_size//2 # was 32 before

# Learning rate
lr = 0.0002
# beta 1 for Adam
beta1 = 0.5

# Number of training epochs
num_epochs = 100


In [None]:
# image_paths = [os.path.join(t1_train_data, impath) for impath in os.listdir(t1_train_data)]
image_paths = [
    os.path.join(t1_train_data, f)
    for f in os.listdir(t1_train_data)
    if (f.endswith(".nii") or f.endswith(".nii.gz")) and not f.startswith("._")
]

tf = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    ])
dataset = BraTSDataset(image_paths, tf)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [None]:
def weight_init(m):
    """Custom weight initialization called on netG and netD"""
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.normal_(m.bias.data, 0)


In [None]:
# Generator code
class Generator(nn.Module):
    def __init__(self) -> None:
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input: latent_size x 1 x 1
            nn.ConvTranspose2d(latent_size, feature_map_size * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_map_size * 8),
            nn.ReLU(True),
            # input: feature_map_size*16 x image_size/16 x image_size/16
            nn.ConvTranspose2d(feature_map_size * 8, feature_map_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_map_size * 4),
            nn.ReLU(True),
            # input: feature_map_size*8 x image_size/8 x image_size/8
            nn.ConvTranspose2d(feature_map_size * 4, feature_map_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_map_size * 2),
            nn.ReLU(True),
            # input: feature_map_size*4 x image_size/4 x image_size/4
            nn.ConvTranspose2d(feature_map_size*2, feature_map_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_map_size),
            nn.ReLU(True),
            # input: feature_map_size*2 x image_size/2 x image_size/2
            nn.ConvTranspose2d(feature_map_size, 1, 4, 2, 1, bias=False),
            nn.Tanh(),
            # output: 1 x image_size x image_size
        )
    
    def forward(self, input):
        return self.main(input)

In [None]:
netG = Generator().to(device)
netG.apply(weight_init)
print(netG)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input: 1 x image_size x image_size
            nn.Conv2d(1, feature_map_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_map_size),
            nn.LeakyReLU(0.2, True),
            # input: feature_map_size x image_size/2 x image_size/2
            nn.Conv2d(feature_map_size, feature_map_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_map_size * 2),
            nn.LeakyReLU(0.2, True),
            # input: feature_map_size*2 x image_size/4 x image_size/4
            nn.Conv2d(feature_map_size * 2, feature_map_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_map_size * 4),
            nn.LeakyReLU(0.2, True),
            # input: feature_map_size*4 x image_size/8 x image_size/8
            nn.Conv2d(feature_map_size * 4, feature_map_size * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_map_size * 8),
            nn.LeakyReLU(0.2, True),
            # input: feature_map_size*8 x image_size/16 x image_size/16
            nn.Conv2d(feature_map_size * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # output: 1 x 1 x 1
        )
    
    def forward(self, input):
        return self.main(input)


In [None]:
netD = Discriminator().to(device)
netD.apply(weight_init)
print(netD)

In [None]:
critetion = nn.BCELoss()

# fixed_noise = torch.randn(16, latent_size, 1, 1, device=device)

real_label = 1.
fake_label = 0.

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
# Training loop

g_loss_hist = list()
d_loss_hist = list()
d_loss_real_hist = list()
d_loss_fake_hist = list()
psnr_hist = list()
ssim_hist = list()

best_psnr = -float("inf")
best_ssim = -float("inf")

best_g_weights = None
best_d_weights = None

print("Staring training loop...")
print(f"Number of mini batch iterations per epoch: {len(dataloader)}")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in dataloader
    for i, data in tqdm(enumerate(dataloader, 0)):
        # Train with real batch
        real = data.to(device)
        b_size = real.size(0)

        # Update D network
        # maximize log(D(x)) using batch of real images
        # maximize log(1 - D(G(z))) using batch of fake images
        netD.zero_grad()
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real).view(-1)
        errD_real = critetion(output, label)
        errD_real.backward()

        # Train with fake batch
        noise = torch.randn(b_size, latent_size, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = critetion(output, label)
        errD_fake.backward()

        errD = errD_real + errD_fake
        d_loss_real_hist.append(errD_real.detach().cpu().numpy())
        d_loss_fake_hist.append(errD_fake.detach().cpu().numpy())
        d_loss_hist.append(errD.detach().cpu().numpy())
        optimizerD.step()

        # Train generator: maximize log(D(G(z)))
        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = critetion(output, label)
        errG.backward()

        g_loss_hist.append(errG.detach().cpu().numpy())
        optimizerG.step()

        # Save best model based on PSNR every iteration
        psnr_hist.append(
            compute_psnr(
                real.detach().cpu().numpy(),
                fake.detach().cpu().numpy()
            )
        )
        ssim_hist.append(
            compute_ssim(
                real.detach().cpu().numpy(),
                fake.detach().cpu().numpy()
            )
        )
        if ssim_hist[-1] > best_ssim: #ssim from their code
            best_psnr = psnr_hist[-1]
            best_ssim = ssim_hist[-1]
            best_g_weights = netG.state_dict()
            best_d_weights = netD.state_dict()

    # Summarize performance every epoch
    tqdm.write("\n".join((
        f"epoch: {epoch}, Loss D: {d_loss_hist[-1]:.4f}, Loss G: {g_loss_hist[-1]:.4f} ",
        f"Loss D real: {d_loss_real_hist[-1]:.4f} ",
        f"Loss D fake: {d_loss_fake_hist[-1]:.4f} ",
        f"psnr: {psnr_hist[-1]:.4f}, best_psnr: {best_psnr:.4f} ",
        f"ssim: {ssim_hist[-1]:.4f}, best_ssim: {best_ssim:.4f} ",
    )))

