In [None]:
import glob
import os
import shutil
import re
from PIL import Image 
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import nibabel as nib

from modules.scandata import MriScan, MriSlice, TumourSegmentation, ScanType, ScanPlane

import PIL
import PIL.Image
import tensorflow as tf
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from tensorflow import keras
#from tensorflow.keras import layers
#from tensorflow.keras.models import Sequential

In [None]:
images = []
maps = []

train_image_dir = os.path.join('data','UPENN-GBM','slice_segmentation_stratify','train','image_data')
train_map_dir = os.path.join('data','UPENN-GBM','slice_segmentation_stratify','train','map_data')
for image_file in os.listdir(train_image_dir):
    map_file = image_file.replace('allseq', 'map')
    if not os.path.exists(os.path.join(train_map_dir,map_file)):
        raise FileNotFoundError((image_file, map_file))

    image = tf.io.read_file(os.path.join(train_image_dir,image_file))
    image = tf.io.decode_png(image, channels=4)
    map = tf.io.read_file(os.path.join(train_map_dir,map_file))
    map = tf.io.decode_png(map, channels=1)

    map = map.numpy()
    # Convert map to make class integers contiguous
    map[map==4] = 3
    map = tf.convert_to_tensor(map)
    images.append(image)
    maps.append(map)

In [None]:
train_images, val_images, train_maps, val_maps = train_test_split(images, maps, test_size=0.2)


In [None]:
fig, axs = plt.subplots(10,4)
fig.set_size_inches(12,36)
skip=0
for row in range(10):
    axs[row][0].imshow(train_images[skip+row])
    axs[row][1].imshow(train_maps[skip+row])
    axs[row][2].imshow(val_images[skip+row])
    axs[row][3].imshow(val_maps[skip+row]) 

In [None]:

train_images = tf.convert_to_tensor(train_images)
train_maps = tf.convert_to_tensor(train_maps)
val_images = tf.convert_to_tensor(val_images)
val_maps = tf.convert_to_tensor(val_maps)


In [None]:
def scaler_0_1(x):
    return x/255.0

def scaler_neg1_1(x):
    return x/127.5 - 1

def create_dataset(img, map, scaler):
    img = scaler(tf.cast(img, tf.float32))
    
    return img,map
    

In [None]:
train_data = tf.data.Dataset.from_tensor_slices(
    create_dataset(train_images, train_maps,scaler_neg1_1)
)
val_data = tf.data.Dataset.from_tensor_slices(
    create_dataset(val_images, val_maps,scaler_neg1_1)
)


In [None]:
for pic , seg in train_data.take(1):
    print(pic.shape)

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 64

In [None]:
train_batch = (
    train_data.cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)

In [None]:
for im , se in train_batch.take(1):
    print(im.shape, se.shape)