In [1]:
import numpy as np
import matplotlib.pyplot as plot
from math import ceil
from scanpath_graph import ScanPathGraph
from iou_graph import IOUGraph
from euclidean_graph import EuclideanGraph
import torch
from feature_extraction.dense_feature_extraction import DenseFeatureExtractor
from gaze_tracking_graph import GazeTrackingGraph
import os

In [2]:
full_meta_path = '../reflacx_lib/full_meta.json' # if file doesn't exist, it will be created
reflacx_dir = "../data/reflacx"
mimic_dir = "../data/mimic/reflacx_imgs"

from metadata import Metadata

metadata = Metadata(reflacx_dir, mimic_dir, full_meta_path, max_dicom_lib_ram_percent=30)

loading metadata
metadata loaded from file


In [None]:
errors = []
from rlogger import RLogger
import os
from fixation_node import FixationNode

def generate_dataset(name,
                     metadata,
                     outdir=None,
                     filenames={'meta': 'meta.yaml',
                                'edges': 'edges.csv',
                                'nodes': 'nodes.csv',
                                'graphs': 'graphs.csv'
                                },
                     g_id = 'graph_id',
                     sep=', ',
                     graph_class=GazeTrackingGraph,
                     stdevs=1,
                     feature_extractor=DenseFeatureExtractor(),
                     mean_normalize_features=True,
                     mean_features_fpath=None,
                     log_dir='.'):
    log = RLogger(__name__)
    
    outdir = './{}'.format(name) if outdir is None else outdir
    os.makedirs(outdir, exist_ok=True)

    mean_features = None
    if mean_normalize_features:
        if mean_features_fpath is None:
            mean_features = feature_extractor.get_reflacx_avg_features(metadata)
        elif not os.path.exists(mean_features_fpath):
            mean_features = feature_extractor.get_reflacx_avg_features(metadata,
                                                                        fname=mean_features_fpath)
        elif mean_features_fpath.split('.')[-1] == 'pt':
            mean_features = torch.load(mean_features_fpath)
        else: # .npy
            mean_features = torch.from_numpy(np.load(mean_features_fpath))

    with open(os.sep.join([outdir, filenames['meta']]), 'w') as f:
        f.writelines(['dataset_name: {}'.format(name),
                        '\nedge_data:',
                        '\n- file_name: {}'.format(filenames['edges']),
                        '\nnode_data:',
                        '\n- file_name: {}'.format(filenames['nodes']),
                        '\ngraph_data:',
                        '\nfile_name: {}'.format(filenames['graphs'])])

    e_csv = open(os.sep.join([outdir, filenames['edges']]), 'w')
    n_csv = open(os.sep.join([outdir, filenames['nodes']]), 'w')
    g_csv = open(os.sep.join([outdir, filenames['graphs']]), 'w')        
    
    csv_line = lambda prefix, line: sep.join([prefix, line]) + '\n'
    csv_header = lambda line: csv_line(g_id, line)
    
    n_csv.write(csv_header(FixationNode.csv_header()))
    e_csv.write(csv_header(graph_class.edge_csv_header()))
    g_csv.write(csv_header('labels'))
    
    i = 0
    for dicom_id in metadata.list_dicom_ids(n_samples=10): #TODO remove after debug
        for reflacx_id in metadata.list_reflacx_ids(dicom_id):
            RLogger.start(os.path.sep.join([log_dir,
                                    '{}__{}.log'.format(dicom_id, reflacx_id)]))
            try:
                curr_line = lambda line: csv_line(i, line)
                g = graph_class(dicom_id,
                            reflacx_id,
                            reflacx_sample=metadata.get_sample(dicom_id, reflacx_id),
                            metadata=metadata,
                            stdevs=stdevs,
                            feature_extractor=feature_extractor,
                            mean_features=mean_features)
                g_csv.write(curr_line(g.graph_csv(labels='common')))
                g.write_nodes_csv(n_csv, curr_line)
                g.write_edges_csv(e_csv, curr_line)
            except:
                errors.append((dicom_id, reflacx_id))
                log('bad graph for pair {} --- {}'.format(dicom_id,
                                                            reflacx_id),
                    exception=True)
                continue

    n_csv.close()
    e_csv.close()
    g_csv.close()

        

