In [22]:
# See what response variables we can get for all the newly preprocessed subjects
# Sep 5, 2022

import pickle
import pandas

tsfile = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'
restcsvfile = '/home/anton/Documents/Tulane/Research/PNC/rest_fmri_power264_meta.csv'
nbackcsvfile = '/home/anton/Documents/Tulane/Research/PNC/nback_fmri_power264_meta.csv'
emoidcsvfile = '/home/anton/Documents/Tulane/Research/PNC/emoid_fmri_power264_meta.csv'
wratcsvfile = '/home/anton/Documents/Tulane/Research/PNC_Good/wrat2.csv'

allts = None

with open(tsfile, 'rb') as f:
    allts = pickle.load(f)
    
wratcsv = pandas.read_csv(wratcsvfile)
restcsv = pandas.read_csv(restcsvfile)
nbackcsv = pandas.read_csv(nbackcsvfile)
emoidcsv = pandas.read_csv(emoidcsvfile)

print('Complete')

Complete


In [23]:
# Fill up age and sex dictionaries

agedict = {}
sexdict = {}

for csv in [restcsv,nbackcsv,emoidcsv]:
    for i,row in csv.iterrows():
        if row['ID'] in agedict:
            if row['AgeInMonths'] != int(agedict[row['ID']]):
                print(f"Disagreement for {row['ID']} {row['AgeInMonths']} {agedict[row['ID']]}")
        else:
            agedict[row['ID']] = int(row['AgeInMonths'])
        if row['ID'] in sexdict:
            if row['Gender'] != sexdict[row['ID']]:
                print(f"Disagreement for {row['ID']} {row['Gender']} {sexdict[row['ID']]}")
        else:
            sexdict[row['ID']] = row['Gender']
            
print(len(list(agedict.keys())))
print(len(list(sexdict.keys())))

944
944


In [24]:
# Fill up wrat dictionary

import numpy as np

wratdict = {}

for i,row in wratcsv.iterrows():
    pncid = row['PNCID']
    std = row['Std']
    assert pncid not in wratdict
    if np.isnan(std):
        print(f"Got NaN for {pncid}")
    else:
        wratdict[pncid] = int(std)
        
print(len(list(wratdict.keys())))

Got NaN for 600004612332
Got NaN for 600046680817
Got NaN for 600278442660
Got NaN for 600468711343
Got NaN for 600591534459
Got NaN for 600627279777
Got NaN for 600634603556
Got NaN for 600750122708
Got NaN for 600756699875
Got NaN for 600810936813
Got NaN for 600865134007
Got NaN for 601151511321
Got NaN for 601156461179
Got NaN for 601253987934
Got NaN for 601255916247
Got NaN for 601349286952
Got NaN for 601390969752
Got NaN for 601412894670
Got NaN for 601447373389
Got NaN for 601472532096
Got NaN for 601520556882
Got NaN for 601564392975
Got NaN for 601637656924
Got NaN for 601668778473
Got NaN for 601773061934
Got NaN for 601900816453
Got NaN for 601974482909
Got NaN for 602047598580
Got NaN for 602061271187
Got NaN for 602091103307
Got NaN for 602104475001
Got NaN for 602109848574
Got NaN for 602256274982
Got NaN for 602526269567
Got NaN for 602543787933
Got NaN for 602561508870
Got NaN for 602594906628
Got NaN for 602655651539
Got NaN for 602666956753
Got NaN for 602744559472


In [36]:
# Missing data dict

missingage = set()
missingsex = set()
missingwrat = set()

for para in ['emoid', 'nback', 'rest']:
    for key in allts[para].keys():
        short = key[4:]
        if int(short) not in agedict:
            missingage.add(key)
        if int(short) not in sexdict:
            missingsex.add(key)
        if int(short) not in wratdict:
            missingwrat.add(key)
            
print(len(missingage))
print(len(missingsex))
print(len(missingwrat))

4
4
20


In [41]:
# Save response variables to meta dict

# Also save failed QC subjects

failedqc = [
    'sub-603403163265',
    'sub-605515760919',
    #'sub-600210241146'
]

metadict = dict(
    age=agedict,
    sex=sexdict,
    wrat=wratdict,
    missingage=missingage,
    missingsex=missingsex,
    missingwrat=missingwrat,
    failedqc=failedqc
)

savefname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'

with open(savefname, 'wb') as f:
    pickle.dump(metadict, f)

print('Complete')

Complete


In [67]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return allsubs

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['age', 'sex', 'wrat'], ['emoid', 'rest', 'nback'])
X = get_X(allts, ['emoid', 'rest', 'nback'], subs)
y = get_y(metadict, ['age', 'sex', 'wrat'], subs)
print([pX.shape for pX in X])
print([var.shape for var in y])

