In [1]:
import torch

import sys

# Setting path so as we can find files in ../src folder
sys.path.append('../src')

from BRATS2013Dataset import BRATS2013Dataset

from tqdm.autonotebook import tqdm

import numpy as np

import json

import skimage

from pprint import pprint

from matplotlib import pyplot as plt

In [2]:
PATCH_SIZE = 65

In [3]:
DATASET_PATH = "../data/brats_2013_obs_path_list_unstacked_resized.txt"

obs_list = []

with open(DATASET_PATH) as file:
  for line in file:
    obs_list.append(line.rstrip())

In [4]:
dataset = BRATS2013Dataset(obs_list=obs_list, stage=None)

In [5]:
segmentation_label_count_dict = {
  "0": 0,
  "1": 0,
  "2": 0,
  "3": 0,
  "4": 0,
  "5": 0,
}

In [6]:
range_elements = range(len(dataset))[:1000]
datasubset = torch.utils.data.Subset(dataset, range_elements)

pbar_epochs = tqdm(range_elements, colour="#9400d3", position=1)

for idx, entry in enumerate(datasubset):
  pbar_epochs.update(1)

  # Handling the very first dimension that is of size 1
  # Yes, we should've squeezed it when we pre-processed the data... 
  # but we are lazy af :)
  label = entry["label"][0, ...].numpy()
  
  label_patches = skimage.util.view_as_windows(
    arr_in=label, window_shape=(PATCH_SIZE, PATCH_SIZE), step=1
  )

  img = entry["img"].numpy()
  img = np.sum(img, axis=0)

  img_patches = skimage.util.view_as_windows(
    arr_in=img, window_shape=(PATCH_SIZE, PATCH_SIZE), step=1
  )

  seg_label_to_patch_id_dict = {
    "0": [],
    "1": [],
    "2": [],
    "3": [],
    "4": [],
    "5": [],
  }

  for i in range(label_patches.shape[0]):
    for j in range(label_patches.shape[1]):
      
      if np.sum(img_patches[i, j, ...]).astype(np.int8) == 0:
        continue
      
      label_value = int(label_patches[i, j, PATCH_SIZE//2, PATCH_SIZE//2])

      seg_label_to_patch_id_dict[str(label_value)].append((i, j))

      segmentation_label_count_dict[str(label_value)] += 1

  seg_label_to_patch_id_path = f"{entry['full_path']}/seg_label_to_patch_id_patch_size_{PATCH_SIZE}.json"
  with open(seg_label_to_patch_id_path, 'w') as fp:
    json.dump(seg_label_to_patch_id_dict, fp)


seg_label_count_path = f"../data/brats_2013_seg_label_count_patch_size_{PATCH_SIZE}.json"
with open(seg_label_count_path, 'w') as fp:
  json.dump(segmentation_label_count_dict, fp)



  0%|          | 0/1000 [00:00<?, ?it/s]

In [7]:
segmentation_label_count_dict

{'0': 13144701, '1': 236350, '2': 150474, '3': 65663, '4': 42702, '5': 20229}