# Parameters
Edge_types mapping: 10_10: 0, 20_20: 1, 40_40: 2, 10_20: 3, 20_40: 3

In [None]:
graph_type = 'heterogeneous' # ['homogeneous', 'heterogeneous']
fold = 'fold1' # ['fold1', 'fold2', 'fold3']

if fold == 'fold1':
    train_slides_vpc = [2, 5, 6, 7]
    val_slides_vpc = [3]
    test_slides_vpc = [1]
elif fold == 'fold2':
    train_slides_vpc = [1, 3, 6, 7]
    val_slides_vpc = [2]
    test_slides_vpc = [5]
elif fold == 'fold3':
    train_slides_vpc = [1, 2, 3, 5]
    val_slides_vpc = [7]
    test_slides_vpc = [6]
# else:
#     assert False 'Please choose the correct fold! - {"fold1", "fold2", "fold3"}'

model_path = 'model/{}/model_mag_multi'.format(fold)
# path_outcomes = '../data/VPC-TMA/vpc_cores_ladan.csv'
# path_outcomes = '../data/Zurich_TMA/Zurich_GG.csv'
# path_outcomes = '../data/Colorado/Colorado_GG.csv'
path_outcomes = '../data/PANDA/train.csv'

# path_VPC = '../feature_extractor_6class/VPC_embeddings/'

# Import

In [None]:
# import dgl
# from dgl.data import DGLDataset
from torch_geometric.data import Dataset
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx, to_undirected
import torch
import os
import networkx as nx # graph visualization
import pickle
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt

# Utils

In [None]:
# a function to move tensors from the CPU to the GPU
def dict_to_device(orig, device):
    new = {}
    for k,v in orig.items():
        new[k] = v.to(device)
    return new

def get_index_by_name(name):
    temp = name.split('_')
    slide = int(temp[0][-3:])
    core = int(temp[1][-3:])
    return 160*(slide-1) + core - 1

def GG_from_core_index(df, name):
    index_core = get_index_by_name(name)
    label = df.GG_TMA[index_core]
    if label == 'B':
        label = 0
    elif label.isnumeric():
        label = int(label)
    else:
        return -1
    return label

def is_neighbors(key, key_n, patch_size):
    x, y = key.split('_')
    x, y = int(x), int(y)
    x_n, y_n = key_n.split('_')
    x_n, y_n = int(x_n), int(y_n)
    return (x_n - x)**2 + (y_n - y)**2 <= patch_size**2

## https://stackoverflow.com/questions/5967500/how-to-correctly-sort-a-string-with-a-number-inside
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

def Biomax_GG_from_core(df1, df3, name):
    temp = name.split('_')
    slide = int(temp[0][-3:])
    core = int(temp[1][-3:])
    if slide == 1:
        GS = df1.gleason_score[core-1]
    else:
        GS = df3.gleason_score[core-1]
    if GS == '-':
        return 0
    else:
        major, minor = int(GS[0]), int(GS[-1])
        GS = major + minor
        if 2 in (major, minor): # benign in majors
            GG = -1
        else:
            if GS > 8:
                GG = 5
            elif GS == 8:
                GG = 4
            elif GS == 6:
                GG = 1
            else: # GS == 7
                if major == 4:
                    GG = 3
                else: # major = 3
                    GG = 2
    return GG

In [None]:
n1 = np.repeat(np.array([0,1,2,3,4,5,6]),5)
e = to_undirected(torch.tensor([[0,0,0,0,0,0],[1,2,3,4,5,6]], dtype=torch.long))
edge_index = e.detach()
print(edge_index.t())
edge_attr = torch.tensor(np.random.rand(35,1))

