# Calculate low-rank dynamic FC and use it to approximate static FC
We can approximate static FC with \~120 rank-\~120 dictionary entries<br>
We can also approximate dynamic FC with >300 rank-1 dictionary entries<br>
Can the rank-1 entries sum together to approximate a static FC?

In [1]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [2]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return allsubs

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['age'], ['rest', 'nback', 'emoid'])
print(len(subs))

X = get_X(allts, ['rest', 'nback', 'emoid'], subs)
print(X[0].shape)

847
(847, 264, 124)


In [3]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 1.83

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [0.01, 0.2], 1/tr)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

ts = [np.stack([ts for ts in filter_design_ts(Xp)]) for Xp in X]
ts = [tsmod/np.linalg.norm(tsmod, axis=-1, keepdims=True) for tsmod in ts]
print(ts[0].shape)

(847, 264, 124)


In [72]:
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F

class LowRankCodes(nn.Module):
    '''
    ranks: array of rank for each codebook matrix
    '''
    def __init__(self, ranks):
        super(LowRankCodes, self).__init__()
        self.As = []
        for rank in ranks:
            A = nn.Parameter(1e-2*torch.randn(rank,264).float().cuda())
            self.As.append(A)
        self.As = nn.ParameterList(self.As)

    def forward(self):
        book = []
        for A in self.As:
            AA = A.T@A
            book.append(AA)
        return torch.stack(book)
    
class LowRankWeights(nn.Module):
    '''
    For a single modality!
    
    nsubs: number of subjects
    ncodes: number of pages in the codebook
    nt: number of timepoints
    '''
    def __init__(self, nsubs, nmods, ncodes, nt):
        super(LowRankWeights, self).__init__()
        self.w = nn.Parameter(1e-2*torch.rand(nsubs, ncodes, nt).float().cuda())

    def forward(self, sub, book):
        w = self.w[sub]
        return torch.einsum('pt,pab->abt', w, book)
    
def get_recon_loss(x, xhat):
    return mseLoss(xhat, x)

def get_smooth_loss_fc(xhat):
    before = xhat[:,:,:-1]
    after = xhat[:,:,1:]
    return torch.mean((before-after)**2)

def get_sub_fc(subts):
    return torch.einsum('at,bt->abt',subts,subts)
    
# Timeseries
x = torch.from_numpy(ts[0]).float().cuda()
    
# Parameters
ntrain = 400
nbatch = 30
smooth_mult = 0.1
nEpochs = 50
pPeriod = 40

mseLoss = nn.MSELoss()
    
# Codebook and weights
lrc = LowRankCodes(100*[1])
ncodes = len(lrc.As)

lrw = LowRankWeights(ntrain, 1, ncodes, x.shape[-1])

# Optimizers
optim = torch.optim.Adam(itertools.chain(lrc.parameters(), lrw.parameters()), lr=1e-2, weight_decay=0)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=int(ntrain/nbatch)+5, factor=0.75, eps=1e-7)
    
for epoch in range(nEpochs):
    suborder = np.arange(ntrain)
#     np.random.shuffle(suborder)
    for bstart in range(0,ntrain,nbatch):
        bend = bstart+nbatch
        if bend > ntrain:
            bend = ntrain
        optim.zero_grad()
        book = lrc()
        recon_loss = 0
        smooth_loss_fc = 0
        for subidx in range(bstart, bend):
            sub = suborder[subidx]
            xsub = get_sub_fc(x[sub])
            xhat = lrw(sub, book)
            recon_loss += get_recon_loss(xsub, xhat)
            smooth_loss_fc += smooth_mult*get_smooth_loss_fc(xhat)
        recon_loss /= (bend-bstart)
        smooth_loss_fc /= (bend-bstart)
        totloss = recon_loss+smooth_loss_fc
        totloss.backward()
        optim.step()
        sched.step(totloss)
        if bstart % nbatch == 0:
            print(f'{epoch} {bstart} recon: {[float(ls)**0.5 for ls in [recon_loss, smooth_loss_fc]]} '
                  f'lr: {sched._last_lr}')

print('Complete')

