# Partial correlation weights for linear age, sex, intelligence prediction

In [1]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [5]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return list(allsubs)

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['wrat'], ['rest', 'nback', 'emoid'])
print(len(subs))

X = get_X(allts, ['rest', 'nback', 'emoid'], subs)
print(X[0].shape)

830
(830, 264, 124)


In [6]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 3

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [0.01, 0.15], 1/tr)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

X = [np.stack([ts for ts in filter_design_ts(Xp)]) for Xp in X]
# Xfiltnorm = [tsmod/np.linalg.norm(tsmod, axis=(-1), keepdims=True) for tsmod in ts]
print(X[0].shape)

(830, 264, 124)


In [8]:
# Get all partial correlations

from nilearn.connectome import ConnectivityMeasure

cm = ConnectivityMeasure(kind='partial correlation')

a,b = np.triu_indices(264, 1)

allp = []
for taskidx in range(3):
    partials = cm.fit_transform(X[taskidx].transpose(0,2,1))
    partials = partials[:,a,b]
    allp.append(partials)
    print('Done partials')

allp = np.stack(allp)
print(allp.shape)

Done partials
Done partials
Done partials
(3, 830, 34716)


In [157]:
# Check prediction using partials

import torch
import torch.nn as nn

mseLoss = nn.MSELoss()

modidx = 0
mod = 'rest'
task = "age"
sm=0

for ii in range(20):

    ntrain = 700
    idcs = torch.randperm(allp.shape[1])

    x = torch.from_numpy(allp[modidx]).float().cuda()
    x = x[idcs]
    xtr = x[:ntrain]
    xt = x[ntrain:]

    y = get_y(metadict, [task], subs)[0]
    y = torch.from_numpy(y).float().cuda()
    y = y[idcs]
    ytr = y[:ntrain]
    yt = y[ntrain:]
    mu = torch.mean(ytr)
    ytr = ytr - mu
    yt = yt - mu

    def toDict(w, acc):
        dct = dict(w=w.detach().cpu().numpy(), 
                   trsubs=sorted([subs[i] for i in idcs[:ntrain]]),
                   tsubs=sorted([subs[i] for i in idcs[ntrain:]]),
                   desc=f"Least squares partial corr {task} {mod} rmse: {float(acc)}")
        return dct

    def save(dct, dr, idx):
        base = f"/home/anton/Documents/Tulane/Research/ImageNomer/data/anton/cohorts/test/weights/partial"
        with open(f"{base}/{dr}/{mod}{idx}.pkl", 'wb') as f:
            pickle.dump(dct, f)

    w, _, _, _ = torch.linalg.lstsq(xtr, ytr)
    yhat = xt@w
    acc = mseLoss(yhat, yt)**0.5

    print(acc)
    sm += acc/20
    save(toDict(w,acc), f'{task}_mean_zero', ii)
    print(f'Done {ii}')
    
print('---')
print(sm)

tensor(29.0752, device='cuda:0')
Done 0
tensor(28.9338, device='cuda:0')
Done 1
tensor(28.3771, device='cuda:0')
Done 2
tensor(26.8056, device='cuda:0')
Done 3
tensor(29.4164, device='cuda:0')
Done 4
tensor(29.7039, device='cuda:0')
Done 5
tensor(26.5510, device='cuda:0')
Done 6
tensor(31.6080, device='cuda:0')
Done 7
tensor(28.3975, device='cuda:0')
Done 8
tensor(29.7382, device='cuda:0')
Done 9
tensor(30.5065, device='cuda:0')
Done 10
tensor(27.5603, device='cuda:0')
Done 11
tensor(29.1412, device='cuda:0')
Done 12
tensor(27.1840, device='cuda:0')
Done 13
tensor(28.2827, device='cuda:0')
Done 14
tensor(31.6508, device='cuda:0')
Done 15
tensor(28.2971, device='cuda:0')
Done 16
tensor(27.7216, device='cuda:0')
Done 17
tensor(29.5121, device='cuda:0')
Done 18
tensor(28.1083, device='cuda:0')
Done 19
---
tensor(28.8286, device='cuda:0')


In [144]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix

modidx = 0
mod = 'rest'
task = "sex"

for ii in range(20):
    ntrain = 700
    idcs = np.arange(allp[modidx].shape[0])
    np.random.shuffle(idcs)

    x = allp[modidx]
    x = x[idcs]
    xtr = x[:ntrain]
    xt = x[ntrain:]

    y = get_y(metadict, [task], subs)[0][:,0]
    y = y[idcs]
    ytr = y[:ntrain]
    yt = y[ntrain:]

    def toDict(w, acc, conf):
        dct = dict(w=w.detach().cpu().numpy(), 
                   trsubs=sorted([subs[i] for i in idcs[:ntrain]]),
                   tsubs=sorted([subs[i] for i in idcs[ntrain:]]),
                   desc=f"Logistic regression partial corr {task} {mod} acc: {float(acc)}")
        return dct

    def save(dct, dr, idx):
        base = f"/home/anton/Documents/Tulane/Research/ImageNomer/data/anton/cohorts/test/weights/partial"
        with open(f"{base}/{dr}/{mod}{idx}.pkl", 'wb') as f:
            pickle.dump(dct, f)

    clf = LogisticRegression(max_iter=1000, penalty='l2', C=1, solver='lbfgs').fit(xtr, ytr)
    yhat = clf.predict(xt)
    acc = np.sum(yhat == yt)/len(yt)
    print(acc)

    mat = confusion_matrix(yt, yhat, normalize='true', labels=[0,1])
    print(mat)

    save(toDict(w,acc,mat), f'{task}', ii)
    print(f'Done {ii}')

0.7923076923076923
[[0.89552239 0.10447761]
 [0.31746032 0.68253968]]
Done 0
0.7384615384615385
[[0.7037037  0.2962963 ]
 [0.20408163 0.79591837]]
Done 1
0.7076923076923077
[[0.72       0.28      ]
 [0.30909091 0.69090909]]
Done 2
0.7615384615384615
[[0.81333333 0.18666667]
 [0.30909091 0.69090909]]
Done 3
0.7692307692307693
[[0.80327869 0.19672131]
 [0.26086957 0.73913043]]
Done 4
0.8
[[0.8125 0.1875]
 [0.22   0.78  ]]
Done 5
0.7769230769230769
[[0.84057971 0.15942029]
 [0.29508197 0.70491803]]
Done 6
0.7692307692307693
[[0.89393939 0.10606061]
 [0.359375   0.640625  ]]
Done 7
0.7615384615384615
[[0.775 0.225]
 [0.26  0.74 ]]
Done 8
0.7461538461538462
[[0.859375   0.140625  ]
 [0.36363636 0.63636364]]
Done 9
0.8076923076923077
[[0.82352941 0.17647059]
 [0.20967742 0.79032258]]
Done 10
0.823076923076923
[[0.84615385 0.15384615]
 [0.2        0.8       ]]
Done 11
0.6846153846153846
[[0.80555556 0.19444444]
 [0.46551724 0.53448276]]
Done 12
0.8
[[0.82894737 0.17105263]
 [0.24074074 0.7592