In [1]:
import sys
sys.path.append('..')
from src import datagen
from src import engine
from src import utils
import yaml
from time import time
import os
import pandas as pd


In [2]:
table = pd.read_csv('../data/TADPOLE_D1_D2_proc_norm.csv')

In [36]:
def search_table(table,pid):
    '''
    Searches the table for the entri
Total people in test set who des corresponding to the pid. Checks if pid label is 'Dementia'.
    
    Returns True is 'Dementia'
    Returns False if not
    '''
    label = table[(table['PTID'] == pid) & (table['VISCODE'] =='bl')]['DX'].values[0]
    if label == 'Dementia' or label == 'AD':
        return True
    else:
        return False
    
def pid_convert(pid):
    '''
    Converts pids that in the form of integers to strings in the proper format
    
    Example: 012368 --> 001_S_2368
    
             12345  --> 001_S_2345
    '''
    p = pid.item()
    if 100000 < p < 1000000:
        p = '0'+str(p)[:2] + '_S_' + str(p)[2:]
    elif p < 100000 :
        p = '00' + str(p)[:1] + '_S_' + str(p)[1:]
    else:
        p = str(p)[:3] + '_S_' + str(p)[3:]
    return p

class Trajectory_Stats:
    def __init__(self, pid, y, y_pred, trajectory_id):
        self.pid = pid
        self.y = y
        _, self.y_pred = y_pred.max(0) #check this to make sure it's not different from argmax
        
        self.trajectory_id = trajectory_id
    
    def print_self(self):
        print('PID: ', self.pid)
        print('label: ', self.y)
        print('prediction: ', self.y_pred)
        print('trajectory_id: ', self.trajectory_id)

class Cascade:
    def __init__(self, trajectories):
        '''
        A cascade is a series of trajectories from a single patient used for tabulating results for
        AD transition detection at the patient level.
        '''
        self.trajectories = trajectories
        self.
        for t in trajectories
      
        flag = 0
        
        #TODO: Check Baseline label
        
        # Record the index where the transition occurs in ground truth
        for t in trajectories:
            if(t.y == 2):
                flag = 1
                self.transition_true = t.trajectory_id[-1]
                break
                
        if flag == 0:
            self.transition_true = -1
            
        flag = 0
        
        # Record the index where the transition occurs in FLARe
        for t in trajectories:
            if(t.y_pred == 2):
                flag = 1
                self.transition_flare = t.trajectory_id[-1]
                break
                
        if flag == 0:
            self.transition_flare = -1
           
        # Check for false positives. If transition true = -1, then the patient never develops AD
        if(self.transition_true == -1):
            self.type = 'No AD'
            if(self.transition_flare != -1):
                self.FP = True
            else:
                self.FP = False
        else:
            self.type = 'AD'
            if(self.transition_flare == -1):
                self.FN = True
            else:
                self.FN = False
                self.diff = self.transition_flare - self.transition_true
                
    def print_self(self):
        print('True Transition: ',self.transition_true)
        print('Predicted Transition: ',self.transition_flare)
        
        print('Trajectory Type: ', self.type)
        
        if(self.transition_flare > 0 and self.transition_true > 0):
            print('Difference in prediction time from ground truth: ', self.diff)
        print('Trajectories', self.trajectories)
        
        
            
        

In [4]:
with open('../data/stats.pickle','rb') as f:
    output = pickle.load(f)

In [20]:
def gen_patient_dict(val_dict_T):
    patients = {}
    num_entries = len(val_dict_T['pid'])
    print(num_entries)
    with open('../data/pids.pickle','rb') as f: # use ../data/patient_List_test.txt
        unique_pids = pickle.load(f)
    patients = {}
    
    # initialize patient dict
    for x in unique_pids:
        patients[x] = []
    
    for i in range(num_entries):
        pid = pid_convert(val_dict_T['pid'][i])
        temp_traj = Trajectory_Stats(pid, val_dict_T['y'][i],val_dict_T['y_pred'][i],val_dict_T['trajectory_id'][i])
        patients[pid].append(temp_traj)
    return patients



In [30]:
def create_cascade(patients, T):
    '''
    Takes a patient dict as input and creates cascades for a certain value of 'T'
    '''
    final = {}
    cascades = {}
    for key in patients:
        # sort Trajectory objects by lexicographical order of their trajectory id's
        patients[key].sort(key=lambda k: k.trajectory_id.tolist())
        final[key] = []
        for t in patients[key]:
            if(T == 1):
                if t.trajectory_id[0] == 0 and len(patients[key]) > 1:
                    final[key].append(t)
            elif (T == 2):
                if t.trajectory_id[0] == 0 and t.trajectory_id[1] == 1 and len(patients[key]) > 2:
                    final[key].append(t)
         
    for key in final.keys():
        cascades[key] = Cascade(final[key])
    return cascades

