# Imports

In [2]:
import monai
from monai.handlers.utils import from_engine
from monai.networks.layers import Norm
from monai.data import CacheDataset, decollate_batch
import numpy as np
from monai.data import Dataset, DataLoader
from monai.transforms import (Transform,AsDiscrete,Activations, Activationsd, Compose, LoadImaged,
                              Transposed, ScaleIntensityd, RandAxisFlipd, RandRotated, RandAxisFlipd,
                              RandBiasFieldd, ScaleIntensityRangePercentilesd, RandAdjustContrastd,
                              RandHistogramShiftd, DivisiblePadd, Orientationd, RandGibbsNoised, Spacingd,
                              RandRicianNoised, AsChannelLastd, RandSpatialCropd,ToNumpyd,EnsureChannelFirstd,
                              RandSpatialCropSamplesd, RandCropByPosNegLabeld)
from monai.data.utils import pad_list_data_collate
import pandas as pd
import random as rd
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchsummary import summary
import nibabel as nib
from torch.optim.lr_scheduler import StepLR
import pytorch_lightning
import wandb

## Wandb login:

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33malijohnnaqvi6[0m ([33mali-john[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

## Data Loader

In [3]:
train_transforms = Compose(
    [
        
        LoadImaged(keys=["img"],image_only=True),
        EnsureChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img"],
            minv=0.0,
            maxv=1.0),
        RandRotated(keys=["img"],
            range_x=np.pi / 12,
            prob=0.5,
            keep_size=True,
            mode="nearest"),

    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img"],image_only=True),
        EnsureChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img"],
            minv=0.0,
            maxv=1.0),
        RandRotated(keys=["img"],
            range_x=np.pi / 12,
            prob=0.5,
            keep_size=True,
            mode="nearest"),

    ]
)


def load_data(batch, root_dir):
   
    data = pd.read_csv(os.path.join(root_dir,'gloabl_ageT2.csv'))
    imgs_list = list(data['participant_id'])
    age_list = list(data['age'])
    length = len(imgs_list)
    print(f'Total images: {length}') 
    test = int(0.90*length)
    first = int(0.75*length)

    imgs_list_train = imgs_list[0:first]
    imgs_list_val = imgs_list[first:test]
    imgs_list_test = imgs_list[test:]
    age_labels_train = age_list[0:first]
    age_labels_val = age_list[first:test]
    age_labels_test = age_list[test:]

    print('train set', len(imgs_list_train), len(age_labels_train))
    print('val set', len(imgs_list_val), len(age_labels_val))

    filenames_train = [{"img": x, "age": y} for (x,y) in zip(imgs_list_train, age_labels_train)]
    ds_train = monai.data.Dataset(filenames_train, train_transforms)
    train_loader = DataLoader(ds_train, batch_size=batch, shuffle = True, num_workers=2, pin_memory=True, collate_fn=pad_list_data_collate)

    filenames_val = [{"img": x, "age": y} for (x, y) in zip(imgs_list_val, age_labels_val)]
    ds_val = monai.data.Dataset(filenames_val, val_transforms)
    val_loader = DataLoader(ds_val, batch_size=batch, shuffle=True, num_workers=1, pin_memory=True, collate_fn=pad_list_data_collate)

    filenames_test = [{"img":x, "age":y} for (x,y) in zip(imgs_list_test,age_labels_test)]
    ds_test = monai.data.Dataset(filenames_test,val_transforms)
    test_loader = DataLoader(ds_test,batch_size=batch,shuffle=True,num_workers=1,pin_memory=True,collate_fn=pad_list_data_collate)
    
    return ds_train, train_loader, ds_val, val_loader, ds_test,test_loader

## 3D U-net Model:

In [4]:

class build_seq_sfcn(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, padding=1), #[2,32,256,256,192]
            nn.BatchNorm3d(32),
            nn.MaxPool3d((2, 2, 2)),#[2,32,128,128,96]
            nn.Dropout(0.2),
            nn.ReLU(),

            nn.Conv3d(32, 64, kernel_size=3, padding=1),#[2,64,128,128,96]
            nn.BatchNorm3d(64),
            nn.MaxPool3d((2, 2, 2)),#[2,64,64,64,48]
            nn.Dropout(0.2),
            nn.ReLU(),

            nn.Conv3d(64, 128, kernel_size=3, padding=1),#[2,128,64,64,48]
            nn.BatchNorm3d(128),
            nn.MaxPool3d((2, 2, 2)),#[2,128,32,32,24]
            nn.Dropout(0.2),
            nn.ReLU(),

            nn.Conv3d(128, 256, kernel_size=3, padding=1),#[2,256,32,32,24]
            nn.BatchNorm3d(256),
            nn.MaxPool3d((2, 2, 2)),#[2,256,16,16,12]
            nn.Dropout(0.2),
            nn.ReLU(),

            nn.Conv3d(256, 256, kernel_size=3, padding=1),#[2,256,16,16,12]
            nn.BatchNorm3d(256),
            nn.MaxPool3d((2, 2, 2)),#[2,256,8,8,6]
            nn.Dropout(0.2),
            nn.ReLU(),

            nn.Conv3d(256, 64, kernel_size=1, padding=1),#[2,64,8,8,6]
            nn.BatchNorm3d(64),
            nn.Dropout(0.2),
            nn.ReLU()
        )

        self.regressor= nn.Sequential(
            nn.Conv3d(64, 1, kernel_size=1, padding=1),#[2,1,8,8,6]
            nn.Flatten(),
            nn.Linear(1440, 1),
            nn.ReLU()
        )


    def forward(self, inputs):

        glob_age_output = self.features(inputs)
        glob_age_output = self.regressor(glob_age_output)


        return glob_age_output

## Train loop:

In [5]:
def train(train_loader, val_loader, model, optimizer, scheduler, max_epochs, root_dir):

    metrices = {}
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print('Device set to Cuda')
    else:
        device = torch.device('cpu')
    model.train()
    loss_object  = nn.MSELoss()
    best_val_loss = 0.0
    best_mae_score = 0.0
    for epoch in range(1,max_epochs +1):
        train_loss = 0.0
        val_loss = 0.0
    
        print("Epoch ", epoch)
        print("Train:", end ="")

        for step, batch in enumerate(train_loader):
            img, age = (batch["img"].cuda(), batch["age"].cuda())
            age = age.unsqueeze(1)
       
            optimizer.zero_grad()

            pred_glob_age = model(img)

            loss = loss_object(pred_glob_age.float(), age.float())
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            print("=", end = "")

        train_loss = train_loss/(step+1)
        metrices["train_loss"] = train_loss
        

        print()
        print("Val:", end ="")
        with torch.no_grad():
                mae_loss=0.0
                for step, batch in enumerate(val_loader):
                    img, age = (batch["img"].cuda(), batch["age"].cuda())
                    age = age.unsqueeze(1)

                    pred_glob_age = model(img)


                    loss = loss_object(pred_glob_age.float(), age.float())
                    val_loss += loss.item()

                    print("=", end = "")
                print()
                val_loss = val_loss/(step+1)
                metrices["val_loss"] = val_loss


        print("Training epoch ", epoch, ", train loss:", train_loss, ", val loss:", val_loss, " | ", optimizer.param_groups[0]['lr'])
        wandb.log(metrices)
        if epoch == 1:
            best_val_loss = val_loss
        if val_loss < best_val_loss:
            print("Saving model")
            best_val_loss = val_loss
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }

            torch.save(state, 'PATH/TO/WHERE/TO/SAVE/MODEL' + f"model_{epoch}.pth")
        scheduler.step()
    return

