# Imports

In [1]:
import monai
from monai.handlers.utils import from_engine
from monai.networks.layers import Norm
from monai.data import CacheDataset, decollate_batch
import numpy as np
from monai.data import Dataset, DataLoader
from monai.transforms import (Transform,AsDiscrete,Activations, Activationsd, Compose, LoadImaged,
                              Transposed, ScaleIntensityd, RandAxisFlipd, RandRotated, RandAxisFlipd,
                              RandBiasFieldd, ScaleIntensityRangePercentilesd, RandAdjustContrastd,
                              RandHistogramShiftd, DivisiblePadd, Orientationd, RandGibbsNoised, Spacingd,
                              RandRicianNoised, AsChannelLastd, RandSpatialCropd,ToNumpyd,EnsureChannelFirstd,
                              RandSpatialCropSamplesd, RandCropByPosNegLabeld)
from monai.data.utils import pad_list_data_collate
import pandas as pd
import random as rd
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchsummary import summary
import nibabel as nib
from torch.optim.lr_scheduler import StepLR
import pytorch_lightning
import wandb

## Wandb login:

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33malijohnnaqvi6[0m ([33mali-john[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

## Data Loader

In [3]:
train_transforms = Compose(
    [
        
        LoadImaged(keys=["img"],image_only=True),
        EnsureChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img"],
            minv=0.0,
            maxv=1.0),
        RandRotated(keys=["img"],
            range_x=np.pi / 12,
            prob=0.5,
            keep_size=True,
            mode="nearest"),

    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img"],image_only=True),
        EnsureChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img"],
            minv=0.0,
            maxv=1.0),
        RandRotated(keys=["img"],
            range_x=np.pi / 12,
            prob=0.5,
            keep_size=True,
            mode="nearest"),

    ]
)


def load_data(batch, root_dir):
   
    data = pd.read_csv(os.path.join(root_dir,'participants_csv.csv'))
    imgs_list = list(data['participant_id'])
    age_labels = list(data['age'])
    print(imgs_list[0])
    

    length = len(imgs_list)
    print(f'Total images: {length}')
    # why we are using 85 of dataset?? we can split 100% of dataset
    test = int(0.85*length)

    imgs_list = imgs_list[:test]
    age_labels = age_labels[:test]

    first = int(0.75*length)

    imgs_list_train = imgs_list[0:first]
    imgs_list_val = imgs_list[first:]
    age_labels_train = age_labels[0:first]
    age_labels_val = age_labels[first:]

    print('train set', len(imgs_list_train), len(age_labels_train))
    print('val set', len(imgs_list_val), len(age_labels_val))

    filenames_train = [{"img": x, "age": y} for (x,y) in zip(imgs_list_train, age_labels_train)]


    ds_train = monai.data.Dataset(filenames_train, train_transforms)
    print('ds train type', type(ds_train))
    train_loader = DataLoader(ds_train, batch_size=batch, shuffle = True, num_workers=2, pin_memory=True, collate_fn=pad_list_data_collate)

    filenames_val = [{"img": x, "age": y} for (x, y) in zip(imgs_list_val, age_labels_val)]

    ds_val = monai.data.Dataset(filenames_val, val_transforms)
    print('ds val type', type(ds_val))
    val_loader = DataLoader(ds_val, batch_size=batch, shuffle=True, num_workers=1, pin_memory=True, collate_fn=pad_list_data_collate)
    
    return ds_train, train_loader, ds_val, val_loader

## 3D U-net Model:

In [4]:

class build_seq_sfcn(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, padding=1), #[2,32,256,256,192]
            nn.BatchNorm3d(32),
            nn.MaxPool3d((2, 2, 2)),#[2,32,128,128,96]
            nn.ReLU(),

            nn.Conv3d(32, 64, kernel_size=3, padding=1),#[2,64,128,128,96]
            nn.BatchNorm3d(64),
            nn.MaxPool3d((2, 2, 2)),#[2,64,64,64,48]
            nn.ReLU(),

            nn.Conv3d(64, 128, kernel_size=3, padding=1),#[2,128,64,64,48]
            nn.BatchNorm3d(128),
            nn.MaxPool3d((2, 2, 2)),#[2,128,32,32,24]
            nn.ReLU(),

            nn.Conv3d(128, 256, kernel_size=3, padding=1),#[2,256,32,32,24]
            nn.BatchNorm3d(256),
            nn.MaxPool3d((2, 2, 2)),#[2,256,16,16,12]
            nn.ReLU(),

            nn.Conv3d(256, 256, kernel_size=3, padding=1),#[2,256,16,16,12]
            nn.BatchNorm3d(256),
            nn.MaxPool3d((2, 2, 2)),#[2,256,8,8,6]
            nn.ReLU(),

            nn.Conv3d(256, 64, kernel_size=1, padding=1),#[2,64,8,8,6]
            nn.BatchNorm3d(64),
            nn.ReLU()
        )

        self.regressor= nn.Sequential(
            nn.Conv3d(64, 1, kernel_size=1, padding=1),#[2,1,8,8,6]
            nn.Flatten(),
            nn.Linear(1440, 1),
            nn.ReLU()
        )


    def forward(self, inputs):

        glob_age_output = self.features(inputs)
        glob_age_output = self.regressor(glob_age_output)


        return glob_age_output

## Train loop:

In [5]:
def train(train_loader, val_loader, model, optimizer, scheduler, max_epochs, root_dir):

    metrices = {}
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print('Device set to Cuda')
    else:
        device = torch.device('cpu')
    model.train()
    loss_object  = nn.MSELoss()
    best_val_loss = 0.0
    best_mae_score = 0.0
    for epoch in range(1,max_epochs +1):
        train_loss = 0.0
        val_loss = 0.0
    
        print("Epoch ", epoch)
        print("Train:", end ="")

        for step, batch in enumerate(train_loader):
            img, age = (batch["img"].cuda(), batch["age"].cuda())
            age = age.unsqueeze(1)
       
            optimizer.zero_grad()

            pred_glob_age = model(img)

            loss = loss_object(pred_glob_age.float(), age.float())
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            print("=", end = "")

        train_loss = train_loss/(step+1)
        metrices["train_loss"] = train_loss
        

        print()
        print("Val:", end ="")
        with torch.no_grad():
                mae_loss=0.0
                for step, batch in enumerate(val_loader):
                    img, age = (batch["img"].cuda(), batch["age"].cuda())
                    age = age.unsqueeze(1)

                    pred_glob_age = model(img)


                    loss = loss_object(pred_glob_age.float(), age.float())
                    val_loss += loss.item()

                    print("=", end = "")
                print()
                val_loss = val_loss/(step+1)
                metrices["val_loss"] = val_loss


        print("Training epoch ", epoch, ", train loss:", train_loss, ", val loss:", val_loss, " | ", optimizer.param_groups[0]['lr'])
        wandb.log(metrices)
        if epoch == 1:
            best_val_loss = val_loss
        if val_loss < best_val_loss:
            print("Saving model")
            best_val_loss = val_loss
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }

            torch.save(state, '/home/wamika/ml_project/models/global_age/' + f"model_{epoch}.pth")
        scheduler.step()
    return

