# GNN Convergence

In [1]:
import os,sys

parent = os.path.dirname(os.path.dirname(os.getcwd()))
if parent not in sys.path:
    sys.path.append(parent)

---

In [2]:
# DATA

train_ind, val_ind = [], []
src_folder = os.path.join(parent, "data", "processed",
                            "BACH_TRAIN")
graph_split = os.path.join(src_folder, "graph_ind.txt")
with open(graph_split, "r") as f:
    l1 = f.readline().strip()
    l2 = f.readline().strip()
    train_ind = list(map(int, l1[1:-1].split(",")))
    val_ind = list(map(int, l2[1:-1].split(",")))
    
from src.datasets.BACH import BACH
from torch_geometric.loader.dataloader import DataLoader

train_set, val_set = BACH(src_folder, ids=train_ind,pred_mode=True), BACH(src_folder, ids=val_ind,pred_mode=True)

In [3]:
# PREDICTOR

from src.predict_cancer import predict_cancer
import torch
from src.model.architectures.cancer_prediction.pred_gnn import PredGNN
from torch_geometric.transforms import Compose, KNNGraph, RandomTranslate, Distance
from torch.nn.functional import softmax

gnn_voter_args = {"LAYERS": 8, "WIDTH": 4, "GLOBAL_POOL": "MEAN", "RADIUS_FUNCTION": "NONE", "POOL_RATIO": 1}
graph_trans = Compose([KNNGraph(6),  Distance(norm=False, cat=False)])
gnn_voter_loc=os.path.join(parent,"model", "GNN_VOTER_RESNET.ckpt")

def predictor():
    gnn_voter = PredGNN.load_from_checkpoint(gnn_voter_loc, **gnn_voter_args).eval().cuda()
    #gnn_voter.intermediate_activations = {}
    #def create_hook(layer):
    #    def hook(m,i,o):
    #        gnn_voter.intermediate_activations[layer] = softmax(o.cpu(),dim=1)
    #    return hook 
    #for lay in range(gnn_voter_args["LAYERS"]):
    #    
    #    gnn_voter.layer_end[lay].register_forward_hook(create_hook(lay))
    return gnn_voter

def vote_convergence(graph):
    with torch.no_grad():
        gnn_voter = predictor()
        voting_graph = graph_trans(graph).cuda()
        output,intermeddiate = gnn_voter(voting_graph.x, voting_graph.edge_index, voting_graph.edge_attr,
                            torch.zeros(voting_graph.x.shape[0]).long().cuda()).squeeze()
        return intermeddiate

---

In [4]:
i = 2
sequence,edge_index,pos = vote_convergence(val_set[i]),val_set[i].edge_index,val_set[i].pos

KeyError: 'model.6.conv.lin.weight'

---
## Graph Vizualization

In [None]:
import networkx as nx
from torch_geometric.utils import to_networkx
from src.transforms.graph_augmentation.largest_component import LargestComponent
from torch_geometric.transforms import KNNGraph,Compose,Distance, RandomTranslate
import matplotlib.pyplot as plt

def viz_graph_nx(graph):
    
#"b", "is", "iv", "n"
    c = ["white","green","red","black","blue"]
    colours = list(map(lambda x: c[x],((graph.x.argmax(dim=1)+1)*(graph.x.max(dim=1).values>0.4).int()).tolist()))

    G = to_networkx(graph,to_undirected=True)
    pos = {i:tuple(graph.pos[i]) for i in range(len(graph.pos))}


    nx.draw(G,pos=pos,node_color=colours)

In [None]:
import matplotlib.pyplot as plt

import imageio
import io
import matplotlib.pyplot as plt

from src.utilities.img_utilities import tensor_to_numpy
from src.transforms.graph_construction.hover_maps import hover_map
from tqdm import tqdm
import os
from numpy.ma import masked_where
from src.transforms.graph_construction.hovernet_post_processing import hovernet_post_process
from src.transforms.graph_construction.percolation import hollow
from torch_geometric.data import Data


def graph_vote_convergence_animation(intermediate_x, edge_index, pos,location,  fps=10):
    with imageio.get_writer(location, mode='I', fps=fps, format="gif") as writer:
        for frame in tqdm(range(0,len(intermediate_x)), desc="Generating Voting Graph GIF"):
            f,ax = plt.subplots(1,1,figsize=(40,40))
            graph = Data(x=intermediate_x[frame], edge_index=edge_index, pos=pos)
            viz_graph_nx(graph)
            
            ax.set_title(f"{frame}")
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            plt_img = imageio.imread(buf)
            writer.append_data(plt_img)

In [None]:
graph_vote_convergence_animation(sequence,edge_index=edge_index,pos=pos,location=os.path.join(parent,"voting_graph.gif"),fps=3)