In [1]:
import pre_proc_functions as proc
import os
import train_functions as tr
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from tqdm.notebook import tqdm
from PIL import Image
from torchvision import transforms
import cv2
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torch
import numpy as np
import segmentation_models_pytorch as smp
import albumentations as albu
from sklearn.metrics import jaccard_score, f1_score
from segmentation_models_pytorch.losses import JaccardLoss
import pandas as pd

Defining Raw Path

In [None]:
img_dir = "/home/zimu/Desktop/Code/Skull_Stripe_dataset/zhao_data/mask"
msk_dir = "/home/zimu/Desktop/Code/Skull_Stripe_dataset/zhao_data/mask"
images_output_dir = "./split"
img_fixed = "_mc_restore"

In [None]:
# Use if images needs swapdim
#proc.apply_fslswapdim_to_folder(img_dir, img_dir + '_sw')
#proc.apply_fslswapdim_to_folder(msk_dir, msk_dir + '_sw')

In [None]:
proc.proc_img_masks('data/img','data/mask',out_dir='data_split',img_fixed = "_mc_restore",test_size=0.5)

Defining Paths

In [None]:

x_train_dir_sag = 'data_split/sag/train_img'
y_train_dir_sag = 'data_split/sag/train_masks'

x_valid_dir_sag = 'data_split/sag/valid_img'
y_valid_dir_sag = 'data_split/sag/valid_masks'

x_train_dir_cor = 'data_split/cor/train_img'
y_train_dir_cor = 'zhao_data_split/cor/train_masks'

x_valid_dir_cor = 'data_split/cor/valid_img'
y_valid_dir_cor = 'data_split/cor/valid_masks'

x_train_dir_ax = 'data_split/ax/train_img'
y_train_dir_ax = 'data_split/ax/train_masks'

x_valid_dir_ax = 'data_split/ax/valid_img'
y_valid_dir_ax = 'data_split/ax/valid_masks'

In [None]:
views = ['sag', 'cor', 'ax']
CLASSES = ['brain']
DEVICE = 'cuda'
models = {'Unet':{'sag':{},'cor':{},'ax':{}}}
preprocessing_fn = {}
lr = 0.0001

