In [37]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from scipy.ndimage import zoom
import nibabel as nib
import time

class BrainMRIDataset(Dataset):
    def __init__(self, files, labels_dict, target_shape=(96, 96, 96), augment=False):
        self.files = files
        self.labels = labels_dict
        self.target_shape = target_shape
        self.augment = augment
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img = nib.load(self.files[idx]).get_fdata()
        scale = [t / s for t, s in zip(self.target_shape, img.shape)]
        img = zoom(img, scale, order=1)
        
        if self.augment:
            if np.random.rand() > 0.5:
                img = np.flip(img, axis=0).copy()
            if np.random.rand() > 0.5:
                img = np.flip(img, axis=1).copy()
        
        img = (img - img.mean()) / (img.std() + 1e-8)
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
        
        subj_id = int(os.path.basename(self.files[idx]).split('-')[0].replace('IXI', ''))
        age = self.labels.get(subj_id, 0)
        return img, torch.tensor(age, dtype=torch.float32)

In [38]:
df = pd.read_excel("IXI-T1_data/IXI.xls")
labels = dict(zip(df['IXI_ID'], df['AGE']))

data_dir = "IXI-T1_data/"
all_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.nii.gz')]

def get_subj_id(filepath):
    return int(os.path.basename(filepath).split('-')[0].replace('IXI', ''))

valid_files = [f for f in all_files if get_subj_id(f) in labels and not pd.isna(labels[get_subj_id(f)])]
print(f"Valid files with age data: {len(valid_files)} / {len(all_files)}")

train_files, test_files = train_test_split(valid_files, test_size=0.2, random_state=42)

train_dataset = BrainMRIDataset(train_files, labels, augment=True)
test_dataset = BrainMRIDataset(test_files, labels, augment=False)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")

Valid files with age data: 563 / 581
Train: 450, Test: 113


In [39]:
class BrainAgeCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.features = torch.nn.Sequential(
            torch.nn.Conv3d(1, 32, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm3d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool3d(2),
            
            torch.nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU(),
            torch.nn.MaxPool3d(2),
            
            torch.nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm3d(128),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool3d(1)
        )
        self.regressor = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(64, 1)
        )
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv3d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, torch.nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)
            elif isinstance(m, torch.nn.BatchNorm3d):
                torch.nn.init.ones_(m.weight)
                torch.nn.init.zeros_(m.bias)
    
    def forward(self, x):
        x = self.features(x)
        x = self.regressor(x)
        return x.squeeze()

model = BrainAgeCNN()
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Parameters: 286,337


In [40]:
device = 'cpu'
model = BrainAgeCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
criterion = torch.nn.MSELoss()

print(f"Training on {device}")
n_epochs = 50
best_loss = float('inf')

for epoch in range(n_epochs):
    t1 = time.time()
    model.train()
    total_loss = 0
    for img, age in train_loader:
        img, age = img.to(device), age.to(device)
        optimizer.zero_grad()
        pred = model(img)
        loss = criterion(pred, age)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    scheduler.step(avg_loss)
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), 'best_model.pth')
    
    t2 = time.time()
    lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch+1}/{n_epochs}: Loss = {avg_loss:.2f}, MAE ~ {avg_loss**0.5:.1f}y, LR = {lr:.1e}, Time = {t2-t1:.1f}s")

TypeError: ReduceLROnPlateau.__init__() got an unexpected keyword argument 'verbose'

In [36]:
model.eval()
preds, actuals = [], []
with torch.no_grad():
    for img, age in test_loader:
        img = img.to(device)
        pred = model(img)
        preds.extend(pred.cpu().numpy().flatten())
        actuals.extend(age.numpy().flatten())

mae = np.mean(np.abs(np.array(preds) - np.array(actuals)))
print(f"Test MAE: {mae:.2f} years")

Test MAE: 23.00 years
