In [6]:
import os
import sys
import random
import time
from copy import deepcopy
from pathlib import Path

import h5py
import numpy as np
from monai import data, transforms as mt

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset

from model import Block_encoder_bottleneck, device, dict_args, init_weights

In [2]:
# Train and test data path

train_image_path = "../dataset/train_rosfl"
test_image_path = "../dataset/testing"

# target/crop shape for the images and masks when training
tar_shape = (256, 256)
crop_shape = (224, 224)


def normalize(data):
    data = (data - data.mean()) / data.std()
    return data


In [3]:
class FCT_Head(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        h_attent_head = [2, 2, 2, 2, 2]
        filters = [8, 16, 32, 64, 128,]
        # number of blocks used in the model
        blocks = len(filters)

        stochastic_depth_rate = 0.0

        #probability for each block
        dpr = [x for x in np.linspace(0, stochastic_depth_rate, blocks)]

        # Multi-scale input
        self.scale_img = nn.AvgPool2d(2,2)   

        # model
        self.block_1 = Block_encoder_bottleneck("first", 1, filters[0], h_attent_head[0], dpr[0])
        self.block_2 = Block_encoder_bottleneck("second", filters[0], filters[1], h_attent_head[1], dpr[1])
        self.block_3 = Block_encoder_bottleneck("third", filters[1], filters[2], h_attent_head[2], dpr[2])
        self.block_4 = Block_encoder_bottleneck("fourth", filters[2], filters[3], h_attent_head[3], dpr[3])
        self.block_5 = Block_encoder_bottleneck("bottleneck", filters[3], filters[4], h_attent_head[4], dpr[4])
    
    def forward(self, x):
        # Multi-scale input
        scale_img_2 = self.scale_img(x)
        scale_img_3 = self.scale_img(scale_img_2)
        scale_img_4 = self.scale_img(scale_img_3)  

        x = self.block_1(x)
        print(f"Block 1 out -> {list(x.size())}")
        skip1 = x
        x = self.block_2(x, scale_img_2)
        print(f"Block 2 out -> {list(x.size())}")
        skip2 = x
        x = self.block_3(x, scale_img_3)
        print(f"Block 3 out -> {list(x.size())}")
        skip3 = x
        x = self.block_4(x, scale_img_4)
        print(f"Block 4 out -> {list(x.size())}")
        skip4 = x

        return {
            "skip1": skip1.cpu().detach().numpy(), 
            "skip2": skip2.cpu().detach().numpy(), 
            "skip3": skip3.cpu().detach().numpy(), 
            "skip4": skip4.cpu().detach().numpy(),
           }

In [4]:
class ACDC_Load_Gray(Dataset):
    def __init__(self, source, ind, Transform=None):
        # basic transforms
        self.loader = mt.LoadImaged(keys=["image", "mask"])
        self.add_channel = mt.EnsureChannelFirstd(keys=["image", "mask"])
        self.spatial_pad = mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge")
        self.spacing = mt.Spacingd(keys=["image", "mask"], pixdim=(1.25, 1.25, -1.0), mode=("nearest", "nearest"))
        # index
        self.ind = ind
        # transform
        if Transform is not None:
            self.transform = Transform
        else:
            self.transform = mt.Compose([
                mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge"),
                mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False)
            ])

        # take the images
        source = Path(source)
        # dirs = os.listdir(str(source))  # stores patient name
        all_data_ed = []
        all_data_ed_mask = []
        all_data_es = []
        all_data_es_mask = []
        for filenames in source.iterdir():
            if filenames.is_dir():
                # patient_path = Path(str(source), filenames)  # individual patient path
                patient_info = str(filenames / "Info.cfg")  # patient information
                file = open(patient_info, 'r').readlines()
                ED_frame = int(file[0].split(":")[1])
                ES_frame = int(file[1].split(":")[1])
                ED = (filenames / f"{filenames.name}_frame{ED_frame:02d}.nii.gz")
                ES = (filenames / f"{filenames.name}_frame{ES_frame:02d}.nii.gz")
                ED_gt = (filenames / f"{filenames.name}_frame{ED_frame:02d}_gt.nii.gz")
                ES_gt = (filenames / f"{filenames.name}_frame{ES_frame:02d}_gt.nii.gz")
                all_data_ed.append(ED)
                all_data_ed_mask.append(ED_gt)
                all_data_es.append(ES)
                all_data_es_mask.append(ES_gt)

        if self.ind is not None:
            all_data_ed = [all_data_ed[i] for i in self.ind]
            all_data_ed_mask = [all_data_ed_mask[i] for i in self.ind]
            all_data_es = [all_data_es[i] for i in self.ind]
            all_data_es_mask = [all_data_es_mask[i] for i in self.ind]

        self.data = [all_data_ed, all_data_ed_mask, all_data_es, all_data_es_mask]

    def __len__(self):
        return len(self.data[0])

    def __getitem__(self, idx):
        ED_img, ED_mask, ES_img, ES_mask = self.data[0][idx], self.data[1][idx], self.data[2][idx], self.data[3][idx]
        # data dict
        ED_data_dict = {"image": ED_img,
                        "mask": ED_mask}
        ES_data_dict = {"image": ES_img,
                        "mask": ES_mask}
        # instead of returning both ED and ES, I have to return just a random choice between ED and ES(image and mask)
        datalist = [ED_data_dict, ES_data_dict]
        data_return = np.random.choice(datalist)
        data_return = self.loader(data_return)
        data_return = self.add_channel(data_return)
        data_return = self.spacing(data_return)
        data_return["image"] = normalize(data_return["image"])
        num_slice = data_return["image"].shape[3]
        random_slice = random.randint(0, num_slice - 1)
        data_return["image"] = data_return["image"][:, :, :, random_slice]
        data_return["image"] = normalize(data_return["image"])
        data_return["mask"] = data_return["mask"][:, :, :, random_slice]
        data_return = self.transform(data_return)
        return data_return

