In [2]:
import os
import csv
import random
import pickle
import numpy as np
import nibabel as nib
from torch.utils.data import Dataset, DataLoader

# Updated dataset CSV path
csv_path = "../data/training_detailed_summary_2020.csv"
save_dir = "data/train_dev_all/"
os.makedirs(save_dir, exist_ok=True)

# Parse the new dataset CSV
data_records = {}
with open(csv_path, 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        subj_id = row['Subject ID']
        scan_type = row['Scan Type']
        file_path = row['File Path']
        if subj_id not in data_records:
            data_records[subj_id] = {}
        data_records[subj_id][scan_type] = file_path

# Split data into train and validation sets
subject_ids = list(data_records.keys())
random.seed(42)
random.shuffle(subject_ids)

train_split = int(len(subject_ids) * 0.8)
train_ids = subject_ids[:train_split]
val_ids = subject_ids[train_split:]

# Calculate mean and std incrementally to avoid memory issues
data_types = ['flair', 't1', 't1ce', 't2']
mean_std_dict = {dtype: {'mean': 0.0, 'std': 1.0} for dtype in data_types}

for dtype in data_types:
    sum_vals = 0
    sum_squared_vals = 0
    num_voxels = 0

    for subj_id in train_ids:
        if dtype in data_records[subj_id]:
            img_path = data_records[subj_id][dtype]
            img = nib.load(img_path).get_fdata(dtype=np.float32)
            sum_vals += np.sum(img)
            sum_squared_vals += np.sum(img**2)
            num_voxels += img.size

    mean = sum_vals / num_voxels
    variance = (sum_squared_vals / num_voxels) - (mean**2)
    std = np.sqrt(variance)

    mean_std_dict[dtype]['mean'] = mean
    mean_std_dict[dtype]['std'] = std

# Save the mean and std dictionary
with open(os.path.join(save_dir, 'mean_std_dict.pickle'), 'wb') as f:
    pickle.dump(mean_std_dict, f)

print(f"Data preparation complete. Total subjects: {len(subject_ids)}, Train: {len(train_ids)}, Validation: {len(val_ids)}")


Data preparation complete. Total subjects: 369, Train: 295, Validation: 74


In [9]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = self.conv_block(n_channels, 64)
        self.down1 = self.down(64, 128)
        self.down2 = self.down(128, 256)
        self.down3 = self.down(256, 512)
        self.down4 = self.down(512, 1024)
        self.up1 = self.up(1024, 512)
        self.up2 = self.up(512, 256)
        self.up3 = self.up(256, 128)
        self.up4 = self.up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1['upsample'](x5)
        x = torch.cat([x, x4], dim=1)
        x = self.up1['conv'](x)

        x = self.up2['upsample'](x)
        x = torch.cat([x, x3], dim=1)
        x = self.up2['conv'](x)

        x = self.up3['upsample'](x)
        x = torch.cat([x, x2], dim=1)
        x = self.up3['conv'](x)

        x = self.up4['upsample'](x)
        x = torch.cat([x, x1], dim=1)
        x = self.up4['conv'](x)

        logits = self.outc(x)
        return logits

    @staticmethod
    def conv_block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    @staticmethod
    def down(in_channels, out_channels):
        return nn.Sequential(
            nn.MaxPool2d(2),
            UNet.conv_block(in_channels, out_channels)
        )

    @staticmethod
    def up(in_channels, out_channels):
        return nn.ModuleDict({
            'upsample': nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            'conv': UNet.conv_block(out_channels * 2, out_channels)
        })


In [10]:

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import time

# Updated dataset loader
class BrainSegDataset(Dataset):
    def __init__(self, data_records, mean_std_dict, subject_ids, is_train=True):
        self.data_records = data_records
        self.mean_std_dict = mean_std_dict
        self.subject_ids = subject_ids
        self.is_train = is_train
        self.slices = self._prepare_slices()

    def _prepare_slices(self):
        slices = []
        for subj_id in self.subject_ids:
            if 'seg' not in self.data_records[subj_id]:
                print(f"Warning: Segmentation file missing for subject {subj_id}. Skipping this subject.")
                continue  # Skip subjects without segmentation files
            seg_path = self.data_records[subj_id]['seg']
            seg_img = nib.load(seg_path).get_fdata()
            num_slices = seg_img.shape[-1]
            for i in range(num_slices):
                slices.append((subj_id, i))
        return slices


    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        subj_id, slice_idx = self.slices[idx]
        images = []
        for dtype in ['flair', 't1', 't1ce', 't2']:
            img_path = self.data_records[subj_id][dtype]
            img = nib.load(img_path).get_fdata()
            img_slice = img[:, :, slice_idx]
            img_slice = (img_slice - self.mean_std_dict[dtype]['mean']) / self.mean_std_dict[dtype]['std']
            images.append(img_slice)
        images = np.stack(images, axis=0)
        images = torch.tensor(images, dtype=torch.float32)

        seg_path = self.data_records[subj_id]['seg']
        seg_img = nib.load(seg_path).get_fdata()
        seg_slice = seg_img[:, :, slice_idx]
        seg_slice = (seg_slice > 0).astype(np.float32)  # Normalize target to binary (0 or 1)
        seg_slice = torch.tensor(seg_slice, dtype=torch.float32).unsqueeze(0)


        return {'image': images, 'target': seg_slice}


# Parameters
batch_size = 10
lr = 0.0001
epochs = 50

# Prepare datasets and loaders
train_dataset = BrainSegDataset(data_records, mean_std_dict, train_ids, is_train=True)
val_dataset = BrainSegDataset(data_records, mean_std_dict, val_ids, is_train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Model, loss, optimizer
model = UNet(n_channels=4, n_classes=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
    model.train()
    for batch in train_loader:
        images, targets = batch['image'].to(device), batch['target'].to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        total_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}] completed in {time.time() - start_time:.2f}s with Avg Loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            images, targets = batch['image'].to(device), batch['target'].to(device)
            outputs = model(images)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
    print(f"Validation Loss after Epoch [{epoch+1}/{epochs}]: {val_loss / len(val_loader):.4f}")

# Save the model
torch.save(model.state_dict(), "checkpoint/unet_brain_seg.pth")
print("Training complete. Model saved to checkpoint/unet_brain_seg.pth")


Epoch [1/50], Batch [0/4557], Loss: 0.7133
Epoch [1/50], Batch [0/4557], Loss: 0.7105
Epoch [1/50], Batch [0/4557], Loss: 0.7069
Epoch [1/50], Batch [0/4557], Loss: 0.7030
Epoch [1/50], Batch [0/4557], Loss: 0.7002
Epoch [1/50], Batch [0/4557], Loss: 0.6975
Epoch [1/50], Batch [0/4557], Loss: 0.6935
Epoch [1/50], Batch [0/4557], Loss: 0.6896
Epoch [1/50], Batch [0/4557], Loss: 0.6856
Epoch [1/50], Batch [0/4557], Loss: 0.6791
Epoch [1/50], Batch [0/4557], Loss: 0.6762
Epoch [1/50], Batch [0/4557], Loss: 0.6709
Epoch [1/50], Batch [0/4557], Loss: 0.6659
Epoch [1/50], Batch [0/4557], Loss: 0.6509
Epoch [1/50], Batch [0/4557], Loss: 0.6393
Epoch [1/50], Batch [0/4557], Loss: 0.6321
Epoch [1/50], Batch [0/4557], Loss: 0.6259
Epoch [1/50], Batch [0/4557], Loss: 0.6061
Epoch [1/50], Batch [0/4557], Loss: 0.5851
Epoch [1/50], Batch [0/4557], Loss: 0.5426
Epoch [1/50], Batch [0/4557], Loss: 0.4930
Epoch [1/50], Batch [0/4557], Loss: 0.4200
Epoch [1/50], Batch [0/4557], Loss: 0.2079
Epoch [1/50

KeyboardInterrupt: 