### Tools for debugging graphs in reflacx and in DGL

In [2]:
from iou_graph import IOUGraph
from scanpath_graph import ScanPathGraph
from euclidean_graph import EuclideanGraph
import matplotlib.pyplot as plt
import cv2
import dgl
import torch
import networkx as nx
import os
import numpy as np
from feature_extraction.dense_feature_extraction import DenseFeatureExtractor

from metadata import Metadata

In [3]:
class GraphPair:
    def __init__(self, reflacx_graph, dgl_graph):
        self.reflacx_graph, self.labels = reflacx_graph
        self.dgl_graph = dgl_graph

    def draw_dgl(self):
        pass

    def draw_reflacx(self):
        pass

    def dgl_iou(self):
        pass

    def reflacx_iou(self):
        pass

    def get_ious(self):
        return {'dgl': self.dgl_iou(), 'reflacx': self.reflacx_iou()}

In [None]:
class GraphDebugger:
    def __init__(self,
                 dgl_dataset_dir,
                 graph_class,
                 extractor=DenseFeatureExtractor(),
                 avg_fpath='./avg_DensNet_REFLACX_features.npy',
                 reflacx_dir= '../reflacx_lib/full_meta.json',
                 mimic_dir= "../data/reflacx",
                 full_meta_path= "../data/mimic/reflacx_imgs"):
        self.reflacx = Metadata(reflacx_dir,
                                mimic_dir,
                                full_meta_path,
                                max_dicom_lib_ram_percent=20)
        self.dataset = dgl.data.CSVDataset(dgl_dataset_dir)
        self.graph_class = graph_class
        self.extractor = extractor
        self.index = {}
        with open(os.sep.join(dgl_dataset_dir, 'index.csv')) as f:
            lines = f.readlines()
            for line in lines[1:]:
                i, did, rid = line.split(',')
                self.index[i] = (did, rid)

        self.mean_feats = np.load(avg_fpath) if avg_fpath is not None else None


    def fetch_by_dgl_index(self, i):
        did, rid = self.index[i]
        sample = self.reflacx.get_sample(did, rid)
        features = self.extractor.get_reflacx_img_features(sample, to_numpy=True)
        
        reflacx_g = self.graph_class(did,
                                     rid,
                                     sample,
                                     metadata=self.metadata,
                                     stdevs=1,
                                     feature_extractor=self.extractor,
                                     img_features=features,
                                     mean_features=self.mean_feats)
        
        return GraphPair(reflacx_g, self.dataset[i])
    
    
    def fetch_by_reflacx(self, did=None, rid=None, sample=None):
        assert sample is not None or (did is not None and rid is not None)
        if sample is None:
            sample = self.reflacx.get_sample(did, rid)
        elif did is None:
            did = sample.dicom_id
            rid = sample.reflacx_id

        features = self.extractor.get_reflacx_img_features(sample, to_numpy=True)
        reflacx_g = self.graph_class(did,
                                     rid,
                                     sample,
                                     metadata=self.metadata,
                                     stdevs=1,
                                     feature_extractor=self.extractor,
                                     img_features=features,
                                     mean_features=self.mean_feats)
        
        reverse_index = {v: k for (k, v) in self.index}
        assert (did, rid) in reverse_index

        i = reverse_index[(did, rid)]

        return GraphPair(reflacx_g, self.dataset[i])