# Experiment

The main limitation with the previous design was that the batches didn't fit neatly into
memory, and a lot of overhead was required for loading upsampled slices. In addition to
requiring a lot of memory and such, it is not feasible to expect fast performance. With
annecdotal experience, some training datasets were around 300 GB, and with constant
reading from disk etc. slows the entire system down.

A solution was proposed by using a 'stochastic caching library' [linked
here](https://github.com/Charl-AI/stochastic-caching) which would cache some intermediate
results making them faster to access. However, with sizes like 300 GB the effects of this
library are drowned out.

In [11]:
# Imports
import re
import os
import cv2
import sys
import json
import torch
import monai
import random
import argparse
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
from time import time, sleep
from datetime import datetime
from matplotlib import pyplot as plt
import SimpleITK as sitk 
from segment_anything import sam_model_registry
from torch.utils.data import Dataset, DataLoader

# Add the setup_data_vars function as we will need it to find the directory for the training data.
dir1 = os.path.abspath(os.path.join(os.path.abspath(''), '..', '..'))
if not dir1 in sys.path: sys.path.append(dir1)

from utils.environment import setup_data_vars
setup_data_vars()

In [3]:
from stocaching import SharedCache

In [7]:
def measure_time_to_load_all(dataloader):
    # Measure the time to load all the data
    start = time()
    for data in tqdm(dataloader):
        pass
    end = time()
    return end - start

In [9]:
def get_id_from_path(path, needs_num=True, pos = 0):
    # Assume that it is the full path that points to the file name. The file name should
    # contain a number indicating the id number. It should appear first
    numbers = re.findall('\d+', os.path.basename(path))
    if needs_num and len(numbers) == 0:
        raise ValueError(f"Could not find a number in {path}")
    if not needs_num and len(numbers) <= pos:
        return 0
    return int(numbers[pos])
assert get_id_from_path('radiotherapy/data/MedSAM_preprocessed/imgs/axis0/CT_zzAMLART_001-071.npy', False) == 1
assert get_id_from_path('/vol/biomedic3/bglocker/ugproj2324/az620/radiotherapy/data/MedSAM_preprocessed/gts/Bladder/axis0/CT_Bladder_zzAMLART_002-085.npy', False) == 2
assert get_id_from_path('radiotherapy/data/nnUNet_raw/Dataset001_Anorectum/imagesTr/zzAMLART_003_0000.nii.gz', False) == 3
assert get_id_from_path('/vol/biomedic3/bglocker/ugproj2324/az620/radiotherapy/research/source/code/data/reports/reportAnorectum_axis0_1.png', False, pos=1) == 1
def get_slice_from_path(path, isMedSAM = False):
    # Assume that it is the full path that points to the file name. The file name should
    # contain a number indicating the id number. It should appear first. The second number
    # should be slice number
    if 'MedSAM' not in path and not isMedSAM:
        raise ValueError(f'Function intended for getting slice id from a path, which is only characteristic of MedSAM data')
    numbers = re.findall('\d+', path.split('/')[-1])
    assert len(numbers) == 2, f"Could not find a number in {path}"
    return int(numbers[1])
assert get_slice_from_path('radiotherapy/data/MedSAM_preprocessed/imgs/axis0/CT_zzAMLART_001-071.npy') == 71
assert get_slice_from_path('/vol/biomedic3/bglocker/ugproj2324/az620/radiotherapy/data/MedSAM_preprocessed/gts/Bladder/axis0/CT_Bladder_zzAMLART_002-085.npy') == 85
try: 
    get_slice_from_path('radiotherapy/data/nnUNet_raw/Dataset001_Anorectum/imagesTr/zzAMLART_003_0000.nii.gz')
    assert False, 'Didn\'t raise error as expected'
except ValueError as e: assert str(e) == 'Function intended for getting slice id from a path, which is only characteristic of MedSAM data'

## Measure the time to load a random slice from each image id using the MedSAM pre-processed data

In [2]:
image_id_from_file_name_regex = r'.*_(\d+).*'
slice_id_from_file_name_regex = r'.*-(\d+).*'

In [4]:
class MEDSAM_DATASET_NO_CACHING(Dataset):
    """A torch dataset for delivering slices of any axis to a medsam model."""

    def __init__(self, img_path, gt_path,):
        """
        Args:
            img_path (string): Path to the directory containing the images
            gt_path (string): Path to the directory containing the ground truth masks
            id_split (list): List of image ids to include in the dataset
        """

        self.root_img_path = img_path
        self.root_gt_path = gt_path
        
        # Assume that axese 0 1 and 2 have been processed.
        filter_fn = lambda x : x.endswith('.npy')
        self.axis0_imgs = list(filter(filter_fn, os.listdir(os.path.join(gt_path, 'axis0'))))
        self.axis1_imgs = list(filter(filter_fn, os.listdir(os.path.join(gt_path, 'axis1'))))
        self.axis2_imgs = list(filter(filter_fn, os.listdir(os.path.join(gt_path, 'axis2'))))

    def __len__(self):
        return len(self.axis0_imgs) + len(self.axis1_imgs) + len(self.axis2_imgs)

    def __getitem__(self, idx):
        assert 0 <= idx < self.__len__(), f"Index {idx} is out of range for dataset of size {self.__len__()}"

        # Fetch the image and ground truth mask. For safety, we index the items around the
        # ground truth masks, so that if for some reason the images are misaligned we will
        # guarantee that we will fetch the correct image

        img_path, gt_path, img_name = self._get_image_and_gt_path(idx)

        img = np.load(img_path, 'r', allow_pickle=True) # (H, W, C)
        img = np.transpose(img, (2, 0, 1)) # (C, H, W)
        assert np.max(img) <= 1. and np.min(img) >= 0., 'image should be normalized to [0, 1]'
        
        img = torch.tensor(img).float()
        
        # Loading of ground truth shouldn't be the limiting factor
        gt = np.load(gt_path, 'r', allow_pickle=True) # (H, W, C)

        gt = cv2.resize(
            gt,
            (256, 256),
            interpolation=cv2.INTER_NEAREST
        )

        return {
            "image": img, # 3x1024x1024
            "gt2D": torch.tensor(gt[None, :,:]).long(), # 1x256x256
            "image_name": img_name
        }
    
    def _get_image_and_gt_path(self, idx):
        if idx < len(self.axis0_imgs):
            axis, gt_name = 0, self.axis0_imgs[idx]
        elif idx < len(self.axis0_imgs) + len(self.axis1_imgs):
            axis, gt_name = 1, self.axis1_imgs[idx - len(self.axis0_imgs)]
        else:
            axis, gt_name = 2, self.axis2_imgs[idx - len(self.axis0_imgs) - len(self.axis1_imgs)]

        image_id = int(re.search(image_id_from_file_name_regex, gt_name).group(1))
        slice_id = int(re.search(slice_id_from_file_name_regex, gt_name).group(1))

        img_name = f'CT_zzAMLART_{image_id:03d}-{slice_id:03d}.npy'
        
        img_path = os.path.join(self.root_img_path, f'axis{axis}', img_name)
        gt_path = os.path.join(self.root_gt_path, f'axis{axis}', gt_name)

        return img_path, gt_path, img_name

In [5]:
my_dataset = MEDSAM_DATASET_NO_CACHING(
    os.path.join(os.environ.get('MedSAM_preprocessed'), 'imgs'), 
    os.path.join(os.environ.get('MedSAM_preprocessed'), 'gts', 'CTVn')  # annecdotally, the CTVn contains the most masks, and therefore is taken as an upper bound
)

In [6]:
my_dataloader = DataLoader(my_dataset, batch_size=16, shuffle=True)

In [8]:
measure_time_to_load_all(my_dataloader)

  2%|▏         | 39/2105 [05:17<4:40:23,  8.14s/it]


KeyboardInterrupt: 

## Measure the time to load a random slice from each image id using the .nii.gz raw data-format

In [None]:
class RAW_DATASET_WITH_CACHING(Dataset):
    """A torch dataset for delivering slices of any axis to a medsam model."""

    def __init__(self, medsam_gt_path, raw_img_path, raw_gt_path):
        """
        Args:
            img_path (string): Path to the directory containing the images
            gt_path (string): Path to the directory containing the ground truth masks
            id_split (list): List of image ids to include in the dataset
        """

        self.root_img_path = raw_img_path
        self.root_gt_path = raw_gt_path
        
        # Read in the ground truths. This has been a pre-processed step so we utilize it.
        # These tell us which slices of the images contain the contoured area. It is a
        # fair comparison becuase MedSAM operates on pre-processed data, thus we maximise
        # this ability with raw data also.
        filter_fn = lambda x : x.endswith('.npy')
        self.axis0_imgs = list(filter(filter_fn, os.listdir(os.path.join(medsam_gt_path, 'axis0'))))
        self.axis1_imgs = list(filter(filter_fn, os.listdir(os.path.join(medsam_gt_path, 'axis1'))))
        self.axis2_imgs = list(filter(filter_fn, os.listdir(os.path.join(medsam_gt_path, 'axis2'))))

        # pre-process this to create a dictionary of image ids and the slices that contain
        # the contoured area

        self.id_to_slice_dict_axis0 = dict([(i, []) for i in range(1, 101)])
        self.id_to_slice_dict_axis1 = dict([(i, []) for i in range(1, 101)])
        self.id_to_slice_dict_axis2 = dict([(i, []) for i in range(1, 101)])
        
        def populate(id_to_slice_dict, axis_imgs):
            for img in axis_imgs:
                image_id = get_id_from_path(img)
                slice_id = get_slice_from_path(img)
                id_to_slice_dict[image_id].append(slice_id)

        populate(self.id_to_slice_dict_axis0, self.axis0_imgs)
        populate(self.id_to_slice_dict_axis1, self.axis1_imgs)
        populate(self.id_to_slice_dict_axis2, self.axis2_imgs)

    def __len__(self):
        return len(self.axis0_imgs) + len(self.axis1_imgs) + len(self.axis2_imgs)

    def __getitem__(self, idx):
        assert 0 <= idx < self.__len__(), f"Index {idx} is out of range for dataset of size {self.__len__()}"

        # Fetch the image and ground truth mask. For safety, we index the items around the
        # ground truth masks, so that if for some reason the images are misaligned we will
        # guarantee that we will fetch the correct image

        img_path, gt_path, slice_index, axis = self._get_image_and_gt_path(idx)

        raw_image = sitk.ReadImage(img_path)
        raw_image_array = sitk.GetArrayFromImage(raw_image)

        raw_gt = sitk.ReadImage(gt_path)
        raw_gt_array = sitk.GetArrayFromImage(raw_gt)

        # read in the slice for both arrays

        slices = [slice(None)] * 3
        slices[axis] = slice_index
        slices = tuple(slices)

        img = raw_image_array[slices]
        img = np.transpose(img, (2, 0, 1)) # (C, H, W)

        img = np.load(img_path, 'r', allow_pickle=True) # (H, W, C)
        img = np.transpose(img, (2, 0, 1)) # (C, H, W)
        assert np.max(img) <= 1. and np.min(img) >= 0., 'image should be normalized to [0, 1]'
        
        img = torch.tensor(img).float()
        
        # Loading of ground truth shouldn't be the limiting factor
        gt = np.load(gt_path, 'r', allow_pickle=True) # (H, W, C)

        gt = cv2.resize(
            gt,
            (256, 256),
            interpolation=cv2.INTER_NEAREST
        )

        return {
            "image": img, # 3x1024x1024
            "gt2D": torch.tensor(gt[None, :,:]).long(), # 1x256x256
            "image_name": img_name
        }
    
    def _get_image_and_gt_path(self, idx):
        """Returns the paths for the image and ground truth mask for the given index and
        also the slice index and axis"""

        if idx < len(self.axis0_imgs):
            axis, gt_name = 0, self.axis0_imgs[idx]
        elif idx < len(self.axis0_imgs) + len(self.axis1_imgs):
            axis, gt_name = 1, self.axis1_imgs[idx - len(self.axis0_imgs)]
        else:
            axis, gt_name = 2, self.axis2_imgs[idx - len(self.axis0_imgs) - len(self.axis1_imgs)]

        image_id = get_id_from_path(gt_name)
        slice_id = get_slice_from_path(gt_name)

        image_path = os.path.join(self.raw_img_path, 'zzAMLART_{:03d}_0000.nii.gz'.format(image_id))
        gt_path = os.path.join(self.raw_gt_path, 'zzAMLART_{:03d}.nii.gz'.format(image_id))

        return image_path, gt_path, slice_id, axis