In [None]:
import os
import torch
import torchvision as tv
import numpy as np

In [None]:
# constants
MAIN_DIR = 'data/brain_cancer/'
DATA_DIR = f'{MAIN_DIR}Brain Cancer/'
ZIP_FILE = f'{MAIN_DIR}BrainCancer.zip'

In [None]:
# check if ./data dir is empty
if os.path.isdir(DATA_DIR) and len(os.listdir(DATA_DIR)) > 0:
    print('Data already extracted!\nSkipping ....')
else:
    os.system(f'unzip {ZIP_FILE} -d {MAIN_DIR}')
    print('\n\nData extracted!')

In [None]:
DATA_PATHS = []  # list of test data paths
for dirname, _, filenames in os.walk(DATA_DIR):
    for filename in filenames:
      # skip folders
        if not os.path.isdir(os.path.join(dirname, filename)):
          DATA_PATHS.append(os.path.join(dirname, filename))
          
# create a dictionary with the number of images for each folder of DATA_DIR
nImages={}
for dirname, _, filenames in os.walk(DATA_DIR):
  if len(filenames)>0:
    name = dirname.split('_')[-1]
    nImages[name]=len(filenames)
print(nImages)

In [None]:
# read the images as a tensor with torchvision
images = [tv.io.read_image(img) for img in DATA_PATHS]

In [None]:
# Assign the labels to the images (files are named with the correct classification)

# code: 
#  0 --> brain glioma
#  1 --> meningioma
#  2 --> tumor

data_len = len(DATA_PATHS)
labels = torch.zeros(data_len)

for i, img in enumerate(DATA_PATHS):
    labels[i] = 0 if 'glioma' in img else 1 if 'menin' in img else 2

In [None]:
# If we don't do this intermediate step it does not work
images_np = np.array(images)

In [None]:
images_th = torch.tensor(images_np)

In [None]:
images_th.shape

In [None]:
class Dataset(torch.utils.data.Dataset):
    """
    overload dataset
    """

    def __init__ (self, images, labels):
        self.images = images
        self.labels = labels
        self.images = self.images.float()
        self.labels = self.labels.long()

    def __len__ (self): 
        return(len(self.images))

    def __getitem__ (self, idx):
        return self.images[idx], self.labels[idx]

In [None]:
my_dataset = Dataset(images_th, labels)

In [None]:
torch.save(my_dataset, f'{MAIN_DIR}brain_cancer_dataset.pt')

In [None]:
# Define a toy_dataset with smaller numbers:
images_th_glioma = images_th[0:100]
images_th_meningioma = images_th[nImages['glioma']:nImages['glioma']+100]
images_th_tumor = images_th[nImages['glioma']+nImages['meningioma']:nImages['glioma']+nImages['meningioma']+100]

labels_glioma = labels[0:100]
labels_meningioma = labels[nImages['glioma']:nImages['glioma']+100]
labels_tumor = labels[nImages['glioma']+nImages['meningioma']:nImages['glioma']+nImages['meningioma']+100]

In [None]:
images_combined = torch.cat((images_th_glioma, images_th_meningioma, images_th_tumor), dim=0)
labels_combined = torch.cat((labels_glioma, labels_meningioma, labels_tumor), dim=0)

In [None]:
small_dataset = Dataset(images_combined, labels_combined)

In [None]:
torch.save(small_dataset, f'{MAIN_DIR}brain_cancer_dataset_small.pt')