# Baseline Static PCA and kSVD

Can't do dynamic PCA/kSVD because each subject has at least 124 points of 34716 features,<br>
and you can't enforce a notion of temporal proximity with these methods

- Rest, nback, emoid individually (train codebook on 400)
- and rest+nback+emoid together (train codebook on 200/200/200, same subjects from each)

Response variables are are 

- age, 
- sex, 
- wrat

Evaluation methods are

- Ridge regression
- Elastic MLP
- LatSim

Number of components are 10,50,100,200,300,400,500,600,800 (+ 1000,1200,1400,1600 for kSVD only)

In [1]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [2]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [3]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return allsubs

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['age'], ['rest', 'nback', 'emoid'])
print(len(subs))

X = get_X(allts, ['rest', 'nback', 'emoid'], subs)
print(X[0].shape)

847
(847, 264, 124)


In [4]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 1.83

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [0.01, 0.2], 1/tr)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

p = [np.stack([ts_to_flat_fc(ts) for ts in filter_design_ts(Xp)]) for Xp in X]
print(p[0].shape)

(847, 34716)


In [5]:
import sys

sys.path.append('../../LatentSimilarity')

from latsim import LatSim

print('Complete')

Complete


In [12]:
from dask_ml.decomposition import PCA
import dask.array as da
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, ncodes):
        super(MLP, self).__init__()
        self.l1 = nn.Linear(ncodes, 40).float().cuda()
        self.l2 = nn.Linear(40,1).float().cuda()
        
    def train(self, xtr, ytr, nepochs=1000, lr=1e-1, l1=1e-1, l2=1e-4, pperiod=100, verbose=False):
        optim = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=l2)
        sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.75, eps=1e-7)
        
        for epoch in range(nepochs):
            optim.zero_grad()
            yhat = self(xtr)
            loss = mseLoss(yhat, ytr)**0.5
            l1loss = l1*torch.sum(torch.abs(self.l1.weight))
            (loss+l1loss).backward()
            optim.step()
            sched.step(loss)
            if verbose:
                if epoch % pperiod == 0 or epoch == nepochs-1:
                    print(f'{epoch} {[float(l) for l in [loss, l1loss]]} {sched._last_lr}')
                    
    def predict(self, xt, yt):
        with torch.no_grad():
            return mseLoss(self(xt), yt)**0.5
                    
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = self.l2(x).squeeze()
        return x

ncodesall = [10,50,100,200,300,400,500,600,800]
ntrainall = [400,400,400,400,400,400,500,600,800]

modstr = 'nback'
modidx = 1

for ncodes,ntrain in zip(ncodesall, ntrainall):
    for split in range(3):
        idcs = np.arange(p[modidx].shape[0])
        np.random.shuffle(idcs)
        t0 = time.time()
        dx = da.from_array(p[modidx], chunks=(1000,1000))
        pca = PCA(n_components=ncodes)
        pca.fit(dx[idcs[:ntrain]])
        z = pca.transform(dx)
        z = z.persist()
        t1 = time.time()
        print(z.shape)
        print(f'pca time {t1-t0}')

        nreps = 10
        trainsizes = [30,50,100,200,300,400,500,600,700,800]
        res = np.zeros((nreps,len(trainsizes)))
        l2 = 1e0

        for rep in range(nreps):
            losses = []

            idcs = np.arange(z.shape[0])
            np.random.shuffle(idcs)

            for ntrain in trainsizes:
                xps = torch.from_numpy(np.asarray(z)).float().cuda()
                xps = torch.cat([xps, torch.ones(xps.shape[0], 1).float().cuda()], dim=1)
                xtr = xps[idcs[:ntrain]]
                xt = xps[idcs[ntrain:]]

                y = get_y(metadict, ['age'], subs)[0]
                y_t = torch.from_numpy(y).float().cuda()
                ytr = y_t[idcs[:ntrain]]
                yt = y_t[idcs[ntrain:]]

                # REDUCE THIS TO GET GOOD RESULTS WITH SPARSITY 0.01->0.001 or 0.0001
                w, _, _, _ = torch.linalg.lstsq(xtr.T@xtr + l2*torch.eye(ncodes+1).float().cuda(), xtr.T@ytr)

                print(torch.mean((yt-xt@w)**2)**0.5)
                losses.append(float(torch.mean((yt-xt@w)**2)**0.5))

            print(f'Finished {rep}')
            res[rep,:] = losses

        print(np.mean(res, axis=0))
        print(np.std(res, axis=0))
            
        print('Finished lstsq')

        with open(f'/home/anton/Documents/Tulane/Research/Work/LatSimEC2/DictEst/PCA/{modstr}-lstsq-mean.csv', 'a') as f:
            f.write(f'{ncodes},{split},{",".join([str(val) for val in np.mean(res, axis=0)])}\n')

        with open(f'/home/anton/Documents/Tulane/Research/Work/LatSimEC2/DictEst/PCA/{modstr}-lstsq-std.csv', 'a') as f:
            f.write(f'{ncodes},{split},{",".join([str(val) for val in np.std(res, axis=0)])}\n')
            
        mseLoss = nn.MSELoss()

        nreps = 10
        trainsizes = [30,50,100,200,300,400,500,600,700,800]
        res = np.zeros((nreps,len(trainsizes)))

        for rep in range(nreps):

            idcs = np.arange(z.shape[0])
            np.random.shuffle(idcs)

            losses = []

            for ntrain in trainsizes:
                xps = torch.from_numpy(np.asarray(z)).float().cuda()
                xtr = xps[idcs[:ntrain]]
                xt = xps[idcs[ntrain:]]

                y = get_y(metadict, ['age'], subs)[0]
                y_t = torch.from_numpy(y).float().cuda()
                ytr = y_t[idcs[:ntrain]]
                yt = y_t[idcs[ntrain:]]

                mlp = MLP(ncodes)
                # 1e-3 good for age 1e-2 good for wrat
                mlp.train(xtr, ytr, lr=1e-2, nepochs=1000, l1=1e0, l2=1e-3)
                loss = mlp.predict(xt, yt)

                losses.append(float(loss))
                print(float(loss))

            res[rep,:] = losses
            print(f'Finished {rep}')

        print(np.mean(res, axis=0))
        print(np.std(res, axis=0))
        
        with open(f'/home/anton/Documents/Tulane/Research/Work/LatSimEC2/DictEst/PCA/{modstr}-mlp-mean.csv', 'a') as f:
            f.write(f'{ncodes},{split},{",".join([str(val) for val in np.mean(res, axis=0)])}\n')

        with open(f'/home/anton/Documents/Tulane/Research/Work/LatSimEC2/DictEst/PCA/{modstr}-mlp-std.csv', 'a') as f:
            f.write(f'{ncodes},{split},{",".join([str(val) for val in np.std(res, axis=0)])}\n')
            
        print('Finished MLP')
        
        nreps = 10
        trainsizes = [30,50,100,200,300,400,500,600,700,800]
        res = np.zeros((nreps,len(trainsizes)))

        nepochs = 500
        pperiod = 100
        verbose = False

        for rep in range(nreps):

            idcs = np.arange(z.shape[0])
            np.random.shuffle(idcs)

            losses = []

            for ntrain in trainsizes:
                xps = torch.from_numpy(np.asarray(z))[idcs].unsqueeze(1).float().cuda()

                mu = torch.mean(xps[:ntrain], dim=0, keepdims=True)
                std = torch.std(xps[:ntrain], dim=0, keepdims=True)
                xps = (xps-mu)/std

                xtr = xps[:ntrain]
                xt = xps[ntrain:]

                y = get_y(metadict, ['age'], subs)[0]
                y_t = torch.from_numpy(y[idcs]).float().cuda()
                ytr = y_t[:ntrain]
                yt = y_t[ntrain:]

                # dp=0.5 for wrat
                sim = LatSim(1, xps, dp=0.5, edp=0.1, wInit=1e-4, dim=10, temp=1)
                optim = torch.optim.Adam(sim.parameters(), lr=1e-3, weight_decay=1e-3)
                sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.75, eps=1e-7)

                for epoch in range(nepochs):
                    optim.zero_grad()
                    yhat = sim(xtr, [ytr])[0][0]
                    loss = mseLoss(yhat, ytr)**0.5
                    loss.backward()
                    optim.step()
                    sched.step(loss)
                    if verbose:
                        if epoch % pperiod == 0 or epoch == nepochs-1:
                            print(f'{epoch} {float(loss)} {sched._last_lr}')

                sim.eval()
                yhat = sim(xps, [y_t], np.arange(ntrain,idcs.shape[0]))[0][0][ntrain:]
                loss = mseLoss(yhat, yt)**0.5
                losses.append(float(loss))

                print(float(loss))

            res[rep,:] = losses
            print(f'Finished {rep}')

        print(np.mean(res, axis=0))
        print(np.std(res, axis=0))
        
        print('Finished latsim')
        
        with open(f'/home/anton/Documents/Tulane/Research/Work/LatSimEC2/DictEst/PCA/{modstr}-latsim-mean.csv', 'a') as f:
            f.write(f'{ncodes},{split},{",".join([str(val) for val in np.mean(res, axis=0)])}\n')

        with open(f'/home/anton/Documents/Tulane/Research/Work/LatSimEC2/DictEst/PCA/{modstr}-latsim-std.csv', 'a') as f:
            f.write(f'{ncodes},{split},{",".join([str(val) for val in np.std(res, axis=0)])}\n')