In [5]:
def train_loader_ACDC(train_index, data_path=train_image_path, transform=None):
    train_loader = ACDC_Load_Gray(source=data_path, Transform=transform, ind=train_index)
    return train_loader


def val_loader_ACDC(val_index, data_path=train_image_path, transform=None):
    val_loader = ACDC_Load_Gray(source=data_path, Transform=transform, ind=val_index)
    return val_loader


def test_loader_ACDC(test_index, data_path=test_image_path, transform=None):
    test_loader = ACDC_Load_Gray(source=data_path, Transform=transform, ind=test_index)
    return test_loader

In [7]:
""" To load the training and validation dataloader works """

train_compose = mt.Compose(
    [mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge"),
     mt.RandSpatialCropD(keys=["image", "mask"], roi_size=crop_shape, random_center=True, random_size=False),
     mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False),
     ]
)

val_compose = mt.Compose(
    [   mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge"),
        mt.RandSpatialCropD(keys=["image", "mask"], roi_size=crop_shape, random_center=True, random_size=False),
        mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False),
    ]
)

splits = KFold(n_splits=3, shuffle=True, random_state=42)

concatenated_dataset = train_loader_ACDC(transform=None, train_index=None)

In [None]:
""" To load the testing dataloader works """
test_compose = mt.Compose(
    [
        mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge"),
        mt.RandSpatialCropD(keys=["image", "mask"], roi_size=crop_shape, random_center=True, random_size=False),
        mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False),
    ]
)

test_data = DataLoader(test_loader_ACDC(transform=test_compose, test_index=None), batch_size=1, shuffle=False)

