In [1]:
import sys
sys.path.append('..')
from src import datagen
from src import engine
from src import utils
import yaml
from time import time
import os
import pandas as pd


In [2]:
table = pd.read_csv('../data/TADPOLE_D1_D2_proc_norm.csv')

In [3]:
def search_table(table,pid,visits):
    '''
    Searches the table for the entries
    Total people in test set who des corresponding to the pid. Checks if pid label is 'Dementia'.
    
    Input:
        visits: list of integer values of visits in trajectory
    
    Returns True is 'Dementia'
    Returns False if not
    '''
    viscodes = [convert_viscode(v) for v in visits]
    for v in viscodes:
        label = table[(table['PTID'] == pid) & (table['VISCODE'] == v)]['DX'].values[0]
        if label == 'Dementia' or label == 'AD' or label == 'MCI to Dementia':
            return False
    return True

def convert_viscode(visit):
    visit = visit.item()
    if(visit == 0):
        viscode = 'bl'
    elif(visit*6 < 10) :
        viscode = 'm0'+str(visit*6)
    else:
        if(visit*6 == 30):
            viscode = 'm'+str(36)
        else:
            viscode = 'm'+str(visit*6)
    return viscode

def pid_convert(pid):
    '''
    Converts pids that in the form of integers to strings in the proper format
    
    Example: 012368 --> 001_S_2368
    
             12345  --> 001_S_2345
    '''
    p = pid.item()
    if 100000 < p < 1000000:
        p = '0'+str(p)[:2] + '_S_' + str(p)[2:]
    elif p < 100000 :
        p = '00' + str(p)[:1] + '_S_' + str(p)[1:]
    else:
        p = str(p)[:3] + '_S_' + str(p)[3:]
    return p

class Trajectory_Stats:
    def __init__(self, pid, y, y_pred, trajectory_id):
        self.pid = pid
        self.y = y
        _, self.y_pred = y_pred.max(0) #check this to make sure it's not different from argmax
        
        self.trajectory_id = trajectory_id
    
    def print_self(self):
        print('PID: ', self.pid)
        print('label: ', self.y)
        print('prediction: ', self.y_pred)
        print('trajectory_id: ', self.trajectory_id)

class Cascade:
    def __init__(self, trajectories,pid):
        '''
        A cascade is a series of trajectories from a single patient used for tabulating results for
        AD transition detection at the patient level.
        '''
        self.pid = pid
        self.trajectories = trajectories
        self.trajectory_ids = []
        for t in trajectories:
            self.trajectory_ids.append(t.trajectory_id)
      
        flag = 0
        
        #TODO: Check Baseline label
        
        # Record the index where the transition occurs in ground truth
        if(trajectories[0].y == 2):
            self.already_ad = True
        else:
            self.already_ad = False
            
        for t in trajectories:
            if(pid == '137_S_0668'):
                print(t.y)
            if(t.y == 2):
                if(pid == '137_S_0668'):
                    print('finally!!!!')
                flag = 1
                self.transition_true = t.trajectory_id[-1]
                break
                
        if flag == 0:
            self.transition_true = -1
            
        flag = 0
        
        # Record the index where the transition occurs in FLARe
        for t in trajectories:
            if(t.y_pred == 2):
                flag = 1
                self.transition_flare = t.trajectory_id[-1]
                break
                
        if flag == 0:
            self.transition_flare = -1
           
        # Check for false positives. If transition true = -1, then the patient never develops AD
        if(self.transition_true == -1):
            self.type = 'No AD'
            if(self.transition_flare != -1):
                self.FP = True
            else:
                self.FP = False
        else:
            self.type = 'AD'
            if(self.transition_flare == -1):
                self.FN = True
            else:
                self.FN = False
                self.diff = self.transition_flare - self.transition_true
                
    def print_self(self):
        print('PID: ',self.pid)
        print('True Transition: ',self.transition_true)
        print('Predicted Transition: ',self.transition_flare)
        
        print('Trajectory Type: ', self.type)
        
        for t in self.trajectories:
            t.print_self()
        
        if(self.transition_flare > 0 and self.transition_true > 0):
            print('Difference in prediction time from ground truth: ', self.diff)
        print('Trajectories', self.trajectory_ids)

