In [1]:
# SECTION 1: SETUP & IMPORTS

# Core ML
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import nibabel as nib
import matplotlib.pyplot as plt
import torchio as tio
import warnings
import torchvision.transforms as transforms



# System and Warnings
import warnings
warnings.filterwarnings("ignore")

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [2]:
# Configurations
IMG_SIZE = 128
VOLUME_SLICES = 50
VOLUME_START_AT = 22
BATCH_SIZE = 4
NUM_EPOCHS = 5
NUM_ROUNDS = 2
batch_size = 4  # stable
pin_memory = True
num_workers = 4  # or os.cpu_count() // 2



In [3]:
# CELL: Fixed and Safe BraTSDataset

class BraTSDataset(Dataset):
    def __init__(self, root_dir, transform=None, train=True):
        self.root_dir = root_dir
        self.patient_dirs = sorted(os.listdir(root_dir))
        self.transform = transform
        self.train = train

    def __len__(self):
        return len(self.patient_dirs)

    def __getitem__(self, idx):
        patient_id = self.patient_dirs[idx]
        patient_path = os.path.join(self.root_dir, patient_id)

        # Load MRI modalities
        modalities = ['t1c', 't1n', 't2f', 't2w']
        image_data = []
        for mod in modalities:
            image_path = os.path.join(patient_path, f"{patient_id}-{mod}.nii")
            image = nib.load(image_path).get_fdata()
            image_data.append(image)

        image_np = np.stack(image_data, axis=0).astype(np.float32)
        image_tensor = torch.tensor(image_np, dtype=torch.float32)

        # Load segmentation label if in training mode
        if self.train:
            label_path = os.path.join(patient_path, f"{patient_id}-seg.nii")
            label_np = nib.load(label_path).get_fdata().astype(np.uint8)

            # 🔄 Remap BraTS labels: [0, 1, 2, 4] → [0, 1, 2, 3]
            label_np[label_np == 4] = 3

            # ✅ FIX: Ensure label tensor is long
            label_tensor = torch.from_numpy(label_np).long().unsqueeze(0)
        else:
            label_tensor = None

        # Apply TorchIO preprocessing
        if self.transform:
            subject_dict = {"images": tio.ScalarImage(tensor=image_tensor)}
            if label_tensor is not None:
                subject_dict["label"] = tio.LabelMap(tensor=label_tensor)
            subject = tio.Subject(**subject_dict)
            transformed = self.transform(subject)
            image_tensor = transformed.images.data
            if label_tensor is not None:
                label_tensor = transformed.label.data

        # 🛡 Sanity check
        if self.train:
            if not torch.is_tensor(image_tensor) or not torch.is_tensor(label_tensor):
                raise RuntimeError(f"[ERROR] {patient_id}: invalid tensor types")

        return (image_tensor, label_tensor) if self.train else image_tensor


In [4]:
# CELL 3: TorchIO Preprocessing Transform

transform = tio.Compose([
    tio.RescaleIntensity(out_min_max=(0, 1)),  # Normalize intensities to [0, 1]
    tio.Resize((128, 128, 128)),               # Resize to fixed shape
    tio.ZNormalization()                       # Normalize mean=0, std=1
])


In [5]:
augment_transform = tio.Compose([
    # Intensity-based
    tio.RandomBiasField(p=0.3),
    tio.RandomGamma(p=0.3),
    tio.RandomNoise(p=0.2),
    
    # Spatial-based
    tio.RandomAffine(
        scales=(0.9, 1.1),
        degrees=10,
        translation=5,
        center='image',
        p=0.5
    ),
    tio.RandomElasticDeformation(p=0.2),
    tio.RandomFlip(axes=('LR',), p=0.5),

    # Preprocessing
    tio.RescaleIntensity(out_min_max=(0, 1)),
    tio.Resize((128, 128, 128)),
    tio.ZNormalization()
])


In [6]:
# CELL 4: Client Dataset Paths and Loaders

