In [None]:
import numpy as np
import matplotlib.pyplot as plot
from math import ceil
from scanpath_graph import ScanPathGraph
from iou_graph import IOUGraph
from euclidean_graph import EuclideanGraph
import torch

In [None]:
full_meta_path = '../reflacx_lib/full_meta.json' # if file doesn't exist, it will be created
reflacx_dir = "../data/reflacx"
mimic_dir = "../data/mimic/reflacx_imgs"

from metadata import Metadata

metadata = Metadata(reflacx_dir, mimic_dir, full_meta_path)

In [None]:
dicom_id = '0658ad3c-b4f77a56-2ed1609f-ea71a443-d847a975'
reflacx_id = 'P109R167865'

In [None]:
sample = metadata.get_sample(dicom_id, reflacx_id)

In [None]:
mean_features = torch.from_numpy(np.load('avg_DensNet_REFLACX_features.npy'))

In [None]:
g_scan_path = ScanPathGraph(dicom_id,
                            reflacx_id,
                            reflacx_sample=sample,
                            mean_features=mean_features)

In [None]:
img = sample.draw_fixations()
bb = sample.get_chest_bounding_box()
bb

In [None]:
plot.imshow(img[bb['ymin']: bb['ymax'] + 1, bb['xmin']: bb['xmax'] + 1])

In [None]:
print(str(g_scan_path))

In [None]:
img = sample.draw_fixations()
bb = sample.get_chest_bounding_box()


In [None]:
g_scan_path = ScanPathGraph(dicom_id,
                            reflacx_id,
                            reflacx_sample=sample,
                            mean_features=mean_features)
g_euc = IOUGraph(dicom_id,
                      reflacx_id,
                      reflacx_sample=sample,
                      mean_features=mean_features)
g_iou = EuclideanGraph(dicom_id,
                      reflacx_id,
                      reflacx_sample=sample,
                      mean_features=mean_features)

g_scan_path.draw()
g_euc.draw()
g_iou.draw()


In [None]:
g_scan_path.nodes[0].features

In [None]:
from feature_extraction.dense_feature_extraction import DenseFeatureExtractor
from gaze_tracking_graph import GazeTrackingGraph
import os

In [None]:
class ReflacxGraphDataset:
    def __init__(self,
                 name,
                 metadata,
                 graph_class=GazeTrackingGraph,
                 stdevs=1,
                 feature_extractor=DenseFeatureExtractor(),
                 mean_normalize_features=True,
                 mean_features_fpath=None):
        self.name = name
        mean_features = None
        if mean_normalize_features:
            if mean_features_fpath is None:
                mean_features = feature_extractor.get_reflacx_avg_features(metadata)
            elif not os.path.exists(mean_features_fpath):
                mean_features = feature_extractor.get_reflacx_avg_features(metadata,
                                                                           fname=mean_features_fpath)
            elif mean_features_fpath.split('.')[-1] == 'pt':
                mean_features = torch.load(mean_features_fpath)
            else:
                mean_features = torch.from_numpy(np.load(mean_features_fpath))

        self.graphs = []

        for dicom_id in metadata.list_dicom_ids():
            for reflacx_id in metadata.list_reflacx_id(dicom_id):
                g = graph_class(dicom_id,
                                reflacx_id,
                                reflacx_sample=metadata.get_sample(dicom_id, reflacx_id),
                                metadata=metadata,
                                stdevs=stdevs,
                                feature_extractor=feature_extractor,
                                mean_features=mean_features)
                self.graphs.append(g)

    
    def save(self, dst_dir):
        pass
            

In [None]:
metadata.get_sample()