In [8]:
for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(concatenated_dataset)))):

    print("--------------------------", "Fold", fold + 1, "--------------------------")

    # training dataset
    training_data = DataLoader(train_loader_ACDC(transform=train_compose, train_index=train_idx), batch_size=2,
                               shuffle=False)
    print("train from here", len(training_data))
    # for dic in training_data:
    #     images = dic["image"]
    #     masks = dic["mask"]
    #     print(images.shape, masks.shape)
    #     image, label = dic["image"], dic["mask"]
    #     plt.figure("visualise", (8, 4))
    #     plt.subplot(1, 2, 1)
    #     plt.title("image")
    #     plt.imshow(image[0, 0, :, :], cmap="gray")
    #     plt.subplot(1, 2, 2)
    #     plt.title("mask")
    #     plt.imshow(label[0, 0, :, :], cmap="gray")
    #     plt.show()
    #     break

    # validation dataset
    validation_data = DataLoader(val_loader_ACDC(transform=val_compose, val_index=val_idx), batch_size=1,
                                 shuffle=False)
    print("val from here", len(validation_data))
    # for dic in validation_data:
    #     images = dic["image"]
    #     masks = dic["mask"]
    #     print(images.shape, masks.shape)
    #     image, label = dic["image"], dic["mask"]
    #     plt.figure("visualise", (8, 4))
    #     plt.subplot(1, 2, 1)
    #     plt.title("image")
    #     plt.imshow(image[0, 0, :, :], cmap="gray")
    #     plt.subplot(1, 2, 2)
    #     plt.title("mask")
    #     plt.imshow(label[0, 0, :, :], cmap="gray")
    #     plt.show()
    #     break

    # test dataset
    # ========================== TEST Data ===================
    # test_data = DataLoader(test_loader_ACDC(transform=test_compose, test_index=None), batch_size=1, shuffle=False)
    # ========================== TEST Data ===================
    # print("test from here")
    # for dic in test_data:
    #     images = dic["image"]
    #     masks = dic["mask"]
    #     print(images.shape, masks.shape)
    #     image, label = dic["image"], dic["mask"]
    #     plt.figure("visualise", (8, 4))
    #     plt.subplot(1, 2, 1)
    #     plt.title("image")
    #     plt.imshow(image[0, 0, :, :], cmap="gray")
    #     plt.subplot(1, 2, 2)
    #     plt.title("mask")
    #     plt.imshow(label[0, 0, :, :], cmap="gray")
    #     plt.show()
    #     break

-------------------------- Fold 1 --------------------------
train from here 20
val from here 20
-------------------------- Fold 2 --------------------------
train from here 20
val from here 20
-------------------------- Fold 3 --------------------------
train from here 20
val from here 20


In [9]:
# =======================================================================
#                                HEAD
# =======================================================================

model_head = FCT_Head()
model_head.apply(init_weights)

optimizer_head = torch.optim.Adam(model_head.parameters(), lr=dict_args['lr'],weight_decay=dict_args['decay'])

scheduler_head = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_head,
            mode='min',
            factor=dict_args['lr_factor'],
            verbose=True,
            threshold=1e-6,
            patience=10,
            min_lr=dict_args['min_lr'])

model_head.to(device)

FCT_Head(
  (scale_img): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (block_1): Block_encoder_bottleneck(
    (layernorm): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
    (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (trans): Transformer(
      (attention_output): Attention(
        (conv_q): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=8)
        (layernorm_q): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (conv_k): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8)
        (layernorm_k): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (conv_v): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8)
        (layernorm_v): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=8,

In [10]:
model_head.train()

try:
    head_fwd = h5py.File('params_and_grads/head_forward_pass.hdf5', 'w') 
    train_label = h5py.File('params_and_grads/train_values.hdf5', 'w')
    for index, train_dict in enumerate(training_data):
        print("index value is ", index)
        X_train = train_dict["image"]
        y_train = train_dict["mask"]
        X_train = X_train.to(device)

        layer_data = model_head(X_train)

        grp_head = head_fwd.create_group(f'IterKey_{index}')
        for k, v in layer_data.items():
            grp_head.create_dataset(k, data=v)
        
        grp_label = train_label.create_group(f'IterKey_{index}')
        grp_label.create_dataset("tlabel", data=y_train.cpu().detach().numpy())
except Exception as ex:
    import traceback
    print("+=" * 25)
    print("Error encountered as :", ex)
    print("+=" * 25)
    traceback.print_exc()

finally:
    head_fwd.close()
    train_label.close()

index value is  0
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  1
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  2
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  3
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  4
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  5
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  6
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]