for i, view in enumerate(views):
    # create segmentation model with pretrained encoder
    # Make sure the model state and path are pointed to the correct encoders
    model = torch.load(f'./model_checkpoints/Unet_efficientnet-b3_{view}.pth')
    optimizer = optim.Adam(model.parameters(), lr=lr)  # Whatever optimizer you used, e.g., Adam
    checkpoint = torch.load(f'./model_checkpoints/Unet_efficientnet-b3_{view}_state.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    preprocessing_fn_i = smp.encoders.get_preprocessing_fn('efficientnet-b3', 'imagenet')
    models['Unet'][view]['model'] = model
    models['Unet'][view]['optimizer'] = optimizer
    models['Unet'][view]['start_epoch'] = 0
    models['Unet'][view]['loss'] = loss
    preprocessing_fn[view] = preprocessing_fn_i

In [None]:
train_loaders = {}
valid_loaders = {}

batchsize = 16

train_dataset_sag = tr.Dataset(
    x_train_dir_sag, 
    y_train_dir_sag, 
    augmentation=None,#get_training_augmentation(), 
    preprocessing=tr.get_preprocessing(preprocessing_fn['sag']),
    classes=CLASSES,
)

valid_dataset_sag = tr.Dataset(
    x_valid_dir_sag, 
    y_valid_dir_sag, 
    augmentation=None, 
    preprocessing=tr.get_preprocessing(preprocessing_fn['sag']),
    classes=CLASSES,
)
train_loader_sag = DataLoader(train_dataset_sag, batch_size=batchsize, shuffle=True, num_workers=12)
valid_loader_sag = DataLoader(valid_dataset_sag, batch_size=batchsize, shuffle=False, num_workers=12)

train_loaders['sag'] = train_loader_sag
valid_loaders['sag'] = valid_loader_sag

train_dataset_cor = tr.Dataset(
    x_train_dir_cor, 
    y_train_dir_cor, 
    augmentation=None,#get_training_augmentation(), 
    preprocessing=tr.get_preprocessing(preprocessing_fn['cor']),
    classes=CLASSES,
)

valid_dataset_cor = tr.Dataset(
    x_valid_dir_cor, 
    y_valid_dir_cor, 
    augmentation=None, 
    preprocessing=tr.get_preprocessing(preprocessing_fn['cor']),
    classes=CLASSES,
)

train_loader_cor = DataLoader(train_dataset_cor, batch_size=batchsize, shuffle=True, num_workers=12)
valid_loader_cor = DataLoader(valid_dataset_cor, batch_size=batchsize, shuffle=False, num_workers=12)

train_loaders['cor'] = train_loader_cor
valid_loaders['cor'] = valid_loader_cor

train_dataset_ax = tr.Dataset(
    x_train_dir_ax, 
    y_train_dir_ax, 
    augmentation=None,#get_training_augmentation(), 
    preprocessing=tr.get_preprocessing(preprocessing_fn['ax']),
    classes=CLASSES,
)

valid_dataset_ax = tr.Dataset(
    x_valid_dir_ax, 
    y_valid_dir_ax, 
    augmentation=None, 
    preprocessing=tr.get_preprocessing(preprocessing_fn['ax']),
    classes=CLASSES,
)

train_loader_ax = DataLoader(train_dataset_ax, batch_size=batchsize, shuffle=True, num_workers=12)
valid_loader_ax = DataLoader(valid_dataset_ax, batch_size=batchsize, shuffle=False, num_workers=12)

train_loaders['ax'] = train_loader_ax
valid_loaders['ax'] = valid_loader_ax

In [None]:
#Test the initial model here
sample_ex = next(iter(train_loaders[views[0]]))
tr.test_model(models['Unet'][views[0]]['model'], 1, sample_ex)

Training Block

In [None]:
for i, view in enumerate(views):
            if view != 'sag':
                    sample_ex = next(iter(train_loaders[view]))


            train_loss_list, val_loss_list = [], []
            train_iou_list, val_iou_list = [], []
            train_dice_list, val_dice_list = [], []
            model = models['Unet'][view]['model']
            print(f'training {view} __________________ 2024')
            #print(model)
            EPOCHS = 30
            device = 'cuda'

            criterion = JaccardLoss('binary')
            optimizer = models['Unet'][view]['optimizer']
            model.to(device)

            for epoch in tqdm(range(0, EPOCHS), desc="epoch", leave=False, colour='green'):
                model.train()
                train_loss, train_iou, train_dice, train_hd = 0, 0, 0, 0
                for i, data in enumerate(tqdm(train_loaders[view], desc="training", leave=False, colour='red')):
                    img, mask = data
                    img, mask = img.to(device), mask.to(device)

                    # Run prediction
                    optimizer.zero_grad()
                    y_pred = model(img)
                    loss = criterion(y_pred, mask)
                    train_loss += loss.item()
                    loss.backward()
                    optimizer.step()

                    # Convert predictions and ground truth to binary (assuming single channel output)
                    y_pred_bin = (y_pred.squeeze(1) > 0.5).cpu().numpy().astype(np.uint8)
                    mask_bin = mask.cpu().numpy().astype(np.uint8)

                    # Update training metrics
                    train_iou += jaccard_score(mask_bin.flatten(), y_pred_bin.flatten())
                    train_dice += f1_score(mask_bin.flatten(), y_pred_bin.flatten())
                    

                # Average the training metrics
                train_loss /= len(train_loaders[view])
                train_iou /= len(train_loaders[view])
                train_dice /= len(train_loaders[view])
            

                # Append training metrics
                train_loss_list.append(train_loss)
                train_iou_list.append(train_iou)
                train_dice_list.append(train_dice)


                # Validation
                model.eval()
                val_loss, val_iou, val_dice, val_hd = 0, 0, 0, 0
                for i, data in enumerate(tqdm(valid_loaders[view], desc="validation", leave=False, colour='blue')):
                    img, mask = data
                    img, mask = img.to(device), mask.to(device)
                    with torch.no_grad():
                        y_pred = model(img)
                        loss = criterion(y_pred, mask)
                        val_loss += loss.item()

                        # Convert predictions and ground truth to binary (assuming single channel output)
                        y_pred_bin = (y_pred.squeeze(1) > 0.5).cpu().numpy().astype(np.uint8)
                        mask_bin = mask.cpu().numpy().astype(np.uint8)

                        # Update validation metrics
                        val_iou += jaccard_score(mask_bin.flatten(), y_pred_bin.flatten())
                        val_dice += f1_score(mask_bin.flatten(), y_pred_bin.flatten())
                    

                # Average the validation metrics
                val_loss /= len(valid_loaders[view])
                val_iou /= len(valid_loaders[view])
                val_dice /= len(valid_loaders[view])


                # Append validation metrics
                val_loss_list.append(val_loss)
                val_iou_list.append(val_iou)
                val_dice_list.append(val_dice)


                tr.test_model(model, epoch, sample_ex)
                print(f'{epoch}_{view}, metrics at: \ntrain loss - {train_loss} \ntrain IOU - {train_iou} \ntrain dice - {train_dice} \nval loss - {val_loss} \nval IOU - {val_iou} \nval_dice - {val_dice}')
                torch.save(model, f'./checkpoints/{epoch}_{view}.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_loss,
                    # any other metrics you might find useful
                }, f'./checkpoints/{epoch}_{view}_state.pth')
            # Create a DataFrame and save to CSV
            metrics_df = pd.DataFrame({
                'Epoch': list(range(1, EPOCHS + 1)),
                'Train_Loss': train_loss_list,
                'Validation_Loss': val_loss_list,
                'Train_IOU': train_iou_list,
                'Validation_IOU': val_iou_list,
                'Train_Dice': train_dice_list,
                'Validation_Dice': val_dice_list,

            })
            metrics_df.to_csv(f'./checkpoints/{epoch}_{view}_training_metrics.csv', index=False)