In [None]:
# Install required libraries
!pip install torch monai nilearn torchio scikit-learn numpy nibabel

import os
import numpy as np
import torch
import torchio
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.transforms import Compose, AddChannel, ScaleIntensity, ToTensor, Resample, EnsureChannelFirst, CropOrPad, RandSpatialCrop
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from monai.data import NiftiDataset
import nilearn.image as nilimage
import nilearn.plotting as nilplot
from sklearn.model_selection import train_test_split
from torch.optim import Adam

# Define the path to your BRATS data directory and lesion segmentations
brats_data_dir = '/path/to/brats_data'
segmentations_dir = '/path/to/lesion_segmentations'

# Create a list of subject and label image paths
subject_files = []
label_files = []
for subject_id in os.listdir(brats_data_dir):
    subject_dir = os.path.join(brats_data_dir, subject_id)
    subject_files.append(os.path.join(subject_dir, 'T1.nii.gz'))
    label_files.append(os.path.join(segmentations_dir, f'{subject_id}_lesion.nii.gz'))

# Split the data into training and validation sets
train_subjects, val_subjects, train_labels, val_labels = train_test_split(
    subject_files, label_files, test_size=0.2, random_state=42
)

# Define MONAI transforms for preprocessing
train_transforms = Compose([
    ScaleIntensity(),
    EnsureChannelFirst(),
    AddChannel(),
    Resample((1, 1, 1)),
    CropOrPad((128, 128, 128)),
    RandSpatialCrop((96, 96, 96), random_size=False),
    ToTensor()
])

val_transforms = Compose([
    ScaleIntensity(),
    EnsureChannelFirst(),
    AddChannel(),
    Resample((1, 1, 1)),
    CropOrPad((128, 128, 128)),
    ToTensor()
])

# Create NiftiDataset for training and validation
train_ds = NiftiDataset(
    image_files=train_subjects,
    label_files=train_labels,
    transform=train_transforms
)

val_ds = NiftiDataset(
    image_files=val_subjects,
    label_files=val_labels,
    transform=val_transforms
)

# Create data loaders
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4)

# Create U-Net model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
).to(device)

# Define loss function and optimizer
loss_function = DiceLoss(sigmoid=True)
optimizer = Adam(model.parameters(), lr=1e-4)

# Create a Dice metric for validation
dice_metric = DiceMetric(include_background=True, reduction='mean')

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        inputs, targets = batch['image'].to(device), batch['label'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    with torch.no_grad():
        metric = dice_metric(y_pred=outputs, y=targets)
        print(f"Epoch [{epoch + 1}/{num_epochs}] Dice: {metric.item():.4f}")

# Save the trained model
torch.save(model.state_dict(), 'unet_brats_model.pt')
