# Parameters

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
num_classes = 6
batch_size = 32
train_on = 'Karolinska'

magnification = 'heterogeneous' # ['10', '20', '40', 'heterogeneous', 'homogeneous']
fold = 'fold1' 
iterate = 1

# model_name = '_2GCN_1GIN_2GCN_concat_deeppool_2linear_layernorm'
# model_name = '_MS_RGCN_4relu'
model_name = '_MS_RGCN'

if train_on == 'Karolinska':
    model_path = 'models/{}/model_mag_{}_{}.pth'.format(fold, magnification + model_name, iterate)
    data_path = f'{train_on}_data/'

print(model_path)
# path_VPC = '../feature_extractor_6class/VPC_embeddings/'

# Import

In [None]:
# import for RGATConv
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter, ReLU, Sequential
from torch_scatter import scatter_add
from torch_sparse import SparseTensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, ones, zeros
from torch_geometric.typing import Adj, OptTensor, Size
from torch_geometric.utils import softmax

In [None]:
# import dgl
# from dgl.data import DGLDataset
from torch_geometric.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import ModuleDict
from torch_geometric.nn import GCNConv, RGCNConv, global_mean_pool, to_hetero, GATConv, SAGPooling, BatchNorm, LayerNorm, AGNNConv, ResGatedGraphConv, SGConv, GINConv #, InstanceNorm, GraphNorm, PairNorm
import os
import networkx as nx # graph visualization
import pickle
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
from IPython import display
from sklearn.metrics import roc_auc_score, balanced_accuracy_score
from sklearn.metrics import cohen_kappa_score
import time
# from pyg_class.RGAT_Conv import RGATConv

# Utils

In [None]:
# a function to move tensors from the Double to Float
def dict_to_float(orig):
    new = {}
    for k,v in orig.items():
        new[k] = v.float()
    return new

def get_edge_index_type(data):
    edge_index = torch.empty((2,0), dtype=torch.long).cuda()
    edge_type = torch.empty((0), dtype=torch.long).cuda()
    for i, t in enumerate(data.edge_types):
        edge_index = torch.cat((edge_index, data.edge_index_dict[t]), 1)
        edge_type = torch.cat((edge_type, torch.ones(data.edge_index_dict[t].shape[1]).cuda()*i))
    return edge_index, edge_type

def quadratic_weighted_kappa(y_hat, y):
    return cohen_kappa_score(y_hat, y, weights="quadratic")

class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(self):
        self.best_acc = 0
        
    def __call__(
        self, current_acc, 
        epoch, model, optimizer
    ):
        if current_acc > self.best_acc:
            self.best_acc = current_acc
            print(f"\nBest validation AUC: {self.best_acc}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': current_acc
                }, model_path)
            
save_model = SaveBestModel()

# PYG Utils by Roozbeh