In [4]:
with open('../data/stats.pickle','rb') as f:
    output = pickle.load(f)

In [5]:
def gen_patient_dict(val_dict_T):
    patients = {}
    num_entries = len(val_dict_T['pid'])
    print(num_entries)
    with open('../data/pids.pickle','rb') as f: # use ../data/patient_List_test.txt
        unique_pids = pickle.load(f)
    patients = {}
    
    # initialize patient dict
    for x in unique_pids:
        patients[x] = []
    
    for i in range(num_entries):
        pid = pid_convert(val_dict_T['pid'][i])
        temp_traj = Trajectory_Stats(pid, val_dict_T['y'][i],val_dict_T['y_pred'][i],val_dict_T['trajectory_id'][i])
        patients[pid].append(temp_traj)
    return patients



In [6]:
def create_cascade(patients, T):
    '''
    Takes a patient dict as input and creates cascades for a certain value of 'T'
    '''
    final = {}
    cascades = {}
    for key in patients:
        # sort Trajectory objects by lexicographical order of their trajectory id's
        patients[key].sort(key=lambda k: k.trajectory_id.tolist())
        final[key] = []
        for t in patients[key]:
            if(T == 1):
                if t.trajectory_id[0] == 0 and len(patients[key]) > 1:
                    final[key].append(t)
            elif (T == 2):
                if t.trajectory_id[0] == 0 and t.trajectory_id[1] == 1 and len(patients[key]) > 2:
                    final[key].append(t)
         
    for key in final:
        if final[key]:
            if(key == '137_S_0668'):
                print('Trajectory -1: ',final[key][0].trajectory_id[:-1])
                print('---')
            if(search_table(table,key,final[key][0].trajectory_id[:-1])):
                cascades[key] = Cascade(final[key],key)
    return cascades

        

In [7]:
def print_counts(cascades, T):
    '''
    Takes a dict of cascades as input and prints out counts
    '''
    counts = [0]*8
    FN = 0
    FP = 0
    total_no_ad = 0
    first_visit = 0
    developed = 0
    already_ad = 0
    for key in cascades.keys():
        cascade = cascades[key]
        #print(cascade.print_self())
        if(cascade.pid == '137_S_0668'):
            cascade.print_self()
        if cascade.type == 'AD':
            if(not cascade.FN):
                developed += 1
                counts[cascade.diff + 3] += 1
            else:
                developed += 1
                FN += 1
        if cascade.type == 'No AD':
            total_no_ad += 1
            if cascade.FP:
                #cascade.print_self()
                FP += 1
    print('Predicted within... index 2 corresponds to -1',counts)
    print('False Negatives',FN)
    print("Total people in test set who didn't develop AD",total_no_ad)
    print('number of False Positives',FP)
    print("Total people who already had AD in the first visit", already_ad)
    total = np.sum(counts) + total_no_ad + already_ad + first_visit + FN
    print("Total people who developed AD during the study:", np.sum(counts) + FN)
    print("Total people", total)


In [8]:
def print_counts_T2(cascades):
    '''
    Takes a dict of cascades as input and prints out counts
    '''
    counts = [0]*8
    FN = 0
    FP = 0
    total_no_ad = 0
    first_visit = 0
    developed = 0
    already_ad = 0
    for key in cascades.keys():
        cascade = cascades[key]
        if(cascade.pid == '137_S_0668'):
            cascade.print_self()
        if cascade.type == 'AD':
            if(not cascade.FN):
                developed += 1
                counts[cascade.diff + 3] += 1
            else:
                developed += 1
                FN += 1
        if cascade.type == 'No AD':
            total_no_ad += 1
            if cascade.FP:
                #cascade.print_self()
                FP += 1
    print('Predicted within... index 2 corresponds to -1',counts)
    print('False Negatives',FN)
    print("Total people in test set who didn't develop AD",total_no_ad)
    print('... the number of False Positives',FP)
    print("Total people who already had AD in the first visit", already_ad)
    total = np.sum(counts) + total_no_ad + already_ad + first_visit + FN
    print("Total people who developed AD during the study:", np.sum(counts) + FN)
    print("Total people", total)

In [None]:
#Usage example


In [9]:
T = 1
patients = gen_patient_dict(output[T-1])
#patients[list(patients.keys())[0]][0].print_self()
cascades_T1 = create_cascade(patients,T)
print_counts(cascades_T1,T)