# Define paths (adjust as needed to match your directory structure)
Hospital1_Train = 'Data/2023GLI/TrainingData'
Hospital1_Val   = 'Data/2023GLI/ValidationData'
Hospital2_Train = 'Data/2023MEN/TrainingData'
Hospital2_Val   = 'Data/2023MEN/ValidationData'
Hospital3_Train = 'Data/2023MET/TrainingData'
Hospital3_Val   = 'Data/2023MET/ValidationData'
Hospital4_Train = 'Data/2023PED/TrainingData'
Hospital4_Val   = 'Data/2023PED/ValidationData'
Hospital5_Train = 'Data/2023SSA/TrainingData'
Hospital5_Val   = 'Data/2023SSA/ValidationData'
Hospital6_Train_Val = 'Data/BraTS2021'
Hospital7_Train = 'Data/BraTS2020/TrainingData'
Hospital7_Val   = 'Data/BraTS2020/ValidationData'
Hospital8_Train_Val = 'Data/BraTS2019/HGG'
Hospital9_Train_Val = 'Data/BraTS2019/LGG'

hospitals = {
    "Hospital1": {"train": Hospital1_Train, "val": Hospital1_Val},
    "Hospital2": {"train": Hospital2_Train, "val": Hospital2_Val},
    "Hospital3": {"train": Hospital3_Train, "val": Hospital3_Val},
    "Hospital4": {"train": Hospital4_Train, "val": Hospital4_Val},
    "Hospital5": {"train": Hospital5_Train, "val": Hospital5_Val},
    "Hospital6": {"combined": Hospital6_Train_Val},
    "Hospital7": {"train": Hospital7_Train, "val": Hospital7_Val},
    "Hospital8": {"combined": Hospital8_Train_Val},
    "Hospital9": {"combined": Hospital9_Train_Val}
}

hospital_loaders = {}
train_ratio = 0.8
batch_size = 8

for hospital, paths in hospitals.items():
    print(f"🔁 Loading {hospital}...")

    if "combined" in paths:
        #full_dataset = BraTSDataset(paths["combined"], transform=transform, train=True)
        full_dataset = BraTSDataset(paths["combined"], transform=augment_transform, train=True)
        train_size = int(train_ratio * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_set, val_set = random_split(full_dataset, [train_size, val_size])
    else:
        # train_set = BraTSDataset(paths["train"], transform=transform, train=True)
        # val_set = BraTSDataset(paths["val"], transform=transform, train=False)

        train_set = BraTSDataset(paths["train"], transform=augment_transform, train=True)
        val_set = BraTSDataset(paths["val"], transform=transform, train=False)


    train_loader = DataLoader(train_set, batch_size=8, shuffle=True,
                          num_workers=2)
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False)

    hospital_loaders[hospital] = {
        "train": train_loader,
        "val": val_loader
    }

print("\n✅ All hospital loaders are ready.")


🔁 Loading Hospital1...
🔁 Loading Hospital2...
🔁 Loading Hospital3...
🔁 Loading Hospital4...
🔁 Loading Hospital5...
🔁 Loading Hospital6...
🔁 Loading Hospital7...
🔁 Loading Hospital8...
🔁 Loading Hospital9...

✅ All hospital loaders are ready.


In [13]:
def visualize_random_sample(hospital="Hospital1", slice_idx=64):
    loader = hospital_loaders[hospital]["train"]
    images, labels = next(iter(loader))

    # Shape: (B, C, D, H, W)
    image = images[0]   # → (4, D, H, W)
    label = labels[0]   # → (1, D, H, W)

    print(f"Image shape: {image.shape}, Label shape: {label.shape}")

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))

    # Show FLAIR image (channel 3)
    axs[0].imshow(image[3, :, :, slice_idx].cpu(), cmap='gray')
    axs[0].set_title(f"{hospital} - FLAIR slice {slice_idx}")
    axs[0].axis('off')

    # Show corresponding segmentation mask
    axs[1].imshow(label[0, :, :, slice_idx].cpu(), cmap='Reds')
    axs[1].set_title(f"{hospital} - Mask slice {slice_idx}")
    axs[1].axis('off')

    plt.tight_layout()
    plt.show()