In [None]:
class GlobalAttentionPooling(torch.nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.att = torch.nn.MultiheadAttention(embed_dim, num_heads, dropout, batch_first=True)
        
    def forward(self, x, batch):
        batch_size = batch[-1].detach() + 1
        embed_dim = x.shape[-1]
        output = torch.zeros((batch_size, embed_dim)).to(x.get_device())
        for b in range(batch_size):
            x_batch = x[batch == b].unsqueeze(0)
            attn_output, _ = self.att(x_batch, x_batch, x_batch)
            output[b] = torch.mean(attn_output[0], dim=0)
            
        return output
        

# Dataset

In [None]:
class VPCDataset(Dataset):
    def __init__(self, root, fold, magnification, train, transform=None, pre_transform=None, pre_filter=None):
        self.fold = fold
        self.fold_temp = fold
        if fold == 'fold4':
            self.fold_temp = 'fold1'
        elif fold == 'fold5':
            self.fold_temp = 'fold2'
        self.magnification = magnification
        self.train = train
        super().__init__(root + self.fold_temp + '/', transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return '../../../../feature_extractor_6class/VPC_embeddings/{}'.format(self.fold_temp)

    @property
    def processed_file_names(self):
        graphs = [f for f in os.listdir(self.root + '/processed') 
                    if f.split('_')[-1] == f'{self.magnification}.pt']
        if self.fold in ['fold1', 'fold2', 'fold3']:
            if self.train:
                return graphs[:len(graphs)*4//5]
            else:
                return graphs[len(graphs)*4//5:]
        if self.train:
            return graphs[len(graphs)//5:]
        else:
            return graphs[:len(graphs)//5]

    def download(self):
        # Download to `self.raw_dir`.
        assert False, 'went to download'

    def process(self):
        assert False, 'went to process'

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
        return data, self.processed_file_names[idx][:16]
    

## Zurich dataset
class ZurichDataset(Dataset):
    def __init__(self, root, slides, fold, magnification, transform=None, pre_transform=None, pre_filter=None):
        self.fold = fold
        self.magnification = magnification
        self.slides = slides
        super().__init__(root + fold + '/', transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return '../../../../feature_extractor_6class/VPC_embeddings/{}'.format(self.fold)

    @property
    def processed_file_names(self):
#         return ['not_implemented']
        return [f for f in os.listdir(self.root + '/processed') 
                if f.split('_')[-1] == f'{self.magnification}.pt' and f.split('_')[0] in self.slides]

    def download(self):
        # Download to `self.raw_dir`.
        assert False, 'went to download'

    def process(self):
        assert False, 'went to process'

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
        return data, self.processed_file_names[idx]

In [None]:
if train_on == 'Karolinska':
    dataset_train = VPCDataset(data_path, fold, magnification, True)
    dataset_val = VPCDataset(data_path, fold, magnification, False)

    loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)

In [None]:
# setting training weights
weights = torch.zeros((6,), dtype=torch.int16)
for data, core_name in loader_train:
    if magnification == 'heterogeneous':
        for c in data['0'].y:
            weights[c] += 1
    else:
        for c in data.y:
            weights[c] += 1

weights = weights.detach()
weights = 1 / (weights)# + 1e-8) # avoid division by 0
weights /= torch.sum(weights)
print(weights)

### setting validation weights
weights_val = torch.zeros((6,), dtype=torch.int16)
for data, core_name in loader_val:
    if magnification == 'heterogeneous':
        for c in data['0'].y:
            weights_val[c] += 1
    else:
        for c in data.y:
            weights_val[c] += 1

weights_val = weights_val.detach()
weights_val = 1 / (weights_val)# + 1e-8) # avoid division by 0
weights_val /= torch.sum(weights_val)
print(weights_val)

# Model

In [None]:
class DeepMIL(nn.Module):
    def __init__(self, feature_size, num_classes):
        super().__init__()
        self.L = feature_size
        self.D = 128
        self.K = 1

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, self.D),
            nn.ReLU(),
            nn.Linear(self.D, num_classes)
            # nn.Sigmoid()
        )

    def forward(self, H):
        # x = x.squeeze(0)

        # H = self.feature_extractor_part1(x)
        # H = H.view(-1, 50 * 4 * 4)
        # H = self.feature_extractor_part2(H)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N
        
        M = torch.mm(A, H)  # KxL

        Y_prob = self.classifier(M)
        # Y_hat = torch.ge(Y_prob, 0.5).float()

        return A, Y_prob

class MIL(nn.Module):
    def __init__(self, feature_size, num_classes, MIL_type):
        super().__init__()
        self.num_classes = num_classes
        if MIL_type == 'deep':
            self.model = DeepMIL(feature_size, num_classes)
        else:
            assert False, "MIL_type should be in ['deep', 'var'] set!"
            
        
    def forward(self, x, batch):
        # x, batch = data.x.float(), data.batch
        batch_size = batch[-1].detach() + 1
        embed_dim = x.shape[-1]
        output = torch.zeros((batch_size, num_classes)).to(x.get_device())
        for b in range(batch_size):
            # x_batch = x[batch == b].unsqueeze(0)
            x_batch = x[batch == b]
            _, output[b] = self.model(x_batch) # in for loop since the dimensions does not match
            
        return output
    
    def prediction(self, data):
        return torch.argmax(self.forward(data), dim=1)

In [None]:
class Rel_GNN(torch.nn.Module):
    def __init__(self, feature_size, num_classes):
        super().__init__()
        embedding_size = 16
        # dim = 32
        # num_relations = 5
        
        # gnn layers # Edge_types mapping: 10_10: 0, 20_20: 1, 40_40: 2, 10_20: 3, 20_40: 3
        self.neigh1 = RGCNConv(feature_size, feature_size, 3, is_sorted=True)
        self.neigh2 = RGCNConv(feature_size, feature_size, 3, is_sorted=True)
        self.mag = RGCNConv(feature_size, feature_size, 6, is_sorted=True)
        self.neigh3 = RGCNConv(feature_size, embedding_size*8, 3, is_sorted=True)
        self.neigh4 = RGCNConv(embedding_size*8, embedding_size*2, 3, is_sorted=True)
        
        # normalization
        # self.norm_neigh1 = LayerNorm(feature_size)
        # self.norm_neigh2 = LayerNorm(feature_size)
        # self.norm_mag = LayerNorm(feature_size)
        self.norm_neigh3 = LayerNorm(feature_size)
        self.norm_neigh4 = LayerNorm(embedding_size*8)

        # pooling
        self.pool = MIL(embedding_size*2*3, num_classes, 'deep')
    
    def forward(self, data):
        x, batch = data['0'].x.float(), data['0'].batch
        # edge_index, edge_type = get_edge_index_type(data)
        
        neigh_edge_index = torch.cat((data.edge_index_dict[('0','0','0')], data.edge_index_dict[('0','1','0')], data.edge_index_dict[('0','2','0')]), dim=1)
        neigh_edge_type = torch.cat((torch.zeros(data.edge_index_dict[('0','0','0')].shape[1], dtype=torch.long), 
                                     torch.ones(data.edge_index_dict[('0','1','0')].shape[1], dtype=torch.long),
                                    torch.ones(data.edge_index_dict[('0','2','0')].shape[1], dtype=torch.long)*2))
        
        ### neighbor block
        # x = self.norm_neigh1(x, batch)
        x = self.neigh1(x, neigh_edge_index, neigh_edge_type)
        # x = self.norm_neigh2(x, batch)
        # x = F.relu(x)
        x = self.neigh2(x, neigh_edge_index, neigh_edge_type)
        # x = self.norm_mag(x, batch)
        # x = F.relu(x)
        
        ### magnification block # 10:0, 20:1, 40:2
        n = x.shape[0] // 3
        mag_edge_index = torch.empty((2,0), dtype=torch.long).cuda()
        mag_edge_type = torch.zeros((n*6), dtype=torch.long).cuda()
        for i in range(1,6):
            mag_edge_type[i*n:(i+1)*n] = i
        for i in range(3):
            for j in range(3):
                if i == j: continue
                mag_edge_index = torch.cat((mag_edge_index, torch.tensor([range(i,n*3,3), range(j,n*3,3)]).cuda()), dim=1)
        
        
        x = self.mag(x, mag_edge_index, mag_edge_type)
        x = self.norm_neigh3(x, batch)
        x = F.relu(x)
        ### neighbor block
        x = self.neigh3(x, neigh_edge_index, neigh_edge_type)
        x = self.norm_neigh4(x, batch)
        x = F.relu(x)
        x = self.neigh4(x, neigh_edge_index, neigh_edge_type)
        # x = F.relu(x)
        
        batch = batch[range(0, x.shape[0], 3)]
        x = torch.cat((x[range(0, x.shape[0], 3)], x[range(1, x.shape[0], 3)], x[range(2, x.shape[0], 3)]), dim=1)
        
        ### aggregation
        x = self.pool(x, batch=batch)

        return x
    
    def prediction(self, data):
        return torch.argmax(self.forward(data), dim=1)
    
class simple_GNN(torch.nn.Module):
    def __init__(self, feature_size, num_classes):
        super().__init__()
        embedding_size = 16
        # dim = 32
        # num_relations = 5
        
        self.neigh1 = GCNConv(feature_size, feature_size, is_sorted=True)
        self.neigh2 = GCNConv(feature_size, feature_size, is_sorted=True)
        self.mag = GCNConv(feature_size, feature_size, is_sorted=True)
        self.neigh3 = GCNConv(feature_size, embedding_size*8, is_sorted=True)
        self.neigh4 = GCNConv(embedding_size*8, embedding_size*2, is_sorted=True)
        
        # normalization
        # self.norm_neigh1 = LayerNorm(feature_size)
        # self.norm_neigh2 = LayerNorm(feature_size)
        # self.norm_mag = LayerNorm(feature_size)
        self.norm_neigh3 = LayerNorm(feature_size)
        self.norm_neigh4 = LayerNorm(embedding_size*8)

        # pooling
        self.pool = MIL(embedding_size*2*3, num_classes, 'deep')
    
    def forward(self, data):
        x, batch = data['0'].x.float(), data['0'].batch
        # edge_index, edge_type = get_edge_index_type(data)
        
        neigh_edge_index = torch.cat((data.edge_index_dict[('0','0','0')], data.edge_index_dict[('0','1','0')], data.edge_index_dict[('0','2','0')]), dim=1)
        neigh_edge_type = torch.cat((torch.zeros(data.edge_index_dict[('0','0','0')].shape[1]), torch.ones(data.edge_index_dict[('0','1','0')].shape[1]),
                                    torch.ones(data.edge_index_dict[('0','2','0')].shape[1])*2))
        
        ### neighbor block
        # x = self.norm_neigh1(x, batch)
        x = self.neigh1(x, neigh_edge_index)
        # x = self.norm_neigh2(x, batch)
        # x = F.relu(x)
        x = self.neigh2(x, neigh_edge_index)
        # x = self.norm_mag(x, batch)
        # x = F.relu(x)
        
        ### magnification block # 10:0, 20:1, 40:2
        n = x.shape[0] // 3
        mag_edge_index = torch.empty((2,0), dtype=torch.long).cuda()
        mag_edge_type = torch.zeros((n*6), dtype=torch.long).cuda()
        for i in range(1,6):
            mag_edge_type[i*n:(i+1)*n] = i
        for i in range(3):
            for j in range(3):
                if i == j: continue
                mag_edge_index = torch.cat((mag_edge_index, torch.tensor([range(i,n*3,3), range(j,n*3,3)]).cuda()), dim=1)
        
        
        x = self.mag(x, mag_edge_index)
        x = self.norm_neigh3(x, batch)
        x = F.relu(x)
        ### neighbor block
        x = self.neigh3(x, neigh_edge_index)
        x = self.norm_neigh4(x, batch)
        x = F.relu(x)
        x = self.neigh4(x, neigh_edge_index)
        # x = F.relu(x)
        
        batch = batch[range(0, x.shape[0], 3)]
        x = torch.cat((x[range(0, x.shape[0], 3)], x[range(1, x.shape[0], 3)], x[range(2, x.shape[0], 3)]), dim=1)
        
        ### aggregation
        x = self.pool(x, batch=batch)

        return x
    
    def prediction(self, data):
        return torch.argmax(self.forward(data), dim=1)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
if magnification == 'heterogeneous':
    model = Rel_GNN(dataset_train[0][0]['0'].num_node_features, num_classes).to(device)
else:
    model = simple_GNN(dataset_train.num_node_features, num_classes).to(device)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)
print(model)

# Training

In [None]:
num_epochs = 100
# prepare plotting
fig = plt.figure(figsize=(20, 5), dpi= 80, facecolor='w', edgecolor='k')
axes = fig.subplots(1,3)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.3, patience=10, verbose=True)

sm = torch.nn.Softmax(dim=1)

losses = []
val_losses = []
val_accs = []
val_auc = []
for epoch in range(num_epochs):
    model.train()
    for i, (data, core_name) in enumerate(loader_train):
        if magnification == 'heterogeneous':
            y = data['0'].y
        else:
            y = data.y
        optimizer.zero_grad()
        data.to(device)
        pred = model(data)
#         print(pred)
        loss = F.cross_entropy(pred.to('cpu'), F.one_hot(y, num_classes=num_classes).double(), weight = weights)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    
    model.eval()
    pred_val = np.zeros(len(dataset_val))
    pred_val_loss = torch.zeros((len(dataset_val), num_classes))
    label = torch.zeros(len(dataset_val), dtype=torch.long).detach()
    for i, (data, core_name) in enumerate(loader_val):
        index = i*batch_size
        if magnification == 'heterogeneous':
            y = data['0'].y
        else:
            y = data.y
        label[index:index + y.shape[0]] = y
        data.to(device)
        pred = model(data)
        pred_val[index:index + y.shape[0]] = torch.argmax(pred, dim=1).cpu().detach().numpy()
        
        pred_val_loss[index:index + y.shape[0], :] = pred.cpu().detach()
        
    loss = F.cross_entropy(pred_val_loss, F.one_hot(label, num_classes=num_classes).double(), weight = weights_val)
    auc = roc_auc_score(label, sm(pred_val_loss), average ='macro', multi_class='ovr')
    label = label.numpy()
    val_acc = np.mean(label == pred_val)
    val_accs.append(val_acc)
    val_losses.append(loss.item())
    val_auc.append(auc)
    
    for ax in axes:
        ax.cla()
    # plot the training loss on a log plot
    axes[0].plot(losses, label='loss')
    axes[0].set_yscale('log')
    axes[0].set_title('Training loss')
    axes[0].set_xlabel('number of gradient iterations')
    axes[0].legend()
    
    # plot the validation loss on a log plot
    axes[1].plot(val_losses, label='loss')
    axes[1].set_yscale('log')
    axes[1].set_title('Validation loss')
    axes[1].set_xlabel('number of epochs')
    axes[1].legend()
    
    # plot the validation loss on a log plot
    axes[2].plot(val_accs, label='val_acc')
    axes[2].plot(val_auc, label='val_auc')
#     axes[1].set_yscale('log')
    axes[2].set_title('Validation Accuracy')
    axes[2].set_xlabel('number of epochs')
    axes[2].legend()
    
    # clear output window and diplay updated figure
    display.clear_output(wait=True)
    display.display(plt.gcf())
    print(f'Epoch{epoch + 1} of {num_epochs} ({100*(epoch + 1)/num_epochs})%, val_acc = {val_acc}, val_auc = {auc}')
    
    ## saving the model
    save_model(auc, epoch, model, optimizer)
    
    scheduler.step(loss)
    
plt.close('all')

# Test Data

In [None]:
class testDataset(Dataset):
    def __init__(self, root, fold, magnification, transform=None, pre_transform=None, pre_filter=None):
        self.fold = fold
        self.fold_temp = fold
        if fold == 'fold4':
            self.fold_temp = 'fold1'
        elif fold == 'fold5':
            self.fold_temp = 'fold2'
        self.magnification = magnification
        super().__init__(root + self.fold_temp + '/', transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return '../../../../feature_extractor_6class/VPC_embeddings/{}'.format(self.fold_temp)

    @property
    def processed_file_names(self):
        return [f for f in os.listdir(self.root + '/processed') 
                    if f.split('_')[-1] == f'{self.magnification}.pt']

    def download(self):
        # Download to `self.raw_dir`.
        assert False, 'went to download'

    def process(self):
        assert False, 'went to process'

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
        return data, self.processed_file_names[idx][:16]
    

# Test

In [None]:
if train_on == 'Karolinska':
    data_path = 'Radboud_data/'
    dataset_test = testDataset(data_path, fold, magnification)
    loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
else:
    data_path = 'Zurich_data/'
    dataset_test = ZurichDataset(data_path, test_slides_zurich, fold, magnification)
    loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
    
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
sm = torch.nn.Softmax(dim=1)

checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

pred_file = f'results/{fold}/ms_rgcn_pred.npy'
label_file = f'results/{fold}/ms_rgcn_label.npy'

print(checkpoint['val_acc'])
print(model_path)
print(pred_file)
print(label_file)

In [None]:
tic = time.time()
test_iter = iter(loader_test)
model.eval()
pred_test = np.zeros((len(dataset_test), num_classes))
label = np.zeros(len(dataset_test))
for i, (data, core_name) in enumerate(loader_test):
    index = i*batch_size
    if magnification == 'heterogeneous':
        y = data['0'].y
    else:
        y = data.y
    label[index:index + y.shape[0]] = y
    data.to(device)
    pred = sm(model(data).cpu())
    pred_test[index:index + y.shape[0], :] = pred.detach().numpy()
toc = time.time()
print(f'duration: {(toc - tic)*1000 / len(dataset_test):.2f} ms')

In [None]:
with open(pred_file, 'wb') as f:
    np.save(f, pred_test)
with open(label_file, 'wb') as f:
    np.save(f, label)

In [None]:
print(roc_auc_score(label,pred_test,average ='macro', multi_class='ovr'))

In [None]:
score = quadratic_weighted_kappa(np.argmax(pred_test, axis=1), label)

print(f"kappa panda: {score}")

In [None]:
print(f'Balanced accuracy: {balanced_accuracy_score(label, np.argmax(pred_test, axis=1))}')