3489
Trajectory -1:  tensor([0])
---
tensor(1.)
tensor(1.)
tensor(0.)
tensor(1.)
tensor(1.)
PID:  137_S_0668
True Transition:  -1
Predicted Transition:  -1
Trajectory Type:  No AD
PID:  137_S_0668
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 1])
PID:  137_S_0668
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 2])
PID:  137_S_0668
label:  tensor(0.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 3])
PID:  137_S_0668
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 4])
PID:  137_S_0668
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 5])
Trajectories [tensor([0, 1]), tensor([0, 2]), tensor([0, 3]), tensor([0, 4]), tensor([0, 5])]
Predicted within... index 2 corresponds to -1 [0, 0, 18, 25, 0, 1, 1, 0]
False Negatives 6
Total people in test set who didn't develop AD 202
number of False Positives 6
Total people who already had AD in the first visit 0
Total people who developed AD during the study: 51
To

In [10]:
#Usage example
T = 2
patients = gen_patient_dict(output[T-1])

#patients[list(patients.keys())[0]][0].print_self()
cascades_T2 = create_cascade(patients,T)
print_counts_T2(cascades_T2)


2150
Trajectory -1:  tensor([0, 1])
---
tensor(1.)
tensor(0.)
tensor(1.)
tensor(1.)
PID:  137_S_0668
True Transition:  -1
Predicted Transition:  tensor(2)
Trajectory Type:  No AD
PID:  137_S_0668
label:  tensor(1.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 1, 2])
PID:  137_S_0668
label:  tensor(0.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 1, 3])
PID:  137_S_0668
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 1, 4])
PID:  137_S_0668
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 1, 5])
Trajectories [tensor([0, 1, 2]), tensor([0, 1, 3]), tensor([0, 1, 4]), tensor([0, 1, 5])]
Predicted within... index 2 corresponds to -1 [0, 0, 9, 25, 0, 0, 0, 0]
False Negatives 5
Total people in test set who didn't develop AD 175
... the number of False Positives 7
Total people who already had AD in the first visit 0
Total people who developed AD during the study: 39
Total people 214


In [20]:
count = 0
for key in cascades_T1:
    if(cascades_T1[key].transition_true != -1):
        cascades_T1[key].print_self()
        print('--------------')

PID:  002_S_0729
True Transition:  tensor(2)
Predicted Transition:  tensor(2)
Trajectory Type:  AD
PID:  002_S_0729
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 1])
PID:  002_S_0729
label:  tensor(2.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 2])
PID:  002_S_0729
label:  tensor(2.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 3])
PID:  002_S_0729
label:  tensor(2.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 4])
PID:  002_S_0729
label:  tensor(2.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 5])
Difference in prediction time from ground truth:  tensor(0)
Trajectories [tensor([0, 1]), tensor([0, 2]), tensor([0, 3]), tensor([0, 4]), tensor([0, 5])]
--------------
PID:  002_S_1070
True Transition:  tensor(3)
Predicted Transition:  tensor(2)
Trajectory Type:  AD
PID:  002_S_1070
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 1])
PID:  002_S_1070
label:  tensor(1.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 2])
PID

prediction:  tensor(2)
trajectory_id:  tensor([0, 1])
PID:  116_S_4167
label:  tensor(2.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 2])
PID:  116_S_4167
label:  tensor(2.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 3])
PID:  116_S_4167
label:  tensor(2.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 4])
PID:  116_S_4167
label:  tensor(2.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 5])
Difference in prediction time from ground truth:  tensor(-1)
Trajectories [tensor([0, 1]), tensor([0, 2]), tensor([0, 3]), tensor([0, 4]), tensor([0, 5])]
--------------
PID:  123_S_0390
True Transition:  tensor(4)
Predicted Transition:  tensor(3)
Trajectory Type:  AD
PID:  123_S_0390
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 1])
PID:  123_S_0390
label:  tensor(1.)
prediction:  tensor(1)
trajectory_id:  tensor([0, 2])
PID:  123_S_0390
label:  tensor(1.)
prediction:  tensor(2)
trajectory_id:  tensor([0, 3])
PID:  123_S_0390
label:  tensor(2.)
prediction: