### Tools for debugging graphs in reflacx and in DGL

In [None]:
from iou_graph import IOUGraph
from scanpath_graph import ScanPathGraph
from euclidean_graph import EuclideanGraph
import matplotlib.pyplot as plt
import cv2
import dgl
import torch
import networkx as nx
import os
import numpy as np
from feature_extraction.dense_feature_extraction import DenseFeatureExtractor

from metadata import Metadata

In [None]:
class GraphPair:
    def __init__(self, reflacx_graph, dgl_graph):
        self.reflacx_graph = reflacx_graph
        self.dgl_graph, self.dgl_labels = dgl_graph

    def draw_dgl(self):
        pass

    def draw_reflacx(self):
        pass

    def dgl_ious(self, field='weight'): # TODO review field
        result = {}
        for i in (int(i) for i in self.dgl_graph.nodes()):
            for j in (int(j) for j in self.dgl_graph.nodes()):
                try:
                    e_i = self.dgl_graph.edge_ids(i, j)
                except dgl.DGLError:
                    result[(i, j)] = 0.0
                    continue
                result[(i, j)] = np.float32(self.dgl_graph.edata[field][e_i])

        return result

    def reflacx_ious(self, canvas_sz=500):
        def get_coords(node):
            tlx, tly = node.topleft
            brx, bry = node.bottomright
            tlx = int(tlx * canvas_sz)
            tly = int(tly * canvas_sz)
            brx = int(brx * canvas_sz)
            bry = int(bry * canvas_sz)
            return tlx, tly, brx, bry
        
        def get_mask(tlx, tly, brx, bry):
            mask = np.zeros((canvas_sz, canvas_sz))
            mask = cv2.rectangle(mask, (tlx, tly), (brx, bry), 255, -1)
            mask = cv2.threshold(mask, 200, 255, cv2.THRESH_BINARY)[1]
            return mask
        
        result = {}
        g = self.reflacx_graph
        for i, node_i in enumerate(g.nodes):
            tlx_i, tly_i, brx_i, bry_i = get_coords(node_i)
            for j, node_j in enumerate(g.nodes):
                if i == j:
                    result[(i, j)] = 1.0
                    continue
                tlx_j, tly_j, brx_j, bry_j = get_coords(node_j)
                imask = get_mask(tlx_i, tly_i, brx_i, bry_i)
                umask = np.copy(imask)
                jmask = get_mask(tlx_j, tly_j, brx_j, bry_j)

                imask[imask != jmask] = 0
                umask[umask != jmask] = 255
                inter = np.count_nonzero(imask)
                union = np.count_nonzero(umask)
                
                result[(i, j)] = inter / union
        
        return result
                

    def get_ious(self):
        return {'dgl': self.dgl_iou(), 'reflacx': self.reflacx_iou()}

In [None]:
class GraphDebugger:
    def __init__(self,
                 dgl_dataset_dir,
                 graph_class,
                 extractor=DenseFeatureExtractor(),
                 avg_fpath='./avg_DensNet_REFLACX_features.npy',
                 reflacx_dir="../data/reflacx",
                 mimic_dir="../data/mimic/reflacx_imgs",
                 full_meta_path="../reflacx_lib/full_meta.json"):
        self.reflacx = Metadata(reflacx_dir,
                                mimic_dir,
                                full_meta_path,
                                max_dicom_lib_ram_percent=20)
        self.dataset = dgl.data.CSVDataset(dgl_dataset_dir)
        self.graph_class = graph_class
        self.extractor = extractor
        self.index = {}
        with open(os.sep.join([dgl_dataset_dir, 'index.csv'])) as f:
            for line in (l.strip() for l in f.readlines()[1:]):
                i, did, rid = line.split(',')
                self.index[int(i)] = (did, rid)

        self.mean_feats = np.load(avg_fpath) if avg_fpath is not None else None


    def _fetch(self, i, sample):
        features = self.extractor.get_reflacx_img_features(sample, to_numpy=True)
        reflacx_g = self.graph_class(sample.dicom_id,
                                     sample.reflacx_id,
                                     sample,
                                     metadata=self.reflacx,
                                     stdevs=1,
                                     feature_extractor=self.extractor,
                                     img_features=features,
                                     mean_features=self.mean_feats)
        
        return GraphPair(reflacx_g, self.dataset[i])

    
    def fetch_by_dgl_index(self, i):
        assert i in self.index
        did, rid = self.index[i]
        sample = self.reflacx.get_sample(did, rid)
        return self._fetch(i, sample)
    
    
    def fetch_by_reflacx(self, did=None, rid=None, sample=None):
        assert sample is not None or (did is not None and rid is not None)
        if sample is None:
            sample = self.reflacx.get_sample(did, rid)
        reverse_index = {v: k for (k, v) in self.index.items()}
        assert (did, rid) in reverse_index
        i = reverse_index[(did, rid)]
        return self._fetch(i, sample)

In [None]:
dgl_dataset_dir = 'datasets/reflacx_densnet225_iou'

In [None]:
gdb = GraphDebugger(dgl_dataset_dir, IOUGraph)

In [None]:
pair = gdb.fetch_by_reflacx('1bdf3180-0209f001-967acab6-0b811ea2-3c2e13eb', 'P300R510107')

In [None]:
r_ious = pair.reflacx_ious()

In [None]:
d_ious = pair.dgl_ious()