[(830, 264, 210), (830, 264, 124), (830, 264, 231)]
[(830,), (830, 2), (830,)]


In [85]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 1.83
N = X[0].shape[0]

def X_to_flat_fc(X):
    fc = []
    for pX in X:
        p = [np.corrcoef(butter_bandpass_filter(pX[i], [tr/20*N, 0.8*N], 2*N)) for i in range(pX.shape[0])]
        a,b = np.triu_indices(p[0].shape[0], 1)
        p = [pp[a,b] for pp in p]
        p = np.stack(p)
        fc.append(p)
    return np.stack(fc, axis=1)

p = X_to_flat_fc(X)
print(p.shape)

(830, 3, 34716)


In [89]:
# Super fast test

import torch
from math import floor

X_t = torch.from_numpy(p[:,0:3]).float().cuda()
y_t = torch.from_numpy(y[0]).float().cuda()

# Training and test set

tot = []

for i in range(10):
    N = X_t.shape[0]
    a = floor(0.8*N)
    idcs = torch.arange(N)
    idcs = torch.randperm(N)
    trainIdcs = idcs[:a]
    testIdcs = idcs[a:]
    Xtr_t = X_t[trainIdcs]
    ytr_t = y_t[trainIdcs]
    Xt_t = X_t[testIdcs]
    yt_t = y_t[testIdcs]

    w0, _, _, _ = torch.linalg.lstsq(Xtr_t[:,0], ytr_t)
    w1, _, _, _ = torch.linalg.lstsq(Xtr_t[:,1], ytr_t)
    w2, _, _, _ = torch.linalg.lstsq(Xtr_t[:,2], ytr_t)

    yhat = (Xt_t[:,0]@w0 + Xt_t[:,1]@w1 + Xt_t[:,2]@w2)/3
    rmse = torch.mean((yhat-yt_t)**2)**0.5
    tot.append(float(rmse))
    print(tot[-1])
    
print(sum(tot)/len(tot)/12)

23.76226043701172
24.083568572998047
27.265823364257812
25.899738311767578
24.34619903564453
26.30966567993164
25.690616607666016
24.78278923034668
25.24112319946289
22.2159481048584
2.0799811045328775


In [93]:
# Logistic regression on sex

from sklearn.linear_model import LogisticRegression
from math import floor

X_t = p[:,2]
y_t = np.argmax(y[1], axis=1)

# Training and test set

tot = []

for i in range(10):
    N = X_t.shape[0]
    a = floor(0.8*N)
    idcs = np.arange(N)
    np.random.shuffle(idcs)
    trainIdcs = idcs[:a]
    testIdcs = idcs[a:]
    Xtr_t = X_t[trainIdcs]
    ytr_t = y_t[trainIdcs]
    Xt_t = X_t[testIdcs]
    yt_t = y_t[testIdcs]

    clf = LogisticRegression(max_iter=5000, C=1e1).fit(Xtr_t, ytr_t)
    yhat = clf.predict(Xt_t)
    acc = np.sum(yhat == yt_t)/yhat.shape[0]
    tot.append(float(acc))
    print(tot[-1])
    
print(sum(tot)/len(tot))

0.7771084337349398
0.8132530120481928
0.7891566265060241
0.7710843373493976
0.7530120481927711
0.8132530120481928
0.7349397590361446
0.7590361445783133
0.8012048192771084
0.7831325301204819
0.7795180722891566


In [55]:
list(wratdict.values())[0:10]

[80, 93, 104, 108, 96, 119, 114, 114, 78, 99]

In [26]:
list(allts['emoid'].keys())[0:5]

['sub-600009963128',
 'sub-600018902293',
 'sub-600020927179',
 'sub-600031697545',
 'sub-600038720566']

In [5]:
restcsv.head()

Unnamed: 0,PythonID,ID,AgeInMonths,Gender,Ethnicity,AgeGroupID,AgeGroupEdge1,AgeGroupEdge2
0,0,600009963128,116,F,CAUCASIAN/WHITE,1,103,144
1,1,600018902293,187,F,CAUCASIAN/WHITE,3,180,216
2,2,600020927179,190,F,CAUCASIAN/WHITE,3,180,216
3,3,600031697545,242,M,AFRICAN,4,216,272
4,4,600038720566,137,F,OTHER/MIXED,1,103,144


In [10]:
wratcsv.head()

Unnamed: 0,PNCID,Valid,Raw,Std
0,600001676724,V,48.0,80.0
1,600003245643,V,30.0,93.0
2,600004612332,,,
3,600004963801,V,52.0,104.0
4,600005394890,V,52.0,108.0