In [14]:
visualize_random_sample("Hospital1")

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/awakili/anaconda3/envs/DLEnv/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/awakili/anaconda3/envs/DLEnv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
  File "/home/awakili/anaconda3/envs/DLEnv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/home/awakili/anaconda3/envs/DLEnv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 211, in collate
    return [
  File "/home/awakili/anaconda3/envs/DLEnv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 212, in <listcomp>
    collate(samples, collate_fn_map=collate_fn_map)
  File "/home/awakili/anaconda3/envs/DLEnv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/home/awakili/anaconda3/envs/DLEnv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 272, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
RuntimeError: torch.cat(): input types can't be cast to the desired output type Long


In [None]:
print("\n📊 Dataset sizes per hospital:")
for hospital, loaders in hospital_loaders.items():
    train_size = len(loaders['train'].dataset)
    val_size = len(loaders['val'].dataset)
    print(f"{hospital} → Train: {train_size} | Val: {val_size}")


In [None]:
# CELL 7: Fixed TwinSegNet Model (ViT + UNet Hybrid)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UpBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(UpBlock, self).__init__()
        self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = ConvBlock(out_channels + skip_channels, out_channels)

    def forward(self, x, skip):
        x = self.up(x)
        # Resize skip connection if needed
        if x.shape[2:] != skip.shape[2:]:
            skip = F.interpolate(skip, size=x.shape[2:], mode='trilinear', align_corners=False)
        x = torch.cat((x, skip), dim=1)
        return self.conv(x)


class PatchEmbedViT(nn.Module):
    def __init__(self, in_channels=128, embed_dim=256, patch_size=2):
        super(PatchEmbedViT, self).__init__()
        self.patch_embed = nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, 1, 1))  # minimal shape

    def forward(self, x):
        x = self.patch_embed(x)
        if x.shape[2:] != self.pos_embed.shape[2:]:
            pos_embed = F.interpolate(self.pos_embed, size=x.shape[2:], mode='trilinear', align_corners=False)
        else:
            pos_embed = self.pos_embed
        return x + pos_embed


class TwinSegNet(nn.Module):
    def __init__(self, in_channels=4, n_classes=3, base_channels=32):
        super(TwinSegNet, self).__init__()

        # Encoder
        self.enc1 = ConvBlock(in_channels, base_channels)              # 128
        self.pool1 = nn.MaxPool3d(2)
        self.enc2 = ConvBlock(base_channels, base_channels * 2)        # 64
        self.pool2 = nn.MaxPool3d(2)
        self.enc3 = ConvBlock(base_channels * 2, base_channels * 4)    # 32
        self.pool3 = nn.MaxPool3d(2)
        self.enc4 = ConvBlock(base_channels * 4, base_channels * 8)    # 16
        self.pool4 = nn.MaxPool3d(2)

        # ViT at bottleneck
        self.vit = PatchEmbedViT(in_channels=base_channels * 8, embed_dim=base_channels * 16, patch_size=2)
        self.vit_proj = nn.Conv3d(base_channels * 16, base_channels * 8, kernel_size=1)

        # Decoder (UNet-style)
        self.up4 = UpBlock(base_channels * 8, base_channels * 4, base_channels * 4)  # match enc3
        self.up3 = UpBlock(base_channels * 4, base_channels * 2, base_channels * 2)  # match enc2
        self.up2 = UpBlock(base_channels * 2, base_channels, base_channels)          # match enc1
        self.final_conv = nn.Conv3d(base_channels, n_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)                # [B, 32, 128,128,128]
        e2 = self.enc2(self.pool1(e1))   # [B, 64, 64,64,64]
        e3 = self.enc3(self.pool2(e2))   # [B, 128, 32,32,32]
        e4 = self.enc4(self.pool3(e3))   # [B, 256, 16,16,16]
        b = self.pool4(e4)               # [B, 256, 8,8,8]

        b = self.vit(b)                  # [B, 512, 4,4,4] → [B, 256, 4,4,4]
        b = self.vit_proj(b)             # [B, 256, 4,4,4]

        d4 = self.up4(b, e3)             # [B, 128, 8,8,8]
        d3 = self.up3(d4, e2)            # [B, 64, 16,16,16]
        d2 = self.up2(d3, e1)            # [B, 32, 32,32,32]
        out = self.final_conv(d2)        # [B, n_classes, 32,32,32]

        return out


