In [None]:
!pip3 install nibabel

In [None]:
import nibabel as nib

# Example to load a single image
image_path = '/kaggle/input/brats-africa/PKG - BraTS-Africa/BraTS-Africa/95_Glioma/BraTS-SSA-00002-000/BraTS-SSA-00002-000-t2f.nii'
image_data = nib.load(image_path)
image_array = image_data.get_fdata()

In [None]:
import matplotlib.pyplot as plt

plt.imshow(image_array[:, 92, :], cmap='gray')  # Replace slice_number with the slice you want to view
plt.show()


In [None]:
import os
import nibabel as nib
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
import torch
import torch.nn as nn
import numpy as np


class AdaIN(nn.Module):
    def __init__(self, in_channel, style_dim):
        super().__init__()
        self.norm = nn.InstanceNorm3d(in_channel)
        self.style = nn.Linear(style_dim, in_channel * 2)

    def forward(self, x, style):
        # Generate gamma and beta from style
        style = self.style(style).unsqueeze(2).unsqueeze(3).unsqueeze(4)
        gamma, beta = style.chunk(2, 1)
        # Normalize and apply gamma and beta
        out = self.norm(x)
        return gamma * out + beta


class MappingNetwork(nn.Module):
    def __init__(self, latent_dim=512, style_dim=512, num_classes=2):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(latent_dim + num_classes, latent_dim),
            nn.LeakyReLU(0.2),
            *[nn.Sequential(
                nn.Linear(latent_dim, latent_dim),
                nn.LeakyReLU(0.2)
            ) for _ in range(7)]
        )
        self.to_style = nn.Linear(latent_dim, style_dim)

    def forward(self, z, label):
        label_onehot = torch.zeros(z.size(0), 2, device=z.device)
        label_onehot.scatter_(1, label.unsqueeze(1), 1)
        x = torch.cat([z, label_onehot], dim=1)
        x = self.shared(x)
        return self.to_style(x)


class Generator3D(nn.Module):
    def __init__(self, latent_dim=512, style_dim=512, num_classes=2):
        super().__init__()
        self.mapping = MappingNetwork(latent_dim, style_dim, num_classes)
        self.const = nn.Parameter(torch.randn(1, 512, 4, 4, 4))
        self.conv1 = nn.Conv3d(512, 256, kernel_size=3, stride=1, padding=1)
        self.ada1 = AdaIN(512, style_dim)
        self.conv2 = nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1) 
        self.ada2 = AdaIN(256, style_dim)
        self.conv3 = nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1) 
        self.ada3 = AdaIN(128, style_dim)
        self.conv4 = nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1) 
        self.ada4 = AdaIN(64, style_dim)
        self.conv5 = nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1)
        self.ada5 = AdaIN(32, style_dim)
        self.conv6 = nn.Conv3d(16, 16, kernel_size=3, stride=1, padding=1)  
        self.conv7 = nn.ConvTranspose3d(16, 1, kernel_size=3, stride=1, padding=1) 

        self.activation = nn.LeakyReLU(0.2)
        self.tanh = nn.Tanh()

    def forward(self, z, label, noise_inject=True):
        batch_size = z.size(0)
        w = self.mapping(z, label)
        x = self.const.expand(batch_size, -1, -1, -1, -1)

        x = self.conv1(x)
        x = self.ada1(x, w)
        x = self.activation(x)
        if noise_inject:
            x = x + torch.randn_like(x) * 0.1

        x = self.conv2(x)
        x = self.ada2(x, w)
        x = self.activation(x)
        if noise_inject:
            x = x + torch.randn_like(x) * 0.1

        x = self.conv3(x)
        x = self.ada3(x, w)
        x = self.activation(x)
        if noise_inject:
            x = x + torch.randn_like(x) * 0.1

        x = self.conv4(x)
        x = self.ada4(x, w)
        x = self.activation(x)
        if noise_inject:
            x = x + torch.randn_like(x) * 0.1

        x = self.conv5(x)
        x = self.ada5(x, w)
        x = self.activation(x)
        if noise_inject:
            x = x + torch.randn_like(x) * 0.1

        x = self.conv6(x)
        x = self.activation(x)  
        x = self.conv7(x)
        return self.tanh(x)


In [None]:
def compute_gaussian_parameters(latents):
    """
    Computes the mean and covariance of the latent space for slice relationships.
    """
    mean_vector = np.mean(latents, axis=0)
    cov_matrix = np.cov(np.array(latents).T)
    return mean_vector, cov_matrix


def sample_new_latents(mean, cov, num_samples):
    """
    Samples new latent variables based on the Gaussian distribution.
    """
    return np.random.multivariate_normal(mean, cov, size=num_samples)


In [None]:
import nibabel as nib
from torch.utils.data import Dataset, DataLoader


class GliomaBraTSDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []

        for patient_folder in os.listdir(root_dir):
            patient_path = os.path.join(root_dir, patient_folder)
            if os.path.isdir(patient_path):
                t1c_path = [file for file in os.listdir(patient_path) if 't1c.nii' in file]
                if t1c_path:
                    self.data.append(os.path.join(patient_path, t1c_path[0]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = self.data[idx]
        volume = nib.load(file_path).get_fdata()
        
        # Normalize real volumes to [-1, 1]
        volume = (volume - volume.min()) / (volume.max() - volume.min())  # Normalize to [0, 1]
        volume = (volume * 2) - 1  # Normalize to [-1, 1]
        
        volume = torch.tensor(volume, dtype=torch.float32).unsqueeze(0)  # Add channel dim
        if self.transform:
            volume = self.transform(volume)
        return volume

class Discriminator3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(512, 1, kernel_size=4, stride=1, padding=0),
        )

    def forward(self, x):
        return self.model(x).view(-1)  # Flatten output to scalar


def discriminator_loss(real, fake):
    return torch.mean((real - 1) ** 2) + torch.mean(fake ** 2)

def generator_loss(fake):
    return torch.mean((fake - 1) ** 2)


In [None]:
def train(generator, dataloader, num_epochs, device):
    generator.to(device)
    optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.0, 0.99))
    loss_fn = nn.MSELoss()

    for epoch in range(num_epochs):
        for real_volumes in dataloader:
            real_volumes = real_volumes.to(device)

            # Generate fake volumes
            z = torch.randn(real_volumes.size(0), generator.mapping.latent_dim).to(device)
            labels = torch.randint(0, 2, (real_volumes.size(0),)).to(device)
            fake_volumes = generator(z, labels)

            # Compute loss
            loss = loss_fn(fake_volumes, real_volumes)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")


In [None]:
import os
from torchvision.transforms import Compose
from torch.utils.data import DataLoader

transform = Compose([])

# Initialize the dataset and dataloader
root_dir = "/kaggle/input/brats-africa/PKG - BraTS-Africa/BraTS-Africa/95_Glioma"  # Replace with your dataset directory
dataset = GliomaBraTSDataset(root_dir=root_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

latent_dim = 512
style_dim = 512
num_classes = 2  # Adjust based on dataset labels

generator = Generator3D(latent_dim=latent_dim, style_dim=style_dim, num_classes=num_classes).to(device)


In [None]:
latent_space = []

# Generate latent representations for the dataset
for real_volumes in dataloader:
    real_volumes = real_volumes.to(device)
    z = torch.randn(real_volumes.size(0), latent_dim).to(device)
    labels = torch.randint(0, 2, (real_volumes.size(0),)).to(device)

    # Store latent vectors
    with torch.no_grad():
        latents = generator.mapping(z, labels).cpu().numpy()
        latent_space.extend(latents)

# Compute Gaussian parameters
latent_mean, latent_cov = compute_gaussian_parameters(latent_space)


In [None]:
import torch.optim as optim

# Hyperparameters
latent_dim = 512
style_dim = 512
num_classes = 2
learning_rate = 0.0002
num_epochs = 50

discriminator = Discriminator3D().to(device)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator3D(latent_dim, style_dim, num_classes).to(device)
optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.0, 0.99))
loss_fn = nn.MSELoss()

for epoch in range(num_epochs):
    for real_volumes in dataloader:
        real_volumes = real_volumes.to(device)

        # Preprocess real volumes to match discriminator input shape
        real_volumes_resized = F.interpolate(real_volumes, size=(128, 128, 128), mode='trilinear', align_corners=False)

        # Generate fake volumes
        z = torch.randn(real_volumes.size(0), latent_dim).to(device)
        labels = torch.randint(0, num_classes, (real_volumes.size(0),)).to(device)
        fake_volumes = generator(z, labels)

        print(f"Real volumes resized shape: {real_volumes_resized.shape}")
        print(f"Fake volumes shape: {fake_volumes.shape}")

        # Check shapes
        assert real_volumes_resized.shape == fake_volumes.shape, "Shape mismatch between real and fake volumes"

        # Train Discriminator
        optimizer_d.zero_grad()
        real_validity = discriminator(real_volumes_resized)
        fake_validity = discriminator(fake_volumes.detach())
        d_loss = torch.mean((real_validity - 1) ** 2) + torch.mean(fake_validity ** 2)  # LSGAN Loss
        d_loss.backward()
        optimizer_d.step()

        # Train Generator
        optimizer_g.zero_grad()
        fake_validity = discriminator(fake_volumes)
        g_loss = torch.mean((fake_validity - 1) ** 2)  # LSGAN Loss
        g_loss.backward()
        optimizer_g.step()

    print(f"Epoch {epoch+1}/{num_epochs}, D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")



In [None]:
print(generator)

In [None]:
# Sample new latent variables from Gaussian distribution
num_samples = 10
sampled_latents = sample_new_latents(latent_mean, latent_cov, num_samples)

# Convert sampled latents to PyTorch tensors
sampled_latents = torch.tensor(sampled_latents, dtype=torch.float32).to(device)
labels = torch.randint(0, 2, (num_samples,)).to(device)

# Generate new 3D MRI volumes
with torch.no_grad():
    generated_volumes = generator(sampled_latents, labels)

# Visualize or save generated volumes
generated_volumes = generated_volumes.cpu().numpy()


In [None]:
import matplotlib.pyplot as plt

# Visualize slices from a generated 3D volume
for i, volume in enumerate(generated_volumes):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Axial, coronal, and sagittal slices
    axes[0].imshow(volume[0, :, :, volume.shape[3] // 2], cmap='gray')  # Axial
    axes[0].set_title(f"Volume {i + 1} - Axial")

    axes[1].imshow(volume[0, :, volume.shape[2] // 2, :], cmap='gray')  # Coronal
    axes[1].set_title(f"Volume {i + 1} - Coronal")

    axes[2].imshow(volume[0, volume.shape[1] // 2, :, :], cmap='gray')  # Sagittal
    axes[2].set_title(f"Volume {i + 1} - Sagittal")

    plt.tight_layout()
    plt.show()