## Training

In [None]:
if __name__ == "__main__":

    learning_rate = 0.001
    batch_size = 3
    epochs = 60
    root_dir = 'PATH/TO/DATASET/DIRECTORY'
    


    ds_train, train_loader, ds_val, val_loader = load_data( batch_size, root_dir)
    
    # Building our 3D UNET model
    model = build_seq_sfcn()
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print('device set to Cuda')

    else:
        print ("Cuda not found")
        device= torch.device('cpu')
    model = model.to(device)
    step = 20
    gamma=0.5
    optimizer = torch.optim.Adam(model.parameters(), learning_rate, weight_decay=1e-5, betas=(0.5, 0.999))
    scheduler = StepLR(optimizer, step_size=step, gamma=gamma)
    wandb.init(
    project="T2 Global age",
    config={
    "learning_rate": 0.001,
    "architecture": "SC-FN",
    "dataset": "CAMCAN",
    "epochs": 70,
    }
)
    
    print("Start of training...")
    train(train_loader, val_loader, model, optimizer, scheduler, epochs, root_dir)
    print("End of training...")

# Inference:

In [5]:
# load the best saved model
state = torch.load('/home/wamika/ml_project/models/global_age/T2/model_59.pth')
model = build_seq_sfcn()
model.load_state_dict(state['state_dict'])
device = torch.device("cuda")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5,betas=(0.5, 0.999))
optimizer.load_state_dict(state['optimizer'])

In [6]:
batch_size = 1
root_dir = '/home/wamika/ml_project/anat/'
ds_train, train_loader, ds_val, val_loader, ds_test, test_loader = load_data( batch_size, root_dir)
differences = []
i=0
with torch.no_grad():
    for step, batch in enumerate(test_loader):
                img, age = (batch["img"].cuda(), batch["age"].cuda())
                age = age.cpu().numpy()
                pred_glob_age = model(img)
                pred_glob_age = pred_glob_age.cpu().numpy().squeeze(0)
                diff = abs(pred_glob_age-age)
                differences.append(diff)
                i=i+1
                if i==50:
                    break

Total images: 653
train set 489 489
val set 98 98


In [9]:
print(f'Test Set MAE: {sum(differences)/len(differences)}')

Test Set MAE: [7.55813675]