0 0 recon: [0.011949059217210473, 1.296255454723844e-06] lr: [0.01]
0 30 recon: [0.011827931935633564, 1.305218103184779e-06] lr: [0.01]
0 60 recon: [0.012338877569601071, 1.310397918217625e-06] lr: [0.01]
0 90 recon: [0.01201671417963237, 1.3160995516054654e-06] lr: [0.01]
0 120 recon: [0.013905218289878982, 1.3256684555004354e-06] lr: [0.01]
0 150 recon: [0.012890260571147387, 1.34038213368106e-06] lr: [0.01]
0 180 recon: [0.011782472308924237, 1.3496743974229167e-06] lr: [0.01]
0 210 recon: [0.012283084701174056, 1.3680592101576008e-06] lr: [0.01]
0 240 recon: [0.012493818424467389, 1.3807224667126702e-06] lr: [0.01]
0 270 recon: [0.012844406885266967, 1.4054761585890162e-06] lr: [0.01]
0 300 recon: [0.011863007478762753, 1.4331407570075352e-06] lr: [0.01]
0 330 recon: [0.012544688073430494, 1.4632368459082526e-06] lr: [0.01]
0 360 recon: [0.012000584093608447, 1.5065754739418786e-06] lr: [0.01]
0 390 recon: [0.014794181932698325, 1.5549148592322939e-06] lr: [0.01]
1 0 recon: [0.011

8 150 recon: [0.010473221296171106, 0.0019231177189245553] lr: [0.005625]
8 180 recon: [0.009426010807417038, 0.001954720984064582] lr: [0.005625]
8 210 recon: [0.010256674653951452, 0.0019814079792828446] lr: [0.005625]
8 240 recon: [0.010650519996177723, 0.0021208612581366565] lr: [0.005625]
8 270 recon: [0.010598657016099109, 0.0024617710162817053] lr: [0.005625]
8 300 recon: [0.00994010343544658, 0.0019291546180557739] lr: [0.005625]
8 330 recon: [0.009740898000619119, 0.002014837932159085] lr: [0.005625]
8 360 recon: [0.009920665731103109, 0.0018705796009021178] lr: [0.005625]
8 390 recon: [0.010862697599120709, 0.0028154822447333107] lr: [0.005625]
9 0 recon: [0.009713608446876548, 0.0017150687178671957] lr: [0.005625]
9 30 recon: [0.009361293904608863, 0.0017865230248876624] lr: [0.005625]
9 60 recon: [0.010139238571719842, 0.0017284424198117142] lr: [0.005625]
9 90 recon: [0.00988760516926212, 0.001505458431929508] lr: [0.005625]
9 120 recon: [0.011537790098929905, 0.0017785205

16 120 recon: [0.010627894458714118, 0.0023767125306497303] lr: [0.005625]
16 150 recon: [0.009433159769458674, 0.002122758423223212] lr: [0.005625]
16 180 recon: [0.008634013408259048, 0.0019123916119617311] lr: [0.005625]
16 210 recon: [0.009371311035033298, 0.0018375829212847808] lr: [0.005625]
16 240 recon: [0.009755961658778803, 0.0018863775611755922] lr: [0.005625]
16 270 recon: [0.009389711688880515, 0.00223474799861571] lr: [0.005625]
16 300 recon: [0.009133093112546523, 0.0018848953078343645] lr: [0.005625]
16 330 recon: [0.00860980088377323, 0.00230946546296658] lr: [0.005625]
16 360 recon: [0.009028104683594705, 0.0020363934764430665] lr: [0.005625]
16 390 recon: [0.01002667506373554, 0.003625560082060172] lr: [0.005625]
17 0 recon: [0.008885947327615103, 0.0020164182226179335] lr: [0.005625]
17 30 recon: [0.0086428497428023, 0.0020321007268690526] lr: [0.005625]
17 60 recon: [0.00931821778517073, 0.002057749306720201] lr: [0.005625]
17 90 recon: [0.008971623400945207, 0.002

24 60 recon: [0.00887424815900264, 0.002244461714777596] lr: [0.00421875]
24 90 recon: [0.008565994233326449, 0.002260351377687996] lr: [0.00421875]
24 120 recon: [0.009768605522460906, 0.002569373566632375] lr: [0.00421875]
24 150 recon: [0.00884975276472306, 0.0023606676785168476] lr: [0.00421875]
24 180 recon: [0.008149394293304564, 0.002087435579174767] lr: [0.00421875]
24 210 recon: [0.008779130231031027, 0.002048497759423174] lr: [0.00421875]
24 240 recon: [0.009220423158907382, 0.0021072077671190187] lr: [0.00421875]
24 270 recon: [0.008848291245625056, 0.002458616778651414] lr: [0.00421875]
24 300 recon: [0.008638144641664626, 0.0020651352149834833] lr: [0.00421875]
24 330 recon: [0.008217384317974948, 0.0024629454862340303] lr: [0.00421875]
24 360 recon: [0.008603577166406487, 0.002196092466252661] lr: [0.00421875]
24 390 recon: [0.008950790251787862, 0.003272502810926918] lr: [0.00421875]
25 0 recon: [0.008465254171747982, 0.002186348029501398] lr: [0.00421875]
25 30 recon: [

31 360 recon: [0.008382633411598908, 0.0022321980435597373] lr: [0.00421875]
31 390 recon: [0.010237470579090508, 0.004810495850297865] lr: [0.00421875]
32 0 recon: [0.008247629562261398, 0.002208457386348197] lr: [0.00421875]
32 30 recon: [0.008034529471790295, 0.002213570742001266] lr: [0.00421875]
32 60 recon: [0.0085821397066534, 0.002282563544138559] lr: [0.00421875]
32 90 recon: [0.00828699543190989, 0.002272246388771357] lr: [0.00421875]
32 120 recon: [0.009157568723331002, 0.0026810732238102364] lr: [0.00421875]
32 150 recon: [0.008546176886356853, 0.0023998606393880527] lr: [0.00421875]
32 180 recon: [0.007893801912840354, 0.0021083259911232914] lr: [0.00421875]
32 210 recon: [0.008453626060713043, 0.002107753145334019] lr: [0.00421875]
32 240 recon: [0.008880065461802356, 0.0021757029792932744] lr: [0.00421875]
32 270 recon: [0.00856990440431228, 0.0024815157313157364] lr: [0.00421875]
32 300 recon: [0.008366518630206867, 0.002097828481088084] lr: [0.00421875]
32 330 recon: [

39 240 recon: [0.008660823850006053, 0.0022746407974913413] lr: [0.00421875]
39 270 recon: [0.008380808285093194, 0.0025650532945497018] lr: [0.00421875]
39 300 recon: [0.008183395677385248, 0.002183189021824688] lr: [0.00421875]
39 330 recon: [0.007808985329339746, 0.0025258518417996915] lr: [0.00421875]
39 360 recon: [0.008204707510328663, 0.0022890798101507077] lr: [0.00421875]
39 390 recon: [0.009389489293919497, 0.0043208288025763765] lr: [0.00421875]
40 0 recon: [0.008065975084742834, 0.0022784414099456954] lr: [0.00421875]
40 30 recon: [0.007863661855767567, 0.0022900464816141023] lr: [0.00421875]
40 60 recon: [0.008380007795082661, 0.0023760573106490035] lr: [0.00421875]
40 90 recon: [0.00809867461752167, 0.002372981255633659] lr: [0.00421875]
40 120 recon: [0.008780307432486905, 0.002881175748427339] lr: [0.00421875]
40 150 recon: [0.008312906759901465, 0.0025126544828302346] lr: [0.00421875]
40 180 recon: [0.007700975081303649, 0.0022084984653625893] lr: [0.00421875]
40 210 r

47 120 recon: [0.00856072335724399, 0.0029884160401372115] lr: [0.00421875]
47 150 recon: [0.008167575971088713, 0.0025770259127511] lr: [0.00421875]
47 180 recon: [0.007583980820272096, 0.002268273513791824] lr: [0.00421875]
47 210 recon: [0.008100088596741346, 0.002288364723073645] lr: [0.00421875]
47 240 recon: [0.008485411431501638, 0.0023534831124202805] lr: [0.00421875]
47 270 recon: [0.008231429919000936, 0.002630338621781104] lr: [0.00421875]
47 300 recon: [0.00802739572044126, 0.0022495674197912066] lr: [0.00421875]
47 330 recon: [0.007680812898744038, 0.0025698525390223303] lr: [0.00421875]
47 360 recon: [0.008079754841554208, 0.002336965631705273] lr: [0.00421875]
47 390 recon: [0.009117314803600245, 0.0043726611158651195] lr: [0.00421875]
48 0 recon: [0.007941587261438177, 0.002319978265063999] lr: [0.00421875]
48 30 recon: [0.007752984847825033, 0.002312309199170985] lr: [0.00421875]
48 60 recon: [0.0082566582228058, 0.0023957505127427626] lr: [0.00421875]
48 90 recon: [0.

In [68]:
# Fast weight estimation for all subjects

book = lrc()

A = book.reshape(book.shape[0], -1).permute(1,0)
AA = A.T@A
codes = []

for sub in range(x.shape[0]):
    B = get_sub_fc(x[sub]).reshape(-1, x.shape[-1])
    AB = A.T@B
    C,_,_,_ = torch.linalg.lstsq(AA+0.1*torch.eye(AA.shape[0]).float().cuda(),AB)
    codes.append(torch.from_numpy(C.detach().cpu().numpy()))
    if sub % 100 == 0:
        loss = mseLoss(A@C,B)**0.5
        print(f'Finished {sub} {loss}')
    
codes = torch.stack(codes)
print(codes.shape)

Finished 0 0.0072065494023263454
Finished 100 0.008164320141077042
Finished 200 0.005845096427947283
Finished 300 0.0074473656713962555
Finished 400 0.007260077632963657
Finished 500 0.006784871686249971
Finished 600 0.00850341934710741
Finished 700 0.007079832721501589
Finished 800 0.007483322639018297
torch.Size([847, 200, 124])


In [61]:
# Reconstruct static FC from dynamic FC

import matplotlib.pyplot as plt

sub = 398

dynfc = get_sub_fc(x[sub])
dynfc = torch.mean(dynfc, dim=-1).detach().cpu().numpy()

statfc = np.corrcoef(ts[0][sub])

reconfc = lrw(sub, book)
reconfc = torch.mean(reconfc, dim=-1).detach().cpu().numpy()

A = book.reshape(book.shape[0], -1).permute(1,0)
reconfc2 = torch.mean((A.cpu()@codes[sub]), dim=1).reshape(264,264).numpy()

fig, ax = plt.subplots(1,4,figsize=(10,3))

print(np.mean((statfc-124*dynfc)**2)**0.5)
print(np.mean((statfc-124*reconfc)**2)**0.5)
print(np.mean((statfc-124*reconfc2)**2)**0.5)

ax[0].imshow(dynfc)
ax[1].imshow(statfc)
ax[2].imshow(reconfc)
ax[3].imshow(reconfc2)
fig.show()

RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [70]:
ntrain = 30

losses = []

for ntrain in [30,50,100,200,300,400,500,600,700,800]:

    codescuda = codes.float().cuda()
    xps = torch.mean(codescuda, dim=-1) #ps.reshape(ps.shape[0],-1)
    xps = torch.cat([xps, torch.ones(xps.shape[0], 1).float().cuda()], dim=1)
    xtr = xps[:ntrain]
    xt = xps[ntrain:]

    y = get_y(metadict, ['age'], subs)[0]
    y_t = torch.from_numpy(y).float().cuda()
    ytr = y_t[:ntrain]
    yt = y_t[ntrain:]

    w, _, _, _ = torch.linalg.lstsq(xtr.T@xtr + 0.01*torch.eye(201).float().cuda(), xtr.T@ytr)

#     print(torch.mean((ytr-xtr@w)**2)**0.5)
    losses.append(float(torch.mean((yt-xt@w)**2)**0.5))
    
print(losses)

[37.303897857666016, 36.489013671875, 33.918373107910156, 33.12736892700195, 32.28167724609375, 31.161489486694336, 29.461423873901367, 30.928251266479492, 31.118194580078125, 26.9007511138916]


In [71]:
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
    
class CNN(nn.Module):
    def __init__(self, x, dp=0.1):
        super(CNN, self).__init__()
        self.cnn1 = torch.nn.Conv2d(1,10,(x.shape[-2],4)).float().cuda()
        self.ap1 = torch.nn.AvgPool2d((1,x.shape[-1]-3))
        self.lin1 = torch.nn.Linear(10,1).float().cuda()
        self.dp = nn.Dropout(p=dp)
        
    def forward(self, ts):
        ts = self.dp(ts)
        y = F.relu(self.cnn1(ts))
        y = self.ap1(y)
        z = y.reshape(y.shape[0], -1).squeeze()
        y = self.lin1(z)
        return z, y.squeeze()
    
nEpochs = 1000
pPeriod = 200

ntrain = 30

losses = []

for ntrain in [30,50,100,200,300,400,500,600,700,800]:

    codescuda = codes.float().cuda()
    xps = codescuda.unsqueeze(1)
    xtr = xps[:ntrain]
    xt = xps[ntrain:]

    y = get_y(metadict, ['age'], subs)[0]
    y_t = torch.from_numpy(y).float().cuda()
    ytr = y_t[:ntrain]
    yt = y_t[ntrain:]

    cnn = CNN(xtr, dp=0)

    optim = torch.optim.Adam(cnn.parameters(), lr=1e-1, weight_decay=1e-1)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.9, eps=1e-7)

    for epoch in range(nEpochs):
        optim.zero_grad()
        z, yhat1 = cnn(xtr)
        loss = mseLoss(yhat1, ytr)
        loss.backward()
        optim.step()
        sched.step(loss)
#         if epoch % pPeriod == 0 or epoch == nEpochs-1:
#             print(f'{epoch} recon: {loss**0.5} {sched._last_lr}')

#     print('Complete')

    cnn.eval()
    z, yhat1 = cnn(xt)
    loss = mseLoss(yhat1, yt)**0.5
    losses.append(float(loss))
    
print(losses)

[37.07823944091797, 36.521366119384766, 33.62803268432617, 33.55487060546875, 32.20626449584961, 31.09649085998535, 29.927452087402344, 31.040786743164062, 31.461437225341797, 27.856002807617188]


In [10]:
import sys

sys.path.append('../../LatentSimilarity')

from latsim import LatSim

print('Complete')

Complete


In [64]:
ntrain = 800

losses = []

for ntrain in [30,50,100,200,300,400,500,600,700,800]:
    codescuda = codes.float().cuda()
    xps = codescuda.unsqueeze(1)
    xtr = xps[:ntrain]
    xt = xps[ntrain:]

    y = get_y(metadict, ['age'], subs)[0]
    y_t = torch.from_numpy(y).float().cuda()
    ytr = y_t[:ntrain]
    yt = y_t[ntrain:]

    mseLoss = nn.MSELoss()

    nEpochs = 500
    pPeriod = 100

    cnn = CNN(xtr, dp=0.1)
    sim = LatSim(1, torch.zeros(1,1,10), dp=0, edp=0.1, wInit=1e-4, dim=20, temp=1)

    optim = torch.optim.Adam(itertools.chain(cnn.parameters(), sim.parameters()), lr=1e-1, weight_decay=1e-2)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.9, eps=1e-7)

    for epoch in range(nEpochs):
        optim.zero_grad()
        z, yhat1 = cnn(xtr)
        yhat2 = sim(z.unsqueeze(1), [ytr])[0][0]
        loss = mseLoss(yhat2, ytr)
        loss.backward()
        optim.step()
        sched.step(loss)
    #     if epoch % pPeriod == 0 or epoch == nEpochs-1:
    #         print(f'{epoch} recon: {loss**0.5} {sched._last_lr}')

    # print('Complete')

    cnn.eval()
    sim.eval()
    z, yhat1 = cnn(xps)
    yhat2 = sim(z.unsqueeze(1), [y_t])[0][0][ntrain:]
    loss = mseLoss(yhat2, yt)**0.5
    losses.append(float(loss))
    
print(losses)

[39.11385726928711, 39.22460174560547, 37.47481918334961, 35.456817626953125, 34.131099700927734, 33.35258102416992, 31.382802963256836, 30.60487937927246, 30.524438858032227, 31.71627426147461]


In [65]:
ntrain = 700

losses = []

for ntrain in [30,50,100,200,300,400,500,600,700,800]:
    codescuda = codes.float().cuda()
    xps = torch.mean(codescuda, dim=-1).unsqueeze(1)
    xtr = xps[:ntrain]
    xt = xps[ntrain:]

    y = get_y(metadict, ['age'], subs)[0]
    y_t = torch.from_numpy(y).float().cuda()
    ytr = y_t[:ntrain]
    yt = y_t[ntrain:]

    mseLoss = nn.MSELoss()

    nEpochs = 500
    pPeriod = 100

    sim = LatSim(1, torch.zeros(1,1,ncodes), dp=0.1, edp=0.1, wInit=1e-4, dim=2, temp=1)

    optim = torch.optim.Adam(sim.parameters(), lr=1e-1, weight_decay=1e-2)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.9, eps=1e-7)

    for epoch in range(nEpochs):
        optim.zero_grad()
        yhat = sim(xtr, [ytr])[0][0]
        loss = mseLoss(yhat, ytr)
        loss.backward()
        optim.step()
        sched.step(loss)
#         if epoch % pPeriod == 0 or epoch == nEpochs-1:
#             print(f'{epoch} recon: {loss**0.5} {sched._last_lr}')

#     print('Complete')

    sim.eval()
    yhat = sim(xps, [y_t])[0][0][ntrain:]
    loss = mseLoss(yhat, yt)**0.5
    losses.append(float(loss))
print(losses)

[39.33172607421875, 39.20526123046875, 38.03037643432617, 35.90704345703125, 34.469017028808594, 34.209110260009766, 32.38001251220703, 32.26662826538086, 32.01296615600586, 32.053043365478516]