In [3]:
x = metadata.get_sample('34cedb74-d0996b40-6d218312-a9174bea-d48dc033', 'P102R108387')
img = x.get_dicom_img()
img

In [None]:
g = ScanPathGraph('34cedb74-d0996b40-6d218312-a9174bea-d48dc033', 'P102R108387', metadata.get_sample('34cedb74-d0996b40-6d218312-a9174bea-d48dc033', 'P102R108387'))

In [None]:
generate_dataset('test',
                 metadata,
                 outdir='./dataset',
                 graph_class=ScanPathGraph,
                 mean_features_fpath='avg_DensNet_REFLACX_features.npy',
                 log_dir='./log')

In [None]:
ds = ReflacxGraphDataset('',
                         metadata,
                         ScanPathGraph,
                         mean_features_fpath='avg_DensNet_REFLACX_features.npy')

In [None]:
import logging
logger = logging.getLogger()
logging.basicConfig(filename='example.log', level=logging.DEBUG)

In [None]:
errors = [('9678dc02-54a05e84-f5efffa5-bc62e0a2-83dac014', 'P300R050750'),
 ('7a1165df-fc1f2e7a-f901fa11-d73f3ee4-91abd6ae', 'P300R430029'),
 ('767626c8-a068ea4b-578a5042-7bbdaec4-efc21ef2', 'P300R038991')]

In [None]:
did, rid = errors[0]
sample = metadata.get_sample(did, rid)
print('lalala')

In [None]:
g = ScanPathGraph(errors[0][0], errors[0][1], metadata=metadata, mean_features=mean_features)

In [None]:
import inspect
inspect.stack()[0][3]

In [None]:
sample = metadata.get_sample(*errors[1])
fixations = sample.get_fixations()

In [None]:
extractor = DenseFeatureExtractor()
img_features = extractor.get_img_features(sample.get_dicom_img(),
                                          mean_features=mean_features,
                                          to_numpy=True)
fix = fixations[42]
fix, sample.get_dicom_img().shape

In [None]:
trans_fix = extractor.transform_fixation((fix['x_position'], fix['y_position']),
                             fix['angular_resolution_x_pixels_per_degree'],
                             fix['angular_resolution_y_pixels_per_degree'],
                             img_size=sample.get_dicom_img().shape,
                             normalize=True)

fixation_pos, fixation_crop = trans_fix

        
tl = fixation_crop[0]
br = fixation_crop[1]
tr = (br[0], tl[1])
bl = (tl[0], br[1])

adjustpos = lambda point:(point[0] * img_features.shape[2],
                                  point[1] * img_features.shape[1])
        
tl = adjustpos(tl)
bl = adjustpos(bl)
tr = adjustpos(tr)
br = adjustpos(br)

(fixation_pos, fixation_crop), (tl, bl, tr, br)

In [None]:
from math import ceil, floor

In [None]:
h_region_count = ceil(tr[0]) - floor(tl[0])
v_region_count = ceil(bl[1]) - floor(tl[1])

h_region_count, v_region_count

In [None]:
list(range(floor(tl[0]), ceil(tr[0]))), list(range(floor(tl[1]), ceil(bl[1])))

In [None]:
crop_area = (tr[0] - tl[0]) * (bl[1] - tl[1])

# calculate intersection between fixation crop
# and each of the feature regions
result = np.zeros(img_features.shape[0], dtype=img_features.dtype)

for i in range(int(tl[0]), int(tl[0]) + h_region_count):
    for j in range(int(tl[1]), int(tl[1]) + v_region_count):
        xmin = max(tl[0], i)
        ymin = max(tl[1], j)
        xmax = min(br[0], i + 1)
        ymax = min(br[1], j + 1)

        coef = (xmax - xmin) * (ymax - ymin) / crop_area
        result = np.sum([result,
                                img_features[:, i, j] * coef],
                                     axis=0)

In [None]:
ff = extractor.get_fixation_features((fix['x_position'], fix['y_position']),
                                     fix['angular_resolution_x_pixels_per_degree'],
                                     fix['angular_resolution_y_pixels_per_degree'],
                                     img_features=img_features,
                                     mean_features=mean_features,
                                     img_size=None)#sample.get_dicom_img().shape)

In [None]:
extractor.last_img_size, sample.get_dicom_img().shape