In [32]:
def print_counts(cascades, T):
    '''
    Takes a dict of cascades as input and prints out counts
    '''
    counts = [0]*8
    FN = 0
    FP = 0
    total_no_ad = 0
    first_visit = 0
    developed = 0
    already_ad = 0
    for key in cascades.keys():
        cascade = cascades[key]
        print(cascade.print_self())
        if cascade.type == 'AD':
            if(not cascade.FN):
                if(cascade.trajectories[T-1].y == 2):
                    if(not search_table(table,key)):
                        developed += 1
                        counts[cascade.diff + 3] += 1
                    else:
                        already_ad += 1
                else:
                    developed += 1
                    counts[cascade.diff + 3] += 1
            else:
                developed += 1
                FN += 1
        if cascade.type == 'No AD':
            total_no_ad += 1
            if cascade.FP:
                FP += 1
    print('Predicted within... index 2 corresponds to -1',counts)
    print('False Negatives',FN)
    print("Total people in test set who didn't develop AD",total_no_ad)
    print('... and of those people the number of False Positives',FP)
    print("Total people who already had AD in the first visit", already_ad)
    total = np.sum(counts) + FN + FP + total_no_ad + already_ad + first_visit
    print("Total people who developed AD:", developed)
    print("Total people", total)


In [29]:
#Usage example
T = 1
patients = gen_patient_dict(output[T-1])

patients[list(patients.keys())[0]][0].print_self()


3489
PID:  002_S_0729
label:  tensor(2.)
prediction:  tensor(2)
trajectory_id:  tensor([4, 5])


In [37]:
cascades = create_cascade(patients,T)
print_counts(cascades,T)

True Transition:  tensor(2)
Predicted Transition:  tensor(2)
Trajectory Type:  AD
Difference in prediction time from ground truth:  tensor(0)
Trajectories [<__main__.Trajectory_Stats object at 0x7f7c3e06d7f0>, <__main__.Trajectory_Stats object at 0x7f7c3d3dd550>, <__main__.Trajectory_Stats object at 0x7f7c41ee49b0>, <__main__.Trajectory_Stats object at 0x7f7c41f448d0>, <__main__.Trajectory_Stats object at 0x7f7c3ce997b8>]
None
True Transition:  tensor(3)
Predicted Transition:  tensor(2)
Trajectory Type:  AD
Difference in prediction time from ground truth:  tensor(-1)
Trajectories [<__main__.Trajectory_Stats object at 0x7f7c41b97198>, <__main__.Trajectory_Stats object at 0x7f7c3d62e160>, <__main__.Trajectory_Stats object at 0x7f7c3cde3ba8>, <__main__.Trajectory_Stats object at 0x7f7c3ce99668>, <__main__.Trajectory_Stats object at 0x7f7c3d716be0>]
None
True Transition:  -1
Predicted Transition:  -1
Trajectory Type:  No AD
Trajectories [<__main__.Trajectory_Stats object at 0x7f7c41c3ba58>

True Transition:  tensor(1)
Predicted Transition:  tensor(1)
Trajectory Type:  AD
Difference in prediction time from ground truth:  tensor(0)
Trajectories [<__main__.Trajectory_Stats object at 0x7f7c401ca9b0>, <__main__.Trajectory_Stats object at 0x7f7c40150be0>, <__main__.Trajectory_Stats object at 0x7f7c3d794ac8>, <__main__.Trajectory_Stats object at 0x7f7c42228a20>, <__main__.Trajectory_Stats object at 0x7f7c41b97860>]
None
True Transition:  tensor(2)
Predicted Transition:  tensor(2)
Trajectory Type:  AD
Difference in prediction time from ground truth:  tensor(0)
Trajectories [<__main__.Trajectory_Stats object at 0x7f7c3d2e5cf8>, <__main__.Trajectory_Stats object at 0x7f7c41eb0d30>, <__main__.Trajectory_Stats object at 0x7f7c41873898>, <__main__.Trajectory_Stats object at 0x7f7c3cde37b8>]
None
True Transition:  tensor(1)
Predicted Transition:  tensor(1)
Trajectory Type:  AD
Difference in prediction time from ground truth:  tensor(0)
Trajectories [<__main__.Trajectory_Stats object at

True Transition:  -1
Predicted Transition:  tensor(5)
Trajectory Type:  No AD
Trajectories [<__main__.Trajectory_Stats object at 0x7f7c3dc4c3c8>, <__main__.Trajectory_Stats object at 0x7f7c4004c4a8>, <__main__.Trajectory_Stats object at 0x7f7c3dd3ceb8>, <__main__.Trajectory_Stats object at 0x7f7c41ed8438>, <__main__.Trajectory_Stats object at 0x7f7c3fdf8e10>]
None
True Transition:  -1
Predicted Transition:  -1
Trajectory Type:  No AD
Trajectories [<__main__.Trajectory_Stats object at 0x7f7c3d3dd358>, <__main__.Trajectory_Stats object at 0x7f7c3dd9d710>]
None
True Transition:  -1
Predicted Transition:  -1
Trajectory Type:  No AD
Trajectories []
None
True Transition:  tensor(1)
Predicted Transition:  tensor(4)
Trajectory Type:  AD
Difference in prediction time from ground truth:  tensor(3)
Trajectories [<__main__.Trajectory_Stats object at 0x7f7c41acdba8>, <__main__.Trajectory_Stats object at 0x7f7c3dd3c710>, <__main__.Trajectory_Stats object at 0x7f7c3d5739b0>, <__main__.Trajectory_Stat