## Training

In [None]:
if __name__ == "__main__":

    learning_rate = 0.001
    batch_size = 3
    epochs = 100
    root_dir = '/home/wamika/ml_project/anat/'
    


    ds_train, train_loader, ds_val, val_loader = load_data( batch_size, root_dir)
    
    # Building our 3D UNET model
    model = build_seq_sfcn()
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print('device set to Cuda')

    else:
        print ("Cuda not found")
        device= torch.device('cpu')
    model = model.to(device)
    step = 20
    gamma=0.5
    optimizer = torch.optim.Adam(model.parameters(), learning_rate, weight_decay=1e-5, betas=(0.5, 0.999))
    scheduler = StepLR(optimizer, step_size=step, gamma=gamma)
    wandb.init(
    project="Global age prediction",
    config={
    "learning_rate": 0.001,
    "architecture": "SC-FN",
    "dataset": "CAMCAN",
    "epochs": 100,
    }
)
    
    print("Start of training...")
    train(train_loader, val_loader, model, optimizer, scheduler, epochs, root_dir)
    print("End of training...")

/home/wamika/ml_project/anat/sub-CC110033/anat/sub-CC110033_T1w.nii.gz
Total images: 653
train set 489 489
val set 66 66
ds train type <class 'monai.data.dataset.Dataset'>
ds val type <class 'monai.data.dataset.Dataset'>
device set to Cuda


Start of training...
Device set to Cuda
Epoch  1
Training epoch  1 , train loss: 356.1148931526699 , val loss: 369.57370549982244  |  0.001
Epoch  2
Training epoch  2 , train loss: 196.94310204090516 , val loss: 729.8856534090909  |  0.001
Epoch  3
Training epoch  3 , train loss: 183.28752690882772 , val loss: 911.4110107421875  |  0.001
Epoch  4
Training epoch  4 , train loss: 177.65944904491215 , val loss: 644.3483567671342  |  0.001
Epoch  5
Training epoch  5 , train loss: 162.753443238194 , val loss: 389.7922099720348  |  0.001
Epoch  6
Training epoch  6 , train loss: 159.37833428821682 , val loss: 537.5222695090554  |  0.001
Epoch  7
Training epoch  7 , train loss: 146.82957220370054 , val loss: 297.421065937389  |  0.001
Saving model
Epoch  8
Training epoch  8 , train loss: 136.18092932145288 , val loss: 328.9022085016424  |  0.001
Epoch  9
Training epoch  9 , train loss: 124.74123104072056 , val loss: 301.538854078813  |  0.001
Epoch  10
Training epoch  10 , train loss: 117.0723

In [None]:
file_path = "/home/wamika/ml_project/train.txt"
with open(file_path, "w") as file:
    for item in train_losses:
        file.write(str(item) + "\n")


In [None]:
p = "/home/wamika/ml_project/val.txt"
with open(p, "w") as f:
    for item in val_losses:
        f.write(str(item) + "\n")