In [None]:
import json
import os
import scipy.ndimage
import numpy as np
import matplotlib.path as mpltPath
from matplotlib.path import Path
from random import shuffle
from openslide import open_slide, ImageSlide
import scipy.io as sio
import pdb
import sys


In [None]:
#File paths
slide_path = '/mys3bucket/TCGA_LUSC'
slides = os.listdir(slide_path)
save_path = '/home/ubuntu/codebase/Semi-Supervised-GANs/dataset/patch_data'

no_patches = 500
no_train_slides = 200
no_dev_slides = 100
no_test_slides = 100
chunk_size = 20

def split(data):
    N = len(data)
    trn_idx = int(np.ceil(0.7*N))
    dev_idx = int(np.ceil(0.2*(N)))
    
    train = data[:trn_idx]
    dev = data[trn_idx:trn_idx+dev_idx]
    test = data[trn_idx+dev_idx:]
    
    return train,dev,test


def get_mask(coords):
    coords.sort(key=lambda x: x[0],reverse=True)
    xmin,xmax = coords[-1][0],coords[0][0]
    coords.sort(key=lambda x: x[1],reverse=True)
    ymin,ymax = coords[-1][1],coords[0][1]
    maximum = max(xmax,ymax)
    minimum = min(xmin,ymin)
    x, y = np.meshgrid(np.arange(minimum,maximum), np.arange(minimum,maximum))
    print("Done generating meshgrid!")
    x, y = x.flatten(), y.flatten()
    points = np.vstack((x,y)).T
    p = Path(coords)
    grid = p.contains_points(points)
    mask = grid.reshape((maximum-minimum),(maximum-minimum))
    mask = mask.astype(int)
    x_coords,y_coords = np.nonzero(mask)
    sample_idxs = np.random.choice(x_coords, no_patches)
    return x[sample_idxs],y[sample_idxs]
    
def read_patches(x_coords,y_coords,slide_src,label):
    gen_dataX = []
    gen_dataY = []
    image = open_slide(slide_src)
    for i in range(len(x_coords)):
        patch = image.read_region((x_coords[i],y_coords[i]),0,(256,256))
        patch = patch.convert("RGB")
        patch = np.array(patch)
        gen_dataX.append(patch)
        gen_dataY.append(label)
        
        #Code to save patches as images
        #outfile = os.path.join(slide_dest,"patch_"+str(i)+".jpg")
        #patch.save(outfile,'JPEG')
        #g.write(("patch_"+str(count)+","+str(x_coords[i])+","+str(y_coords[i])+"\n"))
        
    image.close()
    print("Generated patches!")
    return gen_dataX,gen_dataY

def get_slide_path(slideID):
    for slide in slides:
        if str(slideID) == str(slide.split('_')[0]):
            return os.path.join(slide_path,slide)
    return -1

def get_random_polygon(shape):
    if len(shape)>1:
        return shape
    return -1

def generate_data(data, slide_threshold,mode):
    
    DATAX = []
    DATAY = []
    count = 0
    start = 0
    chunk_no = 1
    for annotation in data:
        slide = annotation['slideId']
        shape = annotation['shape']
        label = annotation['annotationSubstanceId']
        polygon = get_random_polygon(shape)
        slide_src = get_slide_path(slide)
        
        if slide_src == -1:
            print(str(slide)+" file does not exist")
            sys.stdout.flush()
            continue
        if polygon == -1:
            print(str(slide)+" has only point annotation")
            sys.stdout.flush()
            continue

        coords = [tuple(x) for x in polygon]
        x_coords,y_coords = get_mask(coords)
        X,Y = read_patches(x_coords,y_coords,slide_src,label)
        DATAX.extend(X)
        DATAY.extend(Y)
        count+=1
        print(">>>>"+str(count))
        
        #Saving chunks of data containing slide_threshold*no_patches
        if count%chunk_size==0 :
            save_dataX = DATAX[start:start+(slide_threshold*no_patches)]
            save_dataY = DATAY[start:start+(slide_threshold*no_patches)]
            start += (slide_threshold*no_patches)
            outfile = os.path.join(save_path,mode,str(chunk_no))
            np.savez(outfile,np.asarray(save_dataX),np.asarray(save_dataY))
            chunk_no+=1
            
        if count == slide_threshold: # Getting patches only for threshold number of slides
            break
        print("*****************************************************")
        
    return np.asarray(DATAX),np.asarray(DATAY)

In [None]:
#Shuffle the data
f = open("/mys3bucket/Annotations/annotations.txt", encoding="utf-8")
data = json.loads(f.read())
f.close()


#shuffle(data)
train,dev,test = split(data)

In [None]:
trainX, trainY = generate_data(train, no_train_slides,'train')
print("Train data generated!")

In [None]:
devX, devY = generate_data(dev, no_dev_slides,'dev')
print("Dev data generated!")

In [None]:
testX, testY = generate_data(test, no_test_slides,'test')
print("Test data generated! ")

In [None]:
# Save to numpy files
train_outfile = "train.npy"
dev_outfile = "dev.npy"
test_outfile = "test.npy"

np.save(train_outfile,train_dataset)
np.save(dev_outfile,dev_dataset)
np.save(test_outfile,test_dataset)

            