In [2]:
import os
import time
import datetime
import numpy as np
import nibabel as nib
import pandas as pd
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from collections import defaultdict
from numpy import nan
import math
import pickle
import cv2
import shutil
from sklearn.model_selection import train_test_split

from opfython.models.unsupervised import UnsupervisedOPF

In [3]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
config_path = '/content/drive/MyDrive/data/config.yml'd

### Import config file

In [5]:
!pip install pyyaml

import yaml

with open(config_path, 'r') as f:
    read_configuration = yaml.safe_load(f)



In [6]:
read_configuration['param_setting']['remove_slice_num']

0

### Preprocessing Pipeline

In [7]:
class Preprocess(object):
    def __init__(self, transform=None):

        self.transform = transform

    def read_config(self):
        # remove head and tail
        self.remove_slice_num = read_configuration['param_setting']['remove_slice_num']
        self.max_img_height = 208
        self.max_img_width = 240



    def get_folder_path(self, dataset='train'):
        self.curr_folder_path = '/content/drive/MyDrive/data'
        # self.data_folder_name = read_configuration['folder_name']['data_folder_name']
        self.data_folder_path = '/content/drive/MyDrive/data'

        train_subfolder_name = 'train'
        test_subfolder_name = 'test'
        img_subfolder_name = 'img'
        gt_subfolder_name = 'gt'

        if dataset == 'train':
            self.img_folder_path = os.path.join(os.path.join(self.data_folder_path, train_subfolder_name), img_subfolder_name)
            self.gt_folder_path = os.path.join(os.path.join(self.data_folder_path, train_subfolder_name), gt_subfolder_name)
        elif dataset == 'test':
            self.img_folder_path = os.path.join(os.path.join(self.data_folder_path, test_subfolder_name), img_subfolder_name)
            self.gt_folder_path = os.path.join(os.path.join(self.data_folder_path, test_subfolder_name), gt_subfolder_name)
        else:
            # Default to train if not specified properly
            self.img_folder_path = os.path.join(os.path.join(self.data_folder_path, train_subfolder_name), img_subfolder_name)
            self.gt_folder_path = os.path.join(os.path.join(self.data_folder_path, train_subfolder_name), gt_subfolder_name)


    def get_file_list(self):
        # Get list of file id names
        self.img_file_list = os.listdir(self.img_folder_path)
        self.gt_file_list = os.listdir(self.gt_folder_path)
        self.gt_id_list = []
        for gt_fname in self.gt_file_list:
            if gt_fname != '.DS_Store':
                gt_id = gt_fname.split('_')[0].split('-')[1]
                self.gt_id_list.append(gt_id)


    def read_bbox_coord(self):
        bbox_coord_fname = 'bbox_coordinate.csv'
        self.bbox_coord_fpath = os.path.join(self.data_folder_path, bbox_coord_fname)
        self.bbox_coord_df = pd.read_csv(self.bbox_coord_fpath)
        self.bbox_img_slice_id_list = self.bbox_coord_df['img_slice_id'].values.tolist() #change variable name


    def match_img_and_gt(self, img_id):
        # Check if the corresponding ground truth exists
        if img_id not in self.gt_id_list:
            print("We cannot find the corresponding ground truth of img %s,"
                  "Continue to Check the Next Image File!" % (img_id))
            return False
        else:
            return True


    def read_nii_data(self, file_id):
        # Load both the MRI image and its corresponding ground truth segmentation mask for a given file id
        # Convert them into 3D NumPy arrays for further processing
        img_fname_prefix = 'sub-'
        img_fname_suffix = '_ses-1_space-MNI152NLin2009aSym_T1w.nii.gz'
        gt_fname_prefix = 'sub-'
        gt_fname_suffix = '_ses-1_space-MNI152NLin2009aSym_label-L_desc-T1lesion_mask.nii.gz'
        img_fname = img_fname_prefix + file_id + img_fname_suffix
        gt_fname = gt_fname_prefix + file_id + gt_fname_suffix
        img_fpath = os.path.join(self.img_folder_path, img_fname)
        gt_fpath = os.path.join(self.gt_folder_path, gt_fname)
        nib_img = nib.load(img_fpath)
        nib_gt = nib.load(gt_fpath)
        img_array_3d = nib_img.get_fdata()
        gt_array_3d = nib_gt.get_fdata()
        self.img_array_3d = img_array_3d
        self.gt_array_3d = gt_array_3d

        return img_array_3d, gt_array_3d

    def pad_box_coordinates(self, box_coordinates, image_height, image_width, target_height, target_width):
        padded_box_coordinates = []
        for box in box_coordinates:
            x, y, width, height = box
            x_padding = (target_width - image_width) // 2
            y_padding = (target_height - image_height) // 2
            # Adjusts the box’s (x, y) position to match its new location in the padded image.
            padded_x = x + x_padding
            padded_y = y + y_padding
            padded_box = (padded_x, padded_y, width, height)
            padded_box_coordinates.append(padded_box)

        return padded_box_coordinates


    def create_mask(self, bbox_coords, img_height, img_width):
        # It returns a mask of the same size as the input image, with pixels inside each bounding box set to 1 (indicating lesion area) and all other pixels set to 0 (background).
        mask = np.zeros((img_height, img_width))
        for box in bbox_coords:
            if any(math.isnan(value) for box in bbox_coords for value in box):
                mask = np.zeros((img_height, img_width))
            else:
                x, y, w, h = box
                mask[y:y+h, x:x+w] = 1
        return mask


    def get_bbox_area(self, img_data, bbox_coords):
        # The method get_bbox_area extracts the regions of an image that fall within each bounding box and returns them as a list of 2D arrays.
        bbox_area_array_2d_list = []

        for box in bbox_coords:
            if any(math.isnan(value) for box in bbox_coords for value in box):
                return []
            else:
                x, y, w, h = box
                bbox_area = img_data[y:y+h, x:x+w]
                bbox_area_array_2d_list.append(bbox_area)
        return bbox_area_array_2d_list


    def get_bbox_label(self, img_id, slice_id, img_height, img_width, img_2d):
        x_bbox_coord_list = [int(x) if not math.isnan(x) else float('nan') for x in self.bbox_coord_df['x']]
        y_bbox_coord_list = [int(x) if not math.isnan(x) else float('nan') for x in self.bbox_coord_df['y']]
        w_bbox_coord_list = [int(x) if not math.isnan(x) else float('nan') for x in self.bbox_coord_df['w']]
        h_bbox_coord_list = [int(x) if not math.isnan(x) else float('nan') for x in self.bbox_coord_df['h']]

        # Builds a list of (x, y, w, h) tuples from the four coordinate lists.
        all_bbox_coord_list = []
        for i in range(len(x_bbox_coord_list)):
            all_bbox_coord_list.append((x_bbox_coord_list[i], y_bbox_coord_list[i], w_bbox_coord_list[i], h_bbox_coord_list[i]))

        all_bbox_coord_dict = defaultdict(list)
        for i, value in enumerate(self.bbox_img_slice_id_list):
            all_bbox_coord_dict[value].append(all_bbox_coord_list[i])

        img_slice_id = img_id + '_' + str(slice_id)
        if img_slice_id in all_bbox_coord_dict.keys():
            temp_slice_bbox_coord = all_bbox_coord_dict[img_slice_id]

        slice_bbox_coord = self.pad_box_coordinates(temp_slice_bbox_coord, 197, 233, self.max_img_height, self.max_img_width)

        print("img_slice_id:", img_slice_id)
        ws_label = self.create_mask(slice_bbox_coord, self.max_img_height, self.max_img_width)
        bbox_area = self.get_bbox_area(img_2d, slice_bbox_coord)

        return ws_label, bbox_area, slice_bbox_coord


    def delete_nan(self, bbox_coords):
        for box in bbox_coords:
            if any(math.isnan(value) for box in bbox_coords for value in box):
                return []

            return bbox_coords

    def jigsaw(self, jigsaw_method, input_area_list, box_coord, image_shape):
        # Choose segmentation method
        if jigsaw_method == 'kmeans':
            box_area_list = self.kmeans(input_area_list)
        elif jigsaw_method == 'threshold':
            box_area_list = self.threshold(input_area_list)
        elif jigsaw_method == 'opf':
            box_area_list = self.opf(input_area_list)

        # Initialize an empty label mask
        label = np.zeros(image_shape)
        # Skip if no bounding boxes
        if [] in box_coord:
            return label
        # Paste weak labels back into the full image
        else:

            for i, coord in enumerate(box_coord):
                x, y, w, h = coord
                box_area = box_area_list[i]
                # Calculate boundaries
                x1, y1 = x, y
                x2, y2 = x + w, y
                x3, y3 = x + w, y + h
                x4, y4 = x, y + h
                # Ensures the box stays within the image (avoids index out-of-bounds errors)
                x1, y1 = max(0, min(x1, image_shape[1] - 1)), max(0, min(y1, image_shape[0] - 1))
                x2, y2 = max(0, min(x2, image_shape[1] - 1)), max(0, min(y2, image_shape[0] - 1))
                x3, y3 = max(0, min(x3, image_shape[1] - 1)), max(0, min(y3, image_shape[0] - 1))
                x4, y4 = max(0, min(x4, image_shape[1] - 1)), max(0, min(y4, image_shape[0] - 1))
                # Add box mask into the full label mask
                label[y1:y3, x1:x3] = np.maximum(label[y1:y3, x1:x3], box_area)

            return label

    def adjust_mask(self, pixel_mask):

        return pixel_mask

    def opf(self,box_area_list):
        """
        Function that performs the unsupervised OPF to provide the baseline ground-truth images.

        Args:
            box_area_list (list): List of the bounding boxes containing the stroke lesion areas.

        Returns:
            opf_area_list (list): List with the OPF results for each area.
        """

        opf_area_list = []

        # Most of the instruction below are similar to your k-means implementation. The difference is in the OPF clustering for segmentation.
        def preprocess_image(image):

            normalized_image = image.astype(np.float32) / 255.0
            resized_image = normalized_image
            blurred_image = cv2.GaussianBlur(resized_image, (5, 5), 0)

            return blurred_image
        # Loop through each bounding box
        for i in range(len(box_area_list)):
            # if the patch is empty, return a blank mask
            if not box_area_list[i].any():
                box_area_value = np.zeros(np.shape(box_area_list[i]), dtype = np.float32)
                opf_area_list.append(box_area_value)
                continue

            image = np.stack((box_area_list[i],) * 3, axis=-1)

            processed_image = preprocess_image(image)
            height = np.shape(processed_image)[0]
            width = np.shape(processed_image)[1]

            pixel_mask = np.zeros((height, width), dtype = np.float32)

            if len(image) == 1:
                opf_area_list.append(pixel_mask)
            else:
                # Creates an instance of the OPF clustering class. I suggest using different values for max_k to check the ones that are proper for your images.
                unsupervised_opf = UnsupervisedOPF(min_k=1,max_k=20)
                processed_image_flatten = processed_image.flatten().reshape(-1,3).astype(np.float32)
                # Creates the clusters based on the OPF
                # We don't need to provide the training ground data. Actually, the best standard is not use the ground-truth images
                unsupervised_opf.fit(processed_image_flatten)
                # Assigns the cluster to each image pixel, i.e., the cluster each pixel belongs to
                unsupervised_opf.propagate_labels()
                # Gets the clusters labels for each pixel in the image
                _, clusters = unsupervised_opf.predict(processed_image_flatten)

                # Gets the indices of the prototypes, i.e., the centers of the clusters
                proto = []
                for i in range(unsupervised_opf.subgraph.n_nodes):
                    if unsupervised_opf.subgraph.nodes[i].idx==unsupervised_opf.subgraph.nodes[i].root:
                        proto.append(unsupervised_opf.subgraph.nodes[i].root)

                # Gets the prototypes labels, i.e., the labels assigned to the clusters' centers
                proto = np.array(proto)
                proto_clusters = list(np.array(clusters)[proto])
                # Gets the prototypes values sorted by their labels. It tries to sort the prototypes from the black to the white colors.
                prototypes = processed_image_flatten[proto][np.argsort(proto_clusters)]
                # Gets the segmented image according to its original resolution
                segmented_image = prototypes[clusters].reshape(processed_image.shape)
                # Gets the unique clusters' labels
                segmented_image_squeeze = np.mean(segmented_image, axis=2)
                segmented_image_around = np.around(segmented_image_squeeze, 2)
                segmented_image_unique = np.sort(np.unique(segmented_image_around))

                # Prints some important information about the clusters' labels
                print('processed image resolution: {} {}'.format(processed_image.shape[0],processed_image.shape[1]))
                print('clusters: ',np.unique(clusters))
                print('segmented image unique: ',segmented_image_unique)

                if len(segmented_image_unique) <= 2:
                    norm_lesion_value = segmented_image_unique[0]
                else:
                    # Since OPF produces a different number of clusters for each image, the following instructions
                    # will get the prototypes, i.e., the centers of the clusters in the first half of the list.
                    # For example, if OPF produces six clusters, the instructions below will get the first three centers, which usually represent the black and grey pixels.
                    # However, the segmentation quality depends on the k_max value of the OPF clustering.
                    median_idx = round(len(segmented_image_unique)/2)
                    norm_lesion_value = segmented_image_unique[0:median_idx]

                # The instructions below are similar to your k-means implementation
                for i in range(len(segmented_image_around)):
                    for j in range(len(segmented_image_around[i])):
                        temp_pixel_value = segmented_image_around[i][j]
                        if np.isscalar(norm_lesion_value ):
                            if temp_pixel_value == norm_lesion_value:
                                pixel_mask[i][j] = 1
                        else:
                            if temp_pixel_value in norm_lesion_value:
                                pixel_mask[i][j] = 1

                pixel_unique = np.unique(pixel_mask)
                if list(pixel_unique) == [0.] or list(pixel_unique) == [1.]:
                    result = pixel_mask
                else:
                    result = self.adjust_mask(pixel_mask)

                opf_area_list.append(result)

        return opf_area_list

    def threshold(self, box_area_list):

        threshold_1 = 190 #white
        threshold_2 = 36 #black
        threshold_area_list = []

        for i in range(len(box_area_list)):
            if not box_area_list[i].any(): # check box area if full of zero
                box_area_value = np.zeros(np.shape(box_area_list[i]), dtype = np.float32)
                threshold_area_list.append(box_area_value)
                continue

            data_2d = box_area_list[i]
            height = np.shape(data_2d)[0]
            width = np.shape(data_2d)[1]

            pixel_mask = np.zeros((height, width))

            for i in range(height):
                for j in range(width):
                    pixel_value = data_2d[i][j]
                    if pixel_value >= threshold_1:
                        pixel_mask[i][j] = 2
                    if threshold_2 <= pixel_value < threshold_1:
                        pixel_mask[i][j] = 1

            pixel_mask[pixel_mask == 2] = 0


            pixel_unique = np.unique(pixel_mask)

            if list(pixel_unique) == [0.] or list(pixel_unique) == [1.]:
                result = pixel_mask
            else:
                result = self.adjust_mask(pixel_mask)
            threshold_area_list.append(result)


        self.threshold_list = threshold_area_list

        return threshold_area_list


    def kmeans(self, box_area_list):
        kmeans_area_list = []

        def preprocess_image(image):

            normalized_image = image.astype(np.float32) / 255.0
            resized_image = normalized_image
            blurred_image = cv2.GaussianBlur(resized_image, (5, 5), 0)

            return blurred_image

        def kmeans_segmentation(image, num_clusters):
            pixel_values = image.reshape(-1, 3).astype(np.float32)
            criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.1)
            _, labels, centers = cv2.kmeans(pixel_values, num_clusters, None, criteria, 30, cv2.KMEANS_RANDOM_CENTERS)
            segmented_image = centers[labels.flatten()].reshape(image.shape)

            return segmented_image

        for i in range(len(box_area_list)):

            if not box_area_list[i].any():
                box_area_value = np.zeros(np.shape(box_area_list[i]), dtype = np.float32)
                kmeans_area_list.append(box_area_value)
                continue

            image = np.stack((box_area_list[i],) * 3, axis=-1)

            processed_image = preprocess_image(image)
            height = np.shape(processed_image)[0]
            width = np.shape(processed_image)[1]

            pixel_mask = np.zeros((height, width), dtype = np.float32)

            if len(image) == 1:
                result = pixel_mask
                kmeans_area_list.append(result)
            elif len(image) == 2:
                num_clusters = 2
                segmented_image = kmeans_segmentation(processed_image, num_clusters)
                segmented_image_squeeze = np.mean(segmented_image, axis=2)
                segmented_image_around = np.around(segmented_image_squeeze, 2)
                segmented_image_unique = np.sort(np.unique(segmented_image_around))

                if len(segmented_image_unique) == 1:
                    norm_lesion_value = segmented_image_unique[0]

                else:
                    norm_lesion_value = segmented_image_unique[1]

                for i in range(len(segmented_image_around)):
                    for j in range(len(segmented_image_around[i])):
                        temp_pixel_value = segmented_image_around[i][j]
                        if temp_pixel_value == norm_lesion_value:
                            pixel_mask[i][j] = 1

                pixel_unique = np.unique(pixel_mask)
                if list(pixel_unique) == [0.] or list(pixel_unique) == [1.]:
                    result = pixel_mask
                else:
                    result = self.adjust_mask(pixel_mask)

                kmeans_area_list.append(result)
            else:
                num_clusters = 3
                segmented_image = kmeans_segmentation(processed_image, num_clusters)
                segmented_image_squeeze = np.mean(segmented_image, axis=2)
                segmented_image_around = np.around(segmented_image_squeeze, 10)

                segmented_image_unique = np.sort(np.unique(segmented_image_around))

                if len(segmented_image_unique) == 1:
                    norm_lesion_value = segmented_image_unique[0]
                else:
                    norm_lesion_value = segmented_image_unique[1]


                for i in range(len(segmented_image_around)):
                    for j in range(len(segmented_image_around[i])):
                        temp_pixel_value = segmented_image_around[i][j]
                        if temp_pixel_value == norm_lesion_value:
                            pixel_mask[i][j] = 1

                pixel_unique = np.unique(pixel_mask)

                if list(pixel_unique) == [0.] or list(pixel_unique) == [1.]:
                    result = pixel_mask
                else:
                    result = self.adjust_mask(pixel_mask)

                kmeans_area_list.append(result)

        self.k_list = kmeans_area_list

        return kmeans_area_list

    def concat_slices(self, dataset):
        self.load_checkpoint(dataset)
        count = 0

        for img_nii_name in self.img_file_list:

            print("img_name:", img_nii_name)
            if img_nii_name != '.DS_Store':
                img_id = img_nii_name.split('_')[0].split('-')[1]

                is_matched = self.match_img_and_gt(img_id)
                if not is_matched:
                    continue

                (img_array_3d, gt_array_3d) = self.read_nii_data(img_id)

                img_array_3d_pad = self.pad_img(img_array_3d)
                gt_array_3d_pad = self.pad_img(gt_array_3d)

                img_array_4d_pad_trans = self.img_transform(img_array_3d_pad, True)
                gt_array_4d_pad_trans = self.img_transform(gt_array_3d_pad, False)

                img_height = len(img_array_3d_pad)
                img_width = len(img_array_3d_pad[0])
                total_num_slice = len(img_array_3d_pad[0][0])

                #print("img_id = %s, total_num_slice = %d" % (img_id, total_num_slice))
                start_slice_id = 0 + self.remove_slice_num
                end_slice_id = total_num_slice - self.remove_slice_num
                count = count+1

                for slice_id in range(start_slice_id, end_slice_id):
                    slice_key = f"{img_id}_{slice_id}"
                    if slice_key in self.processed_slices:
                        print(f"Skipping already processed: {slice_key}")
                        continue
                    img_array_2d = img_array_4d_pad_trans[slice_id,:,:,:]
                    gt_array_2d = gt_array_4d_pad_trans[slice_id,:,:,:]

                    img_array_2d_for_box_area = img_array_3d_pad[:, :, slice_id]

                    ws_label, bbox_area, bbox_coord = self.get_bbox_label(img_id, slice_id, img_height, img_width, img_array_2d_for_box_area)

                    ws_label_3d = np.expand_dims(ws_label, axis=0)

                    slice_bbox_coord = self.delete_nan(bbox_coord)

                    kmeans_2d_array = self.jigsaw('kmeans', bbox_area, slice_bbox_coord, (img_height, img_width))
                    kmeans_3d_array = np.expand_dims(kmeans_2d_array, axis=0)

                    threshold_2d_array = self.jigsaw('threshold', bbox_area, slice_bbox_coord, (img_height, img_width))
                    threshold_3d_array = np.expand_dims(threshold_2d_array, axis=0)

                    # OPF segmentation
                    opf_2d_array = self.jigsaw('opf',bbox_area, slice_bbox_coord, (img_height, img_width))
                    opf_3d_array = np.expand_dims(opf_2d_array, axis=0)

                    ###########################################################################################################
                    # Just uncomment the instructions below if you want to see the plot with the segmentation results
                    # You have to install matplotlib to run these instructions
                    '''
                    f, axarr = plt.subplots(3,2)
                    axarr[0,0].imshow(img_array_2d.T.astype('uint8'))
                    axarr[0,0].set_title('Original image')
                    axarr[0,1].imshow(gt_array_2d.T)
                    axarr[0,1].set_title('Ground-truth')
                    axarr[1,0].imshow(kmeans_2d_array.T)
                    axarr[1,0].set_title('K-Means')
                    axarr[1,1].imshow(opf_2d_array.T)
                    axarr[1,1].set_title('OPF')
                    axarr[2,0].imshow(threshold_2d_array.T)
                    axarr[2,0].set_title('Threshold')
                    axarr[2,1].imshow(ws_label.T)
                    axarr[2,1].set_title('WS Label')
                    f.tight_layout(pad=1.0)
                    # Just uncomment the instruction below if you want to save the PNG image in the processed_data folder
                    plt.savefig('processed_data/{}.png'.format(img_nii_name),dpi=300,bbox_inches='tight')
                    #plt.show()
                    '''
                    ###########################################################################################################

                    # I concatenated the 'opf_3d_array' in the results list
                    # plt.savefig('processed_data/{}.png'.format(img_nii_name),dpi=300,bbox_inches='tight')
                    print(f"Saving: {slice_key}")
                    self.save_proc_data(dataset, img_id, slice_id, [img_array_2d, gt_array_2d, ws_label_3d, kmeans_3d_array, threshold_3d_array, opf_3d_array])

                print('****** file count =', count, '****** ')


    def save_proc_data(self, dataset, img_id, slice_id, data):

        saved_time = self.proc_time
        saved_folder_name = 'processed_data'
        saved_folder_path = os.path.join(self.curr_folder_path, saved_folder_name)
        if not os.path.exists(saved_folder_path):
            os.makedirs(saved_folder_path)


        subfolder_name = saved_time

        subfolder_path = os.path.join(saved_folder_path, subfolder_name)
        if not os.path.exists(subfolder_path):
            os.makedirs(subfolder_path)

        subsubfolder_name = dataset

        subsubfolder_path = os.path.join(subfolder_path, subsubfolder_name)
        if not os.path.exists(subsubfolder_path):
            os.makedirs(subsubfolder_path)

        saved_data_name = img_id+ '_' + str(slice_id)
        saved_data_path = os.path.join(subsubfolder_path, saved_data_name)

        # np.save(saved_data_path, data)
        np.save(saved_data_path, np.array(data, dtype=object), allow_pickle=True)

        # Save log
        with open(os.path.join(saved_folder_path, 'checkpoint.txt'), 'a') as log_file:
            log_file.write(f"{img_id}_{slice_id}\n")

    def pad_img(self, input_img):
        target_height = self.max_img_height
        target_width = self.max_img_width
        input_img_height, input_img_width, _ = input_img.shape

        pad_height = target_height - input_img_height
        upper_padding_rows = pad_height // 2
        lower_padding_rows = pad_height - upper_padding_rows

        pad_width = target_width - input_img_width
        left_padding_columns = pad_width // 2
        right_padding_columns = pad_width - left_padding_columns

        padded_img = np.pad(
            input_img,
            ((upper_padding_rows, lower_padding_rows),
             (left_padding_columns, right_padding_columns),
             (0, 0)),
            'constant',
            constant_values=(0, 0)
        )

        return padded_img


    def img_transform(self, input_img, is_img):
        # Transpose and expand dims, a 3D image (with shape H × W × S) into a 4D array (with shape (S, C, H, W))
        # The function returns:
        #  (S, 3, H, W) for images
        #  (S, 1, H, W) for ground truths
        #  This is exactly how PyTorch likes to receive batches of image data for training
        output_img = input_img.transpose(2, 0, 1)
        output_img = np.expand_dims(output_img, axis = 1)
        if is_img:
            output_img = np.repeat(output_img, 3, 1)
        return output_img


    def process(self):

        self.read_config()

        self.proc_time = time.strftime('%Y%m%d_%H-%M-%S', time.localtime(time.time()))

        for dataset in ['train', 'test']:
            self.get_folder_path(dataset)
            self.get_file_list()
            self.read_bbox_coord()
            self.load_checkpoint(dataset)
            self.concat_slices(dataset)

    # Load checkpoint for colab to resume processing
    def load_checkpoint(self, dataset):
        self.processed_slices = set()
        processed_root = os.path.join(self.curr_folder_path, 'processed_data')

        if not os.path.exists(processed_root):
            return

        timestamp_dirs = [d for d in os.listdir(processed_root) if os.path.isdir(os.path.join(processed_root, d))]
        if not timestamp_dirs:
            return

        latest_timestamp = sorted(timestamp_dirs)[-1]
        self.proc_time = latest_timestamp  # Reuse this folder for saving

        checkpoint_file = os.path.join(processed_root, 'checkpoint.txt')
        if os.path.exists(checkpoint_file):
            with open(checkpoint_file, 'r') as f:
                for line in f:
                    self.processed_slices.add(line.strip())

        print(f"Loaded {len(self.processed_slices)} slices from checkpoint.")

In [None]:
proc = Preprocess()
proc.process()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Skipping already processed: r009s030_130
Skipping already processed: r009s030_131
Skipping already processed: r009s030_132
Skipping already processed: r009s030_133
Skipping already processed: r009s030_134
Skipping already processed: r009s030_135
Skipping already processed: r009s030_136
Skipping already processed: r009s030_137
Skipping already processed: r009s030_138
Skipping already processed: r009s030_139
Skipping already processed: r009s030_140
Skipping already processed: r009s030_141
Skipping already processed: r009s030_142
Skipping already processed: r009s030_143
Skipping already processed: r009s030_144
Skipping already processed: r009s030_145
Skipping already processed: r009s030_146
Skipping already processed: r009s030_147
Skipping already processed: r009s030_148
Skipping already processed: r009s030_149
Skipping already processed: r009s030_150
Skipping already processed: r009s030_151
Skipping already processed: r009s