In [None]:
def dice_loss(pred, target, epsilon=1e-6):
    """
    pred: (B, C, D, H, W)
    target: (B, 1, D', H', W') → label mask
    """
    pred = torch.softmax(pred, dim=1)

    # Resize target to match pred spatial shape
    if target.shape[2:] != pred.shape[2:]:
        target = F.interpolate(target.float(), size=pred.shape[2:], mode='nearest')

    # One-hot encode
    target = target.squeeze(1).long()  # (B, D, H, W)
    target = F.one_hot(target, num_classes=pred.shape[1])  # (B, D, H, W, C)
    target = target.permute(0, 4, 1, 2, 3).float()          # (B, C, D, H, W)

    # Dice computation
    intersection = (pred * target).sum(dim=(2, 3, 4))
    union = pred.sum(dim=(2, 3, 4)) + target.sum(dim=(2, 3, 4))
    dice = (2 * intersection + epsilon) / (union + epsilon)

    return 1 - dice.mean()


def dice_coefficient(pred, target, epsilon=1e-6):
    """
    Computes Dice coefficient for evaluation (not differentiable).
    Assumes pred is logits and target is class index (not one-hot).
    """
    pred = torch.argmax(torch.softmax(pred, dim=1), dim=1)
    target = target.squeeze(1).long()

    dice_scores = []
    for class_id in range(1, pred.shape[1] if pred.ndim == 5 else 2):  # skip background
        pred_class = (pred == class_id).float()
        target_class = (target == class_id).float()

        intersection = (pred_class * target_class).sum()
        union = pred_class.sum() + target_class.sum()

        dice = (2 * intersection + epsilon) / (union + epsilon)
        dice_scores.append(dice.item())

    return np.mean(dice_scores)


In [None]:
from torch.cuda.amp import autocast, GradScaler
import gc

scaler = GradScaler()

def train_one_client(model, dataloader, optimizer, epochs=1, client_name="", device="cuda:0"):
    model.to(device)
    model.train()
    total_loss = 0

    scaler = GradScaler()
    for epoch in range(epochs):
        epoch_loss = 0
        print(f"\n🚀 [{client_name}] Epoch {epoch+1}/{epochs} — Training {len(dataloader)} batches")

        for i, (images, masks) in enumerate(dataloader):
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()

            with autocast():
                outputs = model(images)
                loss = dice_loss(outputs, masks)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            print(f"   📦 Batch {i+1}/{len(dataloader)} — Loss: {loss.item():.4f} — "
                  f"Mem: {torch.cuda.memory_allocated(device) / 1e6:.1f} MB")

            epoch_loss += loss.item()

        print(f"✅ [{client_name}] Epoch {epoch+1} Complete — Avg Loss: {epoch_loss / len(dataloader):.4f}")
        torch.cuda.empty_cache()
        gc.collect()

    return model.state_dict()



In [None]:
import gc
import torch

def clear_cuda():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    print(f"🧹 GPU memory cleared. Available: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")


In [None]:
!nvidia-smi
!fuser -v /dev/nvidia0  # see PIDs using the GPU


In [None]:
!kill -9 <3013>


In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


In [None]:
client_devices = {
    "Hospital1": "cuda:0",
    "Hospital2": "cuda:1",
    # add more mappings as needed, alternating between 0 and 1
}


In [None]:
from torch import optim

client_name = "Hospital1"
device = client_devices[client_name]
model = TwinSegNet(in_channels=4, n_classes=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
train_loader = hospital_loaders[client_name]["train"]

new_weights = train_one_client(model, train_loader, optimizer, epochs=2, client_name=client_name, device=device)


In [None]:
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


In [None]:

import gc
gc.collect()
torch.cuda.empty_cache()


In [None]:
torch.cuda.memory_summary()