x = torch.tensor([[0], [0], [0], [0], [0], [1], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index, edge_attr = edge_attr)
print(data)
# Data(edge_attr=[35, 1], edge_index=[2, 35], x=[7, 1])

networkX_graph = to_networkx(data, node_attrs=["x"], edge_attrs=["edge_attr"])
nx.draw_networkx(networkX_graph)

### DGL ###
# g = dgl.graph(([0,0,0,0,0],[1,2,3,4,5]))
# print(g.edges())
# gx = dgl.to_networkx(g)
# nx.draw_networkx(gx)

# Dataset

In [None]:
class VPCDataset(Dataset):
    def __init__(self, root, fold, magnifications, path_outcomes=None, transform=None, pre_transform=None, pre_filter=None):
        self.fold = fold
        self.magnifications = magnifications
        if path_outcomes is None: # Biomax dataset
            self.df1 = pd.read_csv('../data/Biomax_TMA/PR1921a.csv').reset_index(drop = True)
            self.df3 = pd.read_csv('../data/Biomax_TMA/PR807b.csv').reset_index(drop = True)
        else: # Zurich or VPC data
            self.df = pd.read_csv(path_outcomes).reset_index(drop = True)
#         self.df.columns = [c.replace(' ', '_').replace('/','_') for c in df.columns]
        super().__init__(root + fold + '/', transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return '../../../../feature_extractor_6class_PANDA/256_Radboud_embeddings/{}'.format(self.fold)
        # return '../../../../feature_extractor_6class/VPC_embeddings/{}'.format(self.fold)

    @property
    def processed_file_names(self):
        return ['not_implemented']
#         return [f for f in os.listdir(self.root + '/processed') if f.split('_')[-1] == '10.pt']

    def download(self):
        # Download to `self.raw_dir`.
        print('download')
        pass

    def process(self):
        patch_size = 256 #512
        raw_paths = os.listdir(self.raw_paths[0])
        if '.ipynb_checkpoints' in raw_paths: raw_paths.remove('.ipynb_checkpoints')
        for num, raw_path in enumerate(raw_paths):
            print(f'{raw_path}, {num}/{len(raw_paths)}')
            ## if VPC
            # y = GG_from_core_index(self.df, raw_path) # For VPC
            # y = Biomax_GG_from_core(self.df1, self.df3, raw_path) # for Biomax
            # if y == -1: # label is unknown: 'S' or 'X' ... or '-' in Biomax
            #     continue
            ## elif Zurich
            # y = int(self.df.GG[self.df.name == raw_path]) # This one line for Zurich
            y = int(self.df.isup_grade[self.df.image_id == raw_path]) # This one line for PANDA
            ## end if
            embd_paths = []
            for magnification in self.magnifications:
            #     embd_paths.append(self.raw_paths[0] + '/' + raw_path + '/512/{}/256_aug_model_avgpool.pkl'.format(magnification))
                embd_paths.append(self.raw_paths[0] + '/' + raw_path + '/{}/{}/256_aug_model_avgpool.pkl'.format(patch_size, magnification))
            m = len(self.magnifications)
            embd_dicts = []
            for embd_path in embd_paths:
                with open(embd_path, 'rb') as f:
                    embd_dicts.append(pickle.load(f))
            dict_keys = [k for k in embd_dicts[0].keys()]
            dict_keys.sort(key=natural_keys)
            if '.ipynb_checkpo' in dict_keys: dict_keys.remove('.ipynb_checkpo')
            x = np.zeros((m, len(dict_keys), 512)) # 512 is hard coded! 
            edge = []
            ### 2*len(dict_keys) for undirected graph times 2!
            edge_type_size = {'10_10': 0, '20_20': 0, '40_40': 0, '10_20': 2*len(dict_keys), '20_40': 2*len(dict_keys)} # hard coded!
            ## 10_10, 20_20, 40_40
            for k, embd_dict in enumerate(embd_dicts):
                for i, key in enumerate(dict_keys):
                    x[k, i] = embd_dicts[k][key]
                    for j, key_n in enumerate(dict_keys):
                        if is_neighbors(key, key_n, patch_size): #  and i != j
                            edge.append([i*m+k,j*m+k]) # [node0_mag10, node0_mag20, node0_mag40, node1_mag10, ...]
                            if k == 0:
                                edge_type_size['10_10'] += 1
                            elif k == 1:
                                edge_type_size['20_20'] += 1
                            else:
                                edge_type_size['40_40'] += 1
            # 10_20
            for i in range(len(dict_keys)):
                edge.append([i*m,i*m+1])
                edge.append([i*m+1,i*m])
            # 20_40
            for i in range(len(dict_keys)):
                edge.append([i*m+1,i*m+2])
                edge.append([i*m+2,i*m+1])
            
            # change x from 3d to 2d in the correct order
            x = np.reshape(x, (x.shape[0]*x.shape[1], x.shape[2]), order='F')
            # construct the edges
            e = torch.tensor(edge, dtype=torch.long).t()
            edge_index = e.detach()
            ## mapping 10_10 to 0, ...
            edge_type = [0]*edge_type_size['10_10']
            edge_type.extend([1]*edge_type_size['20_20'])
            edge_type.extend([2]*edge_type_size['40_40'])
            edge_type.extend([3]*edge_type_size['10_20'])
            edge_type.extend([4]*edge_type_size['20_40'])
            
            edge_type = torch.tensor(edge_type, dtype=torch.long).detach()
            data = Data(x=torch.from_numpy(x), edge_index=edge_index, y=torch.tensor([y]))
            if graph_type == 'heterogeneous':
                data = data.to_heterogeneous(edge_type=edge_type)
            # print(raw_path)
            
#             networkX_graph = to_networkx(data, node_attrs=["x"]) # change the order of edges :((
#             node_pos = [[int(k.split('_')[0]) + (i%3)*150, int(k.split('_')[1]) + (i%3)*150] for k in dict_keys for i in range(m)]
#             nx.draw_networkx(networkX_graph, pos=node_pos)
            
#             G = nx.Graph()
#             G.add_nodes_from(range(m*len(dict_keys)))
#             G.add_edges_from([(f[0], f[1], {'color':edge_type[i]}) for i, f in enumerate(edge)])
    
#             plt.figure(1,figsize=(16,16))
#             nx.draw_networkx(G, pos=node_pos, edge_color=[G[u][v]['color'] for u,v in G.edges()])
#             plt.show()
            
            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)
            
            torch.save(data, os.path.join(self.processed_dir, f'{raw_path}_{graph_type}.pt'))

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
        return data

In [None]:
# dataset = VPCDataset('Biomax_data/', fold, [10, 20, 40])
# dataset = VPCDataset('VPC_Zurich_data/', fold, [10, 20, 40], path_outcomes)
# dataset = VPCDataset('Zurich_data/', fold, [10, 20, 40], path_outcomes)
dataset = VPCDataset('Radboud_data/', fold, [10, 20, 40], path_outcomes)
# dataset = VPCDataset('data/', fold, [10, 20, 40], path_outcomes)

In [None]:
print(len(dataset))

In [None]:
a = np.array([[[1,2],[3,4]],[[5,6],[7,8]], [[9,10],[11,12]]])
print(a)

In [None]:
print(np.reshape(a, (3*2,2), order='F'))