(847, 10)
pca time 0.6283316612243652
tensor(37.9263, device='cuda:0')
tensor(37.8268, device='cuda:0')
tensor(35.8202, device='cuda:0')
tensor(34.8596, device='cuda:0')
tensor(34.8266, device='cuda:0')
tensor(34.6793, device='cuda:0')
tensor(34.3288, device='cuda:0')
tensor(33.5522, device='cuda:0')
tensor(33.8064, device='cuda:0')
tensor(34.5195, device='cuda:0')
Finished 0
tensor(57.2830, device='cuda:0')
tensor(38.2530, device='cuda:0')
tensor(35.8462, device='cuda:0')
tensor(35.1462, device='cuda:0')
tensor(34.2441, device='cuda:0')
tensor(34.2634, device='cuda:0')
tensor(33.9425, device='cuda:0')
tensor(32.7267, device='cuda:0')
tensor(33.0702, device='cuda:0')
tensor(28.8153, device='cuda:0')
Finished 1
tensor(42.7748, device='cuda:0')
tensor(39.2168, device='cuda:0')
tensor(36.4993, device='cuda:0')
tensor(35.1403, device='cuda:0')
tensor(33.4259, device='cuda:0')
tensor(33.5454, device='cuda:0')
tensor(34.6391, device='cuda:0')
tensor(35.0229, device='cuda:0')
tensor(33.4904, 

42.37179946899414
41.40070343017578
37.66204071044922
36.220703125
36.4393310546875
36.61768341064453
37.49641799926758
37.19482421875
37.121681213378906
36.76382827758789
Finished 0
38.39251708984375
37.31061553955078
36.86189651489258
36.147090911865234
36.37052917480469
35.72047424316406
36.28553009033203
37.24943923950195
37.414024353027344
33.61748504638672
Finished 1
43.413692474365234
38.27040100097656
36.46181106567383
35.86273956298828
36.39316177368164
36.21809768676758
36.24336242675781
36.09786605834961
35.11429214477539
34.602081298828125
Finished 2
41.285221099853516
38.635467529296875
37.48316192626953
36.175750732421875
36.420433044433594
36.54480743408203
36.198123931884766
35.578773498535156
34.740596771240234
29.22987174987793
Finished 3
44.63951110839844
40.391258239746094
35.609153747558594
35.99089431762695
35.832794189453125
35.793663024902344
37.24181365966797
37.12947463989258
37.473819732666016
38.11723709106445
Finished 4
40.82905197143555
40.16706466674805
3

33.90705108642578
37.36651611328125
Finished 0
37.76063537597656
35.34185791015625
34.92770004272461
34.56296157836914
34.27622604370117
33.94938659667969
33.474700927734375
33.9802360534668
33.997318267822266
33.896419525146484
Finished 1
35.8976936340332
35.279170989990234
35.374000549316406
35.54692840576172
35.70965576171875
35.77009963989258
36.255699157714844
37.36005783081055
37.84690475463867
39.72151184082031
Finished 2
39.39600372314453
36.441226959228516
36.28809356689453
35.0756721496582
34.48692321777344
34.52702331542969
34.21171951293945
34.096229553222656
34.01908493041992
36.278236389160156
Finished 3
39.54011917114258
39.564361572265625
35.20541763305664
34.689537048339844
35.55907440185547
35.64155197143555
35.209625244140625
34.99641799926758
33.86220169067383
35.20431137084961
Finished 4
43.777652740478516
35.4852409362793
35.831504821777344
34.58084487915039
35.65517044067383
35.257286071777344
35.10378646850586
35.13479995727539
34.33916091918945
36.4752197265625

33.95950698852539
34.987762451171875
34.90336227416992
34.16618347167969
32.781028747558594
33.64877700805664
36.3570442199707
Finished 1
42.737518310546875
47.825016021728516
36.63062286376953
34.15713882446289
34.4922981262207
34.5458984375
35.387630462646484
36.32225799560547
35.88361358642578
35.49723434448242
Finished 2
36.74299621582031
47.059085845947266
35.561824798583984
35.34597396850586
34.683006286621094
34.33130645751953
34.667152404785156
34.73260498046875
37.22721862792969
33.322471618652344
Finished 3
37.86787033081055
35.24289321899414
36.428138732910156
35.639068603515625
35.791805267333984
35.91935729980469
36.07212448120117
36.27923583984375
35.27388381958008
35.99860382080078
Finished 4
44.28801345825195
44.83610153198242
34.84576416015625
34.773399353027344
34.63298797607422
34.04798126220703
34.47692108154297
34.386741638183594
33.50815200805664
31.82976722717285
Finished 5
48.75919723510742
40.7325553894043
34.07148742675781
34.48476791381836
34.37910461425781
3

31.16087532043457
Finished 1
48.99507141113281
56.14186477661133
35.45137405395508
34.450111389160156
32.606632232666016
33.35152816772461
33.63376235961914
33.84450912475586
32.98088455200195
31.290971755981445
Finished 2
49.862606048583984
47.69585037231445
36.18583297729492
35.013641357421875
34.96858596801758
34.880821228027344
35.06842041015625
34.13801193237305
34.23061752319336
34.191402435302734
Finished 3
55.078060150146484
45.41610336303711
35.897212982177734
34.725120544433594
34.61608123779297
34.77071762084961
35.045658111572266
35.78412628173828
36.498817443847656
38.30773162841797
Finished 4
47.21026611328125
51.253231048583984
35.859676361083984
35.10388946533203
35.37834548950195
35.156349182128906
36.39100646972656
35.07577133178711
34.90083312988281
33.621429443359375
Finished 5
48.382904052734375
37.149078369140625
36.27615737915039
35.5344352722168
35.21177673339844
34.73312759399414
35.08377456665039
35.6439208984375
35.632545471191406
39.39619064331055
Finished 6

34.497249603271484
34.649658203125
34.168704986572266
34.56904220581055
35.845394134521484
35.788475036621094
Finished 2
51.106380462646484
50.71052932739258
34.921295166015625
34.02751922607422
34.513423919677734
34.254207611083984
33.854549407958984
32.968345642089844
33.226131439208984
33.46921157836914
Finished 3
52.8828125
41.70844650268555
36.29589080810547
35.302406311035156
34.673805236816406
34.97096252441406
35.04801940917969
33.5625114440918
34.326717376708984
35.10524368286133
Finished 4
39.1769905090332
34.54621887207031
33.8044548034668
33.53656768798828
33.530452728271484
33.510494232177734
32.74972152709961
32.20481872558594
32.562225341796875
32.51015090942383
Finished 5
47.194740295410156
35.4810791015625
36.07158279418945
34.851402282714844
34.52019500732422
34.371315002441406
34.26441192626953
34.22169876098633
33.053321838378906
34.09465408325195
Finished 6
44.24319076538086
51.478111267089844
35.20779800415039
34.886993408203125
34.96384811401367
34.37453079223633

33.8420524597168
Finished 2
47.95964431762695
50.15607833862305
38.140079498291016
35.80754852294922
34.989341735839844
34.95716094970703
35.750816345214844
36.0515022277832
38.05439758300781
41.6879997253418
Finished 3
43.55841827392578
36.49129867553711
35.157020568847656
35.20283508300781
34.73504638671875
35.244773864746094
33.505428314208984
33.19963455200195
33.02406311035156
33.048587799072266
Finished 4
50.26545333862305
47.949851989746094
37.62173843383789
35.468048095703125
35.329917907714844
34.61383819580078
35.62621307373047
36.547882080078125
37.99433517456055
39.19020462036133
Finished 5
58.22795867919922
67.9831771850586
40.28550338745117
34.419490814208984
34.43601608276367
34.57197952270508
35.02754592895508
34.135276794433594
32.9620475769043
37.37095642089844
Finished 6
44.19641876220703
38.406856536865234
35.75517272949219
35.22365951538086
35.10178756713867
35.957008361816406
35.657806396484375
34.27890396118164
33.656463623046875
38.87022399902344
Finished 7
47.9

33.34037780761719
33.52396011352539
33.652069091796875
33.715999603271484
34.11920928955078
31.581008911132812
32.26871109008789
Finished 3
43.01911544799805
42.23754119873047
33.31081008911133
33.30213928222656
34.13033676147461
34.660667419433594
34.52827835083008
35.6636848449707
36.7892951965332
36.91288757324219
Finished 4
40.27326965332031
40.79747772216797
38.68064880371094
33.239891052246094
33.19936752319336
34.2304801940918
33.495601654052734
33.307674407958984
34.14413070678711
37.07020950317383
Finished 5
45.204627990722656
45.72498321533203
34.61996841430664
34.6031379699707
34.19184875488281
33.07834243774414
32.13080978393555
33.487117767333984
33.93149185180664
30.66365623474121
Finished 6
43.15323257446289
48.3699951171875
36.16868591308594
32.351192474365234
32.71440505981445
32.58079147338867
32.36561965942383
32.27978515625
30.988662719726562
27.619403839111328
Finished 7
38.46474838256836
40.765663146972656
33.7487678527832
32.42277908325195
34.21402359008789
33.57

34.664119720458984
38.874473571777344
Finished 3
57.57467269897461
48.475223541259766
41.524024963378906
32.45717239379883
34.436119079589844
33.830814361572266
34.07083511352539
35.53324508666992
33.91462707519531
31.703378677368164
Finished 4
41.71391296386719
37.27018356323242
36.16114044189453
34.69053268432617
34.959320068359375
35.256439208984375
36.06332778930664
35.419498443603516
35.80261993408203
32.67677688598633
Finished 5
42.84511184692383
44.767982482910156
38.53746032714844
34.85902404785156
33.48070526123047
33.7623291015625
34.6710205078125
34.38882064819336
35.0301399230957
34.172607421875
Finished 6
46.82253646850586
42.551239013671875
35.69328308105469
34.47521209716797
34.171451568603516
34.73397445678711
35.5194091796875
35.59177017211914
34.982208251953125
34.53798294067383
Finished 7
45.18680953979492
49.283267974853516
39.31130599975586
32.811851501464844
33.680938720703125
34.33216857910156
34.04603576660156
33.423126220703125
32.005104064941406
32.27006530761

36.39619445800781
33.54108810424805
33.59571075439453
33.72767639160156
33.90525436401367
34.20966339111328
34.33885955810547
37.00609588623047
Finished 4
49.05540466308594
41.41573715209961
43.97600555419922
32.370506286621094
34.0421028137207
34.67045974731445
35.96200180053711
36.89606475830078
36.235748291015625
35.75200653076172
Finished 5
48.94295883178711
37.13628387451172
34.761592864990234
33.88209915161133
32.69281768798828
33.54323196411133
32.631309509277344
31.86821746826172
31.158119201660156
33.29753494262695
Finished 6
44.90218734741211
47.10581588745117
35.63600540161133
35.03871154785156
33.808074951171875
35.18147659301758
35.62070083618164
35.590972900390625
33.08467102050781
32.60896682739258
Finished 7
41.824615478515625
40.720619201660156
38.289756774902344
34.2048454284668
34.17760467529297
34.58477783203125
35.46774673461914
35.9019660949707
36.82875442504883
39.91950607299805
Finished 8
44.37348937988281
49.18656539916992
38.66670608520508
34.23732376098633
34

36.19140625
33.00492858886719
Finished 4
58.188785552978516
47.30956268310547
34.41309356689453
34.146728515625
34.40607833862305
34.46413803100586
34.07291793823242
33.40025329589844
33.67901611328125
34.88969039916992
Finished 5
44.76280212402344
40.7982063293457
36.65896987915039
34.486167907714844
33.985984802246094
33.99770736694336
34.82622528076172
34.64149856567383
34.58564376831055
37.22893524169922
Finished 6
41.29690933227539
40.65308380126953
39.91870880126953
31.925045013427734
33.73993682861328
34.52546310424805
34.6710319519043
34.421539306640625
34.58018112182617
37.008609771728516
Finished 7
39.724998474121094
39.937538146972656
38.328609466552734
33.468238830566406
33.37332534790039
33.069580078125
33.59404754638672
34.05422592163086
32.783782958984375
36.181060791015625
Finished 8
41.6854362487793
43.358280181884766
40.35258865356445
34.44801712036133
34.52473831176758
33.98212814331055
33.966949462890625
33.01339340209961
32.0806999206543
31.860034942626953
Finished

33.509010314941406
32.68869400024414
32.7352294921875
33.00082015991211
33.04204559326172
33.244869232177734
35.61787414550781
34.34548568725586
Finished 5
47.944671630859375
41.24806213378906
33.965145111083984
33.398719787597656
34.03096008300781
34.154109954833984
34.34754943847656
34.5290412902832
34.95173263549805
33.6468391418457
Finished 6
42.081214904785156
39.83231735229492
40.172298431396484
35.63367462158203
34.45420837402344
34.38787078857422
35.37411117553711
36.34553909301758
37.2285041809082
37.66805648803711
Finished 7
42.43012619018555
41.44504928588867
41.43362808227539
33.107276916503906
33.253726959228516
34.062705993652344
33.823280334472656
33.95964050292969
33.74875259399414
35.316864013671875
Finished 8
46.561912536621094
49.06578063964844
45.37581253051758
32.5267219543457
32.964317321777344
32.906524658203125
32.85932159423828
33.237449645996094
31.818899154663086
33.51815414428711
Finished 9
[43.5351593  42.47887344 40.37158737 33.59476547 33.52929211 33.6755

32.69374084472656
32.70207977294922
30.883874893188477
Finished 5
49.03779602050781
35.670127868652344
41.25163269042969
34.26533126831055
34.10920333862305
33.91902160644531
34.996742248535156
35.40044021606445
36.071712493896484
37.81269454956055
Finished 6
53.95391082763672
45.99790573120117
40.873573303222656
33.79482650756836
33.8044319152832
34.03425979614258
33.34614562988281
32.3956413269043
33.57987976074219
37.701988220214844
Finished 7
41.8440055847168
41.687644958496094
40.31101608276367
34.08973693847656
34.59157180786133
34.32600402832031
33.18307113647461
33.90525817871094
34.01231384277344
31.496440887451172
Finished 8
45.04408645629883
42.805660247802734
34.709468841552734
34.382568359375
34.2983283996582
32.776123046875
32.90098190307617
33.04914474487305
33.591026306152344
32.23508071899414
Finished 9
[45.77844429 41.81259422 39.73242722 34.57977104 34.29527855 33.9526638
 34.13504295 34.03705368 34.23556633 33.98041382]
[4.80630677 3.22741644 2.57140545 0.5525006  0

43.92875289916992
47.4815673828125
40.68048095703125
32.7515754699707
30.361352920532227
34.37781524658203
35.17963409423828
36.501190185546875
36.360843658447266
38.857337951660156
Finished 6
43.73379135131836
42.36741638183594
39.51113510131836
33.76205825805664
32.80643081665039
33.437686920166016
32.20158004760742
32.507843017578125
34.321964263916016
33.18642807006836
Finished 7
44.38993453979492
43.54878234863281
42.8502197265625
32.62663269042969
31.832120895385742
33.09843444824219
33.203067779541016
32.83356857299805
33.25349044799805
32.24238586425781
Finished 8
40.268733978271484
39.86431121826172
39.082401275634766
33.48435974121094
33.31235122680664
34.58121871948242
34.42387390136719
34.811893463134766
33.82853317260742
31.901123046875
Finished 9
[44.50365295 42.50738411 41.22549286 34.38313446 33.37681122 33.92482185
 33.86376343 33.93822384 33.83342819 33.53013344]
[3.78066473 3.37206183 1.4532633  1.17647078 1.43400486 0.81968923
 1.00414072 1.31478105 1.53629615 3.344

34.89873504638672
34.17369079589844
34.041141510009766
34.28544998168945
32.0697021484375
36.4505729675293
Finished 6
40.78738021850586
44.28385925292969
35.28148651123047
33.95536422729492
34.16139602661133
33.658199310302734
33.163822174072266
34.30077362060547
32.120399475097656
33.27275085449219
Finished 7
39.998653411865234
41.12141799926758
39.81615447998047
33.62470626831055
32.90882873535156
32.63364028930664
32.39119338989258
31.658323287963867
32.01554870605469
30.8543758392334
Finished 8
37.80876541137695
39.448482513427734
36.64208221435547
34.1132926940918
33.93772506713867
35.07978439331055
34.25518035888672
34.77827072143555
35.1386833190918
39.992984771728516
Finished 9
[42.21937599 42.57530174 38.85190811 33.75412903 33.73156395 33.43362923
 33.14350586 33.5500555  33.10949154 33.20844231]
[2.64863332 2.80153849 2.04217801 0.60808446 0.69195339 0.76927129
 0.67499781 1.17531071 1.53303635 3.1254567 ]
Finished MLP
42.369510650634766
39.91960906982422
38.14728927612305
3

33.91923141479492
31.97928237915039
Finished 6
44.606956481933594
40.025516510009766
35.435752868652344
34.43252182006836
34.831878662109375
34.922061920166016
34.285396575927734
36.450721740722656
36.84454345703125
32.59349822998047
Finished 7
52.85157012939453
46.33379364013672
45.67482376098633
31.459516525268555
31.650728225708008
34.258338928222656
34.449554443359375
33.6801872253418
33.98410415649414
32.456974029541016
Finished 8
45.29705810546875
44.01853561401367
35.46821594238281
34.788108825683594
35.047122955322266
34.288631439208984
34.68778991699219
34.31039047241211
33.91433334350586
31.527095794677734
Finished 9
[45.62778702 41.64525871 39.59869461 33.49842777 34.41294003 34.91211205
 34.88251762 34.9774826  34.70240173 34.39229946]
[5.56565177 3.0396028  3.08575751 1.12509297 1.003486   0.74343381
 0.74808508 1.22526899 1.46869842 2.74862444]
Finished MLP
40.19254684448242
39.24846267700195
37.19613265991211
36.34425735473633
35.45944595336914
34.95927429199219
34.67208

45.0047607421875
39.97941589355469
34.03811264038086
34.35540008544922
33.622352600097656
34.374664306640625
35.65298080444336
34.68865203857422
36.65830612182617
Finished 7
44.76507568359375
45.86113357543945
46.4566764831543
33.86445236206055
33.979557037353516
32.68596649169922
33.16619110107422
32.506927490234375
32.345130920410156
34.413108825683594
Finished 8
42.21293640136719
41.60140609741211
41.09318542480469
33.22378921508789
32.93465805053711
31.030622482299805
32.502071380615234
33.54019546508789
35.20923614501953
36.25143051147461
Finished 9
[43.5288475  42.35197678 39.84258347 33.58195648 33.69072762 33.37538624
 33.6229393  33.50112228 32.91285877 33.42777843]
[1.61354644 2.79299604 3.17883611 0.61694554 0.63642678 1.00980672
 0.80945836 1.2538091  1.72412217 2.56805239]
Finished MLP
43.0723876953125
43.1589469909668
39.85762023925781
36.79575729370117
35.43568801879883
34.21585464477539
33.472599029541016
32.129642486572266
30.849334716796875
31.458251953125
Finished 0


33.43013381958008
33.9395866394043
34.812278747558594
35.76193618774414
38.80079650878906
Finished 7
38.20772933959961
38.374271392822266
37.79648208618164
34.06003189086914
34.063720703125
34.014801025390625
33.36299133300781
33.483253479003906
33.756797790527344
35.51823806762695
Finished 8
43.91407012939453
41.43669891357422
41.142478942871094
33.5634880065918
34.15868377685547
34.08155822753906
34.061397552490234
33.4439811706543
33.831993103027344
35.74134826660156
Finished 9
[42.51937523 40.53689651 39.14910774 33.92012482 34.19345703 34.10846252
 34.18639183 34.19539604 34.20868034 35.63352547]
[2.06636459 1.23520794 2.31506591 1.37388353 0.72220381 0.90892997
 0.97346424 1.25156006 1.51765914 1.97857979]
Finished MLP
39.97694778442383
39.36086654663086
38.10399627685547
36.7647590637207
35.41175842285156
33.907981872558594
32.94721603393555
34.312503814697266
34.25588607788086
34.273834228515625
Finished 0
39.9274787902832
40.18303680419922
38.32710647583008
37.152469635009766


33.86110305786133
29.99921989440918
Finished 7
40.29362106323242
44.60660171508789
33.767120361328125
33.569557189941406
35.16606521606445
35.98835754394531
35.03307342529297
34.04999542236328
31.335044860839844
35.042667388916016
Finished 8
47.32845687866211
41.997867584228516
40.20719528198242
35.285484313964844
33.87422561645508
34.06294250488281
33.28615951538086
33.84896469116211
33.919715881347656
35.83525085449219
Finished 9
[46.21347084 42.94637947 39.41930199 33.90206757 34.38952904 34.18235302
 34.7533432  34.60317268 34.56369629 33.60576363]
[3.84724734 2.83804152 3.58584893 1.33669422 0.52682958 1.5241584
 0.83148151 1.11562734 1.44637419 3.30175453]
Finished MLP
40.53325271606445
41.30839157104492
41.480003356933594
38.20183181762695
37.377662658691406
36.82781982421875
35.79134750366211
34.74542236328125
35.76817321777344
33.28573226928711
Finished 0
41.53349685668945
40.22056198120117
39.847835540771484
38.506690979003906
36.83087921142578
36.06647491455078
32.9729576110

39.576446533203125
41.59062576293945
33.023048400878906
32.819950103759766
33.064510345458984
33.22883224487305
32.97724151611328
33.9731330871582
34.0910758972168
Finished 8
44.63565444946289
39.2675666809082
33.99875259399414
33.79231262207031
32.55868911743164
33.43788146972656
34.26923370361328
34.8909912109375
36.10038757324219
30.426837921142578
Finished 9
[42.38408546 39.50643501 37.84558487 33.72365913 33.56589661 33.45252552
 33.82858887 34.15804367 34.07233238 33.51960144]
[3.81870211 0.98267682 3.01842061 0.64337412 1.11901929 1.22864024
 1.5260433  1.31570779 1.97658737 3.41825601]
Finished MLP
41.1314582824707
40.151611328125
39.328834533691406
37.03300094604492
36.43537521362305
34.82377243041992
34.4574089050293
33.675086975097656
35.13917922973633
36.52313995361328
Finished 0
41.24106216430664
41.00826644897461
39.44301986694336
37.20012664794922
35.491600036621094
34.50559616088867
33.352352142333984
34.2827262878418
31.804960250854492
32.66358184814453
Finished 1
42.4

34.79386520385742
35.060123443603516
34.99713897705078
36.863067626953125
32.89801788330078
Finished 8
40.21291732788086
38.03261947631836
39.800575256347656
33.729251861572266
32.90285873413086
33.63847351074219
34.41672897338867
35.134769439697266
32.89362335205078
32.58964538574219
Finished 9
[40.83210526 38.9891922  37.13533554 33.12051735 33.37533302 33.46194458
 33.44198418 33.91303482 33.80562286 32.78705368]
[1.55885665 1.6329911  1.77795559 0.54074949 0.52367486 0.64126289
 0.88689772 0.96857054 1.93656403 2.77985849]
Finished MLP
41.00113296508789
40.891990661621094
39.79893493652344
38.190101623535156
37.02465057373047
36.51123809814453
36.27400207519531
34.42499923706055
34.87268829345703
35.34653091430664
Finished 0
40.53355407714844
40.182674407958984
38.702449798583984
37.64766311645508
36.35106658935547
35.575626373291016
33.485233306884766
33.509674072265625
32.10565185546875
30.39910316467285
Finished 1
41.006431579589844
40.941959381103516
41.75075912475586
38.348514

34.09823989868164
33.5446891784668
35.003971099853516
37.88899612426758
Finished 8
45.867557525634766
40.8614501953125
41.83061218261719
34.04310607910156
32.82078170776367
33.10490417480469
33.5704231262207
31.46776580810547
33.4849739074707
35.28125
Finished 9
[43.42029305 40.01527557 38.95476265 33.47526627 33.26760731 33.70052757
 34.00045815 33.46052971 34.15787182 35.31364746]
[2.28639038 1.95936166 2.63464868 0.57305297 1.09703359 0.97979107
 1.10273641 1.60564928 1.65095186 2.33391175]
Finished MLP
41.66904830932617
42.913291931152344
40.87590789794922
40.35380935668945
37.9906120300293
37.7404670715332
37.73771667480469
37.65650177001953
36.77684783935547
31.644615173339844
Finished 0
42.34939193725586
40.866207122802734
39.97996139526367
39.112857818603516
37.49754333496094
38.094966888427734
37.68278121948242
37.725189208984375
37.77935791015625
34.473426818847656
Finished 1
41.4256706237793
40.4686279296875
41.1486701965332
38.70391845703125
37.66375732421875
38.10313415527

45.6566047668457
42.22523880004883
42.48656463623047
35.06238555908203
33.765052795410156
32.64841842651367
33.09898376464844
31.19461441040039
33.83585739135742
34.24714660644531
Finished 9
[43.03368034 40.99939003 37.44253922 33.64237022 33.05905304 33.14444466
 32.86989746 32.97947826 33.02269497 33.24976101]
[3.85694381 3.62833801 2.76819907 1.27405438 1.11445026 0.40461068
 1.57297395 1.97894875 1.49691911 2.21445435]
Finished MLP
43.5981559753418
44.270503997802734
41.47275924682617
38.63943862915039
38.22318649291992
37.08938217163086
35.65105438232422
34.68366241455078
34.43337631225586
36.07014083862305
Finished 0
39.80668640136719
40.331214904785156
39.28978729248047
39.24348449707031
38.39070129394531
38.00065612792969
36.58225631713867
36.295013427734375
36.86893844604492
39.18180465698242
Finished 1
41.4872932434082
39.97265625
39.37944030761719
39.27244186401367
38.15134048461914
36.50900650024414
35.18560791015625
34.71820068359375
33.704559326171875
31.207639694213867
F

33.11174774169922
33.45364761352539
32.913612365722656
32.96014404296875
32.45866775512695
31.094507217407227
29.707761764526367
Finished 9
[43.11848488 41.11992607 38.69965248 33.23265648 33.19840736 33.34099426
 33.35421009 33.12567768 32.8829977  32.27597809]
[4.18216176 2.50436149 2.44670915 0.58326873 0.50400538 0.45571024
 0.96616556 0.85144387 1.00143242 2.00187716]
Finished MLP
40.07090377807617
41.88783264160156
40.901512145996094
39.403926849365234
37.11531448364258
37.237205505371094
35.540794372558594
33.93573760986328
32.69953918457031
29.938541412353516
Finished 0
40.58204650878906
40.845359802246094
39.85862350463867
38.84035110473633
37.93650436401367
37.337467193603516
36.93373107910156
36.87440490722656
35.93442916870117
32.79664611816406
Finished 1
44.30358123779297
40.72856140136719
39.483272552490234
38.64803695678711
37.22831344604492
37.976436614990234
36.639991760253906
36.344200134277344
35.57826232910156
37.34098434448242
Finished 2
40.50767517089844
40.094135

32.9483528137207
33.48239517211914
33.658660888671875
34.668277740478516
35.19059753417969
37.51276779174805
33.28475570678711
Finished 9
[42.08094826 40.43197861 37.21483955 33.98474236 33.13902512 32.69228363
 33.83581581 34.15932713 34.98590775 35.18082142]
[2.93768493 2.0369953  2.5720643  1.06949115 1.15971916 1.58466549
 0.92505652 1.19139186 1.72561201 2.56441397]
Finished MLP
42.61964797973633
42.18342208862305
40.174808502197266
39.914276123046875
39.40140151977539
39.271244049072266
39.1735725402832
38.227996826171875
37.80756378173828
42.00990676879883
Finished 0
43.0195426940918
41.727691650390625
41.65251541137695
39.49130630493164
39.77397155761719
39.181541442871094
39.1422004699707
39.02241134643555
39.46809387207031
38.818115234375
Finished 1
42.052642822265625
43.40633010864258
40.471378326416016
39.621910095214844
39.25933837890625
38.363861083984375
38.57570266723633
38.44768524169922
37.91409683227539
38.135704040527344
Finished 2
43.596920013427734
41.626121520996

33.38493347167969
30.214365005493164
32.203548431396484
32.114280700683594
Finished 9
[41.82484169 39.32133446 37.47514305 33.7649025  33.35112495 33.55593967
 33.5764122  33.57436047 33.61183414 32.78741322]
[3.10444832 2.12031923 1.70932402 0.66240624 1.34412784 1.24293391
 0.51234622 1.46635417 1.85475213 1.55658849]
Finished MLP
43.206947326660156
44.43486404418945
40.11293411254883
39.52316665649414
39.179386138916016
39.59016036987305
40.08366775512695
40.02985382080078
39.67938995361328
40.8920783996582
Finished 0
42.054115295410156
39.701454162597656
39.81737518310547
38.89049530029297
38.91826248168945
38.4442253112793
37.69044876098633
37.77715301513672
36.551246643066406
35.041969299316406
Finished 1
45.212669372558594
41.98430633544922
40.234798431396484
39.15851593017578
39.420223236083984
38.89805221557617
38.96889877319336
38.81611633300781
36.13554763793945
35.23234176635742
Finished 2
41.7333869934082
42.1378059387207
40.42927169799805
40.452171325683594
40.31092834472

35.75988006591797
34.14713668823242
Finished 9
[42.63692322 40.14744911 38.15573654 34.19704437 33.52155094 33.80009918
 34.0032589  33.74730873 33.70248547 34.03371525]
[3.22666319 3.79794473 1.64952376 0.91097624 1.53098624 0.72394156
 0.96325037 1.75355351 1.39451194 1.83429106]
Finished MLP
44.05780029296875
43.7926025390625
40.421321868896484
41.379608154296875
41.11030960083008
40.2154541015625
40.54966735839844
40.10356903076172
38.566444396972656
38.127769470214844
Finished 0
45.31583023071289
43.026885986328125
41.18156814575195
40.37916946411133
40.90436935424805
40.8245735168457
39.36403274536133
40.40104675292969
39.49541091918945
43.42218017578125
Finished 1
41.41217803955078
40.94091796875
40.46879577636719
39.573848724365234
40.02115249633789
40.410640716552734
39.45848846435547
39.06258010864258
40.12784957885742
41.621219635009766
Finished 2
41.450496673583984
40.28133773803711
40.748905181884766
40.294830322265625
38.9629020690918
38.699928283691406
38.961612701416016

In [33]:
import torch

nreps = 10
trainsizes = [30,50,100,200,300,400,500,600,700,800]
res = np.zeros((nreps,len(trainsizes)))
l2 = 1e0

for rep in range(nreps):
    losses = []

    idcs = np.arange(z.shape[0])
    np.random.shuffle(idcs)

    for ntrain in trainsizes:
        xps = torch.from_numpy(np.asarray(z)).float().cuda()
        xps = torch.cat([xps, torch.ones(xps.shape[0], 1).float().cuda()], dim=1)
        xtr = xps[idcs[:ntrain]]
        xt = xps[idcs[ntrain:]]

        y = get_y(metadict, ['age'], subs)[0]
        y_t = torch.from_numpy(y).float().cuda()
        ytr = y_t[idcs[:ntrain]]
        yt = y_t[idcs[ntrain:]]

        # REDUCE THIS TO GET GOOD RESULTS WITH SPARSITY 0.01->0.001 or 0.0001
        w, _, _, _ = torch.linalg.lstsq(xtr.T@xtr + l2*torch.eye(ncodes+1).float().cuda(), xtr.T@ytr)

        print(torch.mean((yt-xt@w)**2)**0.5)
        losses.append(float(torch.mean((yt-xt@w)**2)**0.5))
            
    print(f'Finished {rep}')
    res[rep,:] = losses
    
print(np.mean(res, axis=0))
print(np.std(res, axis=0))

tensor(192.5408, device='cuda:0')
tensor(195.0819, device='cuda:0')
tensor(198.7012, device='cuda:0')
tensor(205.3038, device='cuda:0')
tensor(201.2672, device='cuda:0')
tensor(185.4756, device='cuda:0')
tensor(171.2626, device='cuda:0')
tensor(147.4945, device='cuda:0')
tensor(138.3634, device='cuda:0')
tensor(185.4845, device='cuda:0')
Finished 0
tensor(193.3747, device='cuda:0')
tensor(195.6248, device='cuda:0')
tensor(201.3620, device='cuda:0')
tensor(203.9545, device='cuda:0')
tensor(196.3075, device='cuda:0')
tensor(189.9360, device='cuda:0')
tensor(176.0122, device='cuda:0')
tensor(165.2872, device='cuda:0')
tensor(135.5327, device='cuda:0')
tensor(211.7391, device='cuda:0')
Finished 1
tensor(190.9630, device='cuda:0')
tensor(196.9132, device='cuda:0')
tensor(198.1187, device='cuda:0')
tensor(199.1819, device='cuda:0')
tensor(188.6399, device='cuda:0')
tensor(180.3952, device='cuda:0')
tensor(170.5749, device='cuda:0')
tensor(160.5892, device='cuda:0')
tensor(150.4752, device='c

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, ncodes):
        super(MLP, self).__init__()
        self.l1 = nn.Linear(ncodes, 40).float().cuda()
        self.l2 = nn.Linear(40,1).float().cuda()
        
    def train(self, xtr, ytr, nepochs=1000, lr=1e-1, l1=1e-1, l2=1e-4, pperiod=100, verbose=False):
        optim = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=l2)
        sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.75, eps=1e-7)
        
        for epoch in range(nepochs):
            optim.zero_grad()
            yhat = self(xtr)
            loss = mseLoss(yhat, ytr)**0.5
            l1loss = l1*torch.sum(torch.abs(self.l1.weight))
            (loss+l1loss).backward()
            optim.step()
            sched.step(loss)
            if verbose:
                if epoch % pperiod == 0 or epoch == nepochs-1:
                    print(f'{epoch} {[float(l) for l in [loss, l1loss]]} {sched._last_lr}')
                    
    def predict(self, xt, yt):
        with torch.no_grad():
            return mseLoss(self(xt), yt)**0.5
                    
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = self.l2(x).squeeze()
        return x
    
mseLoss = nn.MSELoss()

nreps = 10
trainsizes = [30,50,100,200,300,400,500,600,700,800]
res = np.zeros((nreps,len(trainsizes)))

for rep in range(nreps):

    idcs = np.arange(z.shape[0])
    np.random.shuffle(idcs)
    
    losses = []

    for ntrain in trainsizes:
        xps = torch.from_numpy(np.asarray(z)).float().cuda()
        xtr = xps[idcs[:ntrain]]
        xt = xps[idcs[ntrain:]]

        y = get_y(metadict, ['age'], subs)[0]
        y_t = torch.from_numpy(y).float().cuda()
        ytr = y_t[idcs[:ntrain]]
        yt = y_t[idcs[ntrain:]]

        mlp = MLP(ncodes)
        # 1e-3 good for age 1e-2 good for wrat
        mlp.train(xtr, ytr, lr=1e-2, nepochs=1000, l1=1e0, l2=1e-3)
        loss = mlp.predict(xt, yt)

        losses.append(float(loss))
        print(float(loss))
    
    res[rep,:] = losses
    print(f'Finished {rep}')
    
print(np.mean(res, axis=0))
print(np.std(res, axis=0))

40.60023880004883
38.32298278808594
39.800872802734375
34.742130279541016
34.226829528808594
34.612998962402344
34.246883392333984
32.39553451538086
32.25904846191406
33.38909912109375
Finished 0
44.191650390625
41.63274002075195
42.74654769897461
34.2310676574707
33.71150588989258
33.66259002685547
34.705528259277344
33.93981170654297
32.8906364440918
33.65440368652344
Finished 1
44.42479705810547
40.79983139038086
34.38186264038086
34.54288101196289
33.875152587890625
34.27886962890625
35.36629104614258
36.01179122924805
35.041481018066406
33.839202880859375
Finished 2
39.102012634277344
38.587257385253906
39.901031494140625
34.06235122680664
33.89718246459961
33.865177154541016
35.077754974365234
33.87831115722656
32.8411750793457
31.112560272216797
Finished 3
40.39668273925781
40.18560028076172
41.08665084838867
34.97288131713867
35.36097717285156
34.74446487426758
35.354835510253906
34.93059539794922
34.95436477661133
39.72813034057617
Finished 4
51.14301300048828
45.7835502624511

In [34]:
nreps = 10
trainsizes = [30,50,100,200,300,400,500,600,700,800]
res = np.zeros((nreps,len(trainsizes)))

nepochs = 500
pperiod = 100
verbose = False

for rep in range(nreps):

    idcs = np.arange(z.shape[0])
    np.random.shuffle(idcs)
    
    losses = []

    for ntrain in trainsizes:
        xps = torch.from_numpy(np.asarray(z))[idcs].unsqueeze(1).float().cuda()
        
        mu = torch.mean(xps[:ntrain], dim=0, keepdims=True)
        std = torch.std(xps[:ntrain], dim=0, keepdims=True)
        xps = (xps-mu)/std
    
        xtr = xps[:ntrain]
        xt = xps[ntrain:]

        y = get_y(metadict, ['age'], subs)[0]
        y_t = torch.from_numpy(y[idcs]).float().cuda()
        ytr = y_t[:ntrain]
        yt = y_t[ntrain:]

        # dp=0.5 for wrat
        sim = LatSim(1, xps, dp=0.5, edp=0.1, wInit=1e-4, dim=10, temp=1)
        optim = torch.optim.Adam(sim.parameters(), lr=1e-3, weight_decay=1e-3)
        sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.75, eps=1e-7)
        
        for epoch in range(nepochs):
            optim.zero_grad()
            yhat = sim(xtr, [ytr])[0][0]
            loss = mseLoss(yhat, ytr)**0.5
            loss.backward()
            optim.step()
            sched.step(loss)
            if verbose:
                if epoch % pperiod == 0 or epoch == nepochs-1:
                    print(f'{epoch} {float(loss)} {sched._last_lr}')
                    
        sim.eval()
        yhat = sim(xps, [y_t], np.arange(ntrain,idcs.shape[0]))[0][0][ntrain:]
        loss = mseLoss(yhat, yt)**0.5
        losses.append(float(loss))
        
        print(float(loss))
    
    res[rep,:] = losses
    print(f'Finished {rep}')
    
print(np.mean(res, axis=0))
print(np.std(res, axis=0))

41.362892150878906
44.08783721923828
40.386775970458984
39.970916748046875
39.931514739990234
39.34151077270508
39.55750274658203
40.0809440612793
38.64273452758789
40.01464080810547
Finished 0
44.839176177978516
41.674415588378906
40.362648010253906
40.96070098876953
40.21275329589844
39.955753326416016
39.49416732788086
41.14554214477539
41.157047271728516
38.1065788269043
Finished 1
43.523868560791016
40.72843933105469
41.81019592285156
40.60601043701172
40.59677505493164
39.57193374633789
40.452720642089844
40.665374755859375
38.41239929199219
36.44091033935547
Finished 2
44.02684020996094
42.21769332885742
40.772911071777344
39.84599685668945
39.833011627197266
40.07791519165039
40.349483489990234
39.99162673950195
40.3017578125
39.48772430419922
Finished 3
41.49198532104492
40.80345916748047
41.2951545715332
40.10832977294922
39.39784240722656
38.248497009277344
37.999969482421875
38.43232727050781
37.57269287109375
38.61994171142578
Finished 4
41.14785385131836
40.57553482055664