In [2]:
import pickle

with open('../../PNC_Good/MegaMeta3.pkl', 'rb') as f: 
    meta = pickle.load(f)
    
print('Complete')

Complete


In [3]:
with open('../../Work/LatentSim/Splits.pkl', 'rb') as f:
    keys, groups = pickle.load(f)
    
print(len(keys))
print(len(groups[0][0])+len(groups[0][1]))
print('Complete')

620
620
Complete


In [4]:
import numpy as np

nback = np.stack([meta[key]['nback'] for key in keys])
emoid = np.stack([meta[key]['emoid'] for key in keys])

print(nback.shape)
print(emoid.shape)

(620, 264, 231)
(620, 264, 210)


In [5]:
from nilearn.connectome import ConnectivityMeasure

def getFC(timeSeries, kind='correlation', transpose=True):
    connMeasure = ConnectivityMeasure(kind=kind)
    if transpose:
        timeSeries = np.transpose(timeSeries, axes=(0,2,1))
    return connMeasure.fit_transform(timeSeries)

nback_p = getFC(nback)
emoid_p = getFC(emoid)

print(nback_p.shape)

(620, 264, 264)


In [6]:
import torch

def convertTorch(p):
    t = torch.from_numpy(p).float()
    u = []
    for i in range(t.shape[0]):
        u.append(t[i][torch.triu_indices(264,264,offset=1).unbind()])
    return torch.stack(u).cuda()

nback_p_t = convertTorch(nback_p)
emoid_p_t = convertTorch(emoid_p)

print(nback_p_t.shape)
print(emoid_p_t.shape)
print('Complete')

torch.Size([620, 34716])
torch.Size([620, 34716])
Complete


In [6]:
mu_nback = torch.mean(nback_p_t, dim=0, keepdim=True)
mu_emoid = torch.mean(emoid_p_t, dim=0, keepdim=True)
std_nback = torch.std(nback_p_t, dim=0, keepdim=True)
std_emoid = torch.std(emoid_p_t, dim=0, keepdim=True)

nback_p_t = (nback_p_t - mu_nback)/std_nback
emoid_p_t = (emoid_p_t - mu_emoid)/std_emoid

print('Norm complete')

Norm complete


In [7]:
print(torch.min(nback_p_t[:,420:422],dim=0))
print(torch.min(emoid_p_t[:,420:422],dim=0))

torch.return_types.min(
values=tensor([-2.2998, -3.0133], device='cuda:0'),
indices=tensor([603, 350], device='cuda:0'))
torch.return_types.min(
values=tensor([-2.7284, -3.0616], device='cuda:0'),
indices=tensor([351,  33], device='cuda:0'))


In [7]:
age = np.stack([meta[key]['AgeInMonths'] for key in keys])
gen = np.stack([np.array([meta[key]['Gender'] == 'M', meta[key]['Gender'] == 'F']) for key in keys]).astype(int)
wrt = np.stack([meta[key]['wratStd'] for key in keys])

print(age.shape)
print(gen.shape)
print(wrt.shape)

(620,)
(620, 2)
(620,)


In [8]:
age_t = torch.from_numpy(age).float().cuda()
gen_t = torch.from_numpy(gen).float().cuda()
wrt_t = torch.from_numpy(wrt).float().cuda()

print('Complete')

Complete


In [248]:
import torch.nn as nn
import torch.nn.functional as F
import time

ceLoss = torch.nn.CrossEntropyLoss()
mseLoss = torch.nn.MSELoss()

rmse = []
Ass = []

def allBelowThresh(losses, thresh):
    for loss,thr in zip(losses, thresh):
        if loss > thr:
            return False
    return True

def flatten(res):
    flat = [r for re in res for r in re]
    return flat

def mask(e):
    return e - torch.diag(torch.diag(e.detach()))

def arith(n):
    return int(n*(n+1)/2)

def getAvg(res):
    nPara = len(res)
    nTasks = len(res[0])
    avg = []
    for task in range(nTasks):
        avg.append(0)
        for para in range(nPara):
            avg[-1] += res[para][task]/nPara
    return avg

class LatSim(nn.Module):
    def __init__(self, nTasks, inp, dp=0.5, edp=0.1, wInit=1e-4, dimA=10, dimB=2, temp=1):
        super(LatSim, self).__init__()
        self.nTasks = nTasks
        self.A = nn.Parameter(wInit*torch.randn(2,inp.shape[-1],dimA).float().cuda())
        self.B = nn.Parameter(wInit*torch.randn(2,nTasks,dimA,dimB).float().cuda())
        self.dp = nn.Dropout(p=dp)
        self.edp = nn.Dropout(p=edp)
        self.t = temp if isinstance(temp, list) else nTasks*[temp]
    
    def getLatent(self, x, para):
        return x[:,para]@self.A[para]
    
    def getEdges(self, A, para, task):
        e = 1e-10
        z = A@self.B[para, task]
        e = e+z@z.T
        return e
        
    def forward(self, x, ys, testIdcs=None):
        assert self.B.shape[1] == len(ys), "business end targets dim not same as passed"
        x = self.dp(x)
        res = []
        As = []
        es = []
        for para in range(2):
            A = self.getLatent(x, para)
            As.append(F.relu(A))
        As = [As[0] + As[1], As[0] - As[1]]
        for para,A in enumerate(As):
            res.append([])
            for task,y in enumerate(ys):
                e = self.getEdges(A, para, task)
                if testIdcs is not None:
                    e[:,testIdcs] = 0
                e = mask(e)
                e = self.edp(e)
                e[e == 0] = float('-inf')
                e = F.softmax(e/self.t[task], dim=1)
                res[-1].append(e@y)
        return res, As, es

def validate(model, X, ys, testIdcs):
    model.eval()
    losses = []
    with torch.no_grad():
        res, _, _ = model(X, ys, testIdcs)
        # for r,y in zip(res, ys):
        for r,y in zip(getAvg(res), ys):
            if y.dim() == 1:
                loss = mseLoss(r[testIdcs], y[testIdcs]).cpu().numpy()**0.5
                losses.append(loss)
            else:
                corr = (torch.argmax(r, dim=1) == torch.argmax(y, dim=1))[testIdcs]
                loss = torch.sum(corr)/len(testIdcs)
                losses.append(loss)
    model.train()
    return losses

def getAs(model, X, ys):
    model.eval()
    As = None
    with torch.no_grad():
        _, As, _ = model(X, ys)
    model.train()
    return As

nEpochs = 200
pPeriod = 5
thresh = [20,0.3,10]
regParam = [1,1e3,50]

for grp in range(10):
    trainIdcs = groups[grp][0][0:496]
    trainValidIdcs = groups[grp][0]
    validIdcs = np.arange(496,len(trainValidIdcs))
    testIdcs = groups[grp][1]

    X0 = nback_p_t
    X1 = emoid_p_t

    X = torch.stack([X0, X1], dim=1)
    Xt = X[trainIdcs]
    Xtv = X[trainValidIdcs]

    yy = [age_t, gen_t, wrt_t]
    yt = [age_t[trainIdcs], gen_t[trainIdcs], wrt_t[trainIdcs]]
    ytv = [age_t[trainValidIdcs], gen_t[trainValidIdcs], wrt_t[trainValidIdcs]]

    nTasks = len(yy)

    sim = LatSim(nTasks, X, dp=0.5, edp=0.1, wInit=1e-4, dimA=40, dimB=2, temp=[1,1,1])
    optim = torch.optim.Adam(sim.parameters(), lr=5e-4, weight_decay=5e-4)

    validLoss = [[] for _ in range(nTasks)]

    for epoch in range(nEpochs):
        optim.zero_grad()
        res, _, _ = sim(Xt, yt)
        loss = []
        for i,(r,y) in enumerate(zip(flatten(res)+getAvg(res), (X.shape[1]+1)*yt)):
            if y.dim() > 1:
                loss.append(regParam[i%nTasks]*ceLoss(r, y))
            else:
                loss.append(regParam[i%nTasks]*mseLoss(r, y))
        sum(loss).backward()
        optim.step()
        if epoch % pPeriod == 0 or epoch == nEpochs-1 or allBelowThresh(loss[0:nTasks], thresh):
            print(f'epoch {epoch} loss={loss}')
            losses = validate(sim, Xtv, ytv, validIdcs)
            for i,lss in enumerate(losses):
                if (len(validLoss[i]) == 0 or 
                        (yy[i].dim() == 1 and lss < min(validLoss[i])) or 
                        (yy[i].dim() > 1 and lss > max(validLoss[i]))):
                    print(f'New best validation epoch {epoch} {i} loss={lss}')
                    torch.save(sim.state_dict(), f'../../Work/LatentSim/sim{i}.pyt')
                    validLoss[i].append(float(lss))
            if allBelowThresh(loss[0:nTasks], thresh):
                print('Early stopping')
                break

    finalLoss = []

    for i in range(nTasks):
        sim.load_state_dict(torch.load(f'../../Work/LatentSim/sim{i}.pyt'))
        loss = validate(sim, X, yy, testIdcs)
        if i == 0:
            Ass.append(getAs(sim, X, yy))
        finalLoss.append(float(loss[i]))

    rmse.append(finalLoss)

    print(f'FINISHED {rmse}')

epoch 0 loss=[tensor(1511.8212, device='cuda:0', grad_fn=<MulBackward0>), tensor(693.0249, device='cuda:0', grad_fn=<MulBackward0>), tensor(13331.0625, device='cuda:0', grad_fn=<MulBackward0>), tensor(1514.2000, device='cuda:0', grad_fn=<MulBackward0>), tensor(692.8686, device='cuda:0', grad_fn=<MulBackward0>), tensor(13371.7617, device='cuda:0', grad_fn=<MulBackward0>), tensor(1512.8429, device='cuda:0', grad_fn=<MulBackward0>), tensor(692.9332, device='cuda:0', grad_fn=<MulBackward0>), tensor(13349.9355, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 0 0 loss=44.566139221191406
New best validation epoch 0 1 loss=0.5
New best validation epoch 0 2 loss=12.045412063598633
epoch 5 loss=[tensor(1386.7703, device='cuda:0', grad_fn=<MulBackward0>), tensor(691.8045, device='cuda:0', grad_fn=<MulBackward0>), tensor(13087.1738, device='cuda:0', grad_fn=<MulBackward0>), tensor(1354.0339, device='cuda:0', grad_fn=<MulBackward0>), tensor(688.3138, device='cuda:0', grad_fn=<Mu

epoch 90 loss=[tensor(68.5770, device='cuda:0', grad_fn=<MulBackward0>), tensor(499.2239, device='cuda:0', grad_fn=<MulBackward0>), tensor(130.4671, device='cuda:0', grad_fn=<MulBackward0>), tensor(57.8530, device='cuda:0', grad_fn=<MulBackward0>), tensor(438.6774, device='cuda:0', grad_fn=<MulBackward0>), tensor(236.1685, device='cuda:0', grad_fn=<MulBackward0>), tensor(45.7233, device='cuda:0', grad_fn=<MulBackward0>), tensor(459.9781, device='cuda:0', grad_fn=<MulBackward0>), tensor(142.1627, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 95 loss=[tensor(56.1050, device='cuda:0', grad_fn=<MulBackward0>), tensor(455.3521, device='cuda:0', grad_fn=<MulBackward0>), tensor(153.4505, device='cuda:0', grad_fn=<MulBackward0>), tensor(49.9775, device='cuda:0', grad_fn=<MulBackward0>), tensor(389.6986, device='cuda:0', grad_fn=<MulBackward0>), tensor(190.4976, device='cuda:0', grad_fn=<MulBackward0>), tensor(39.4640, device='cuda:0', grad_fn=<MulBackward0>), tensor(414.3337, device='cuda:0'

epoch 170 loss=[tensor(14.5850, device='cuda:0', grad_fn=<MulBackward0>), tensor(319.4002, device='cuda:0', grad_fn=<MulBackward0>), tensor(170.1331, device='cuda:0', grad_fn=<MulBackward0>), tensor(13.9270, device='cuda:0', grad_fn=<MulBackward0>), tensor(334.9871, device='cuda:0', grad_fn=<MulBackward0>), tensor(153.8521, device='cuda:0', grad_fn=<MulBackward0>), tensor(11.3128, device='cuda:0', grad_fn=<MulBackward0>), tensor(324.5306, device='cuda:0', grad_fn=<MulBackward0>), tensor(141.0717, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 175 loss=[tensor(15.0217, device='cuda:0', grad_fn=<MulBackward0>), tensor(319.4757, device='cuda:0', grad_fn=<MulBackward0>), tensor(200.0239, device='cuda:0', grad_fn=<MulBackward0>), tensor(13.2757, device='cuda:0', grad_fn=<MulBackward0>), tensor(334.6930, device='cuda:0', grad_fn=<MulBackward0>), tensor(151.9115, device='cuda:0', grad_fn=<MulBackward0>), tensor(10.8017, device='cuda:0', grad_fn=<MulBackward0>), tensor(324.4205, device='cuda:

epoch 45 loss=[tensor(665.2437, device='cuda:0', grad_fn=<MulBackward0>), tensor(661.0502, device='cuda:0', grad_fn=<MulBackward0>), tensor(308.1974, device='cuda:0', grad_fn=<MulBackward0>), tensor(351.2656, device='cuda:0', grad_fn=<MulBackward0>), tensor(653.8471, device='cuda:0', grad_fn=<MulBackward0>), tensor(322.0486, device='cuda:0', grad_fn=<MulBackward0>), tensor(390.0180, device='cuda:0', grad_fn=<MulBackward0>), tensor(656.0318, device='cuda:0', grad_fn=<MulBackward0>), tensor(240.1098, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 45 0 loss=24.615942001342773
epoch 50 loss=[tensor(512.7016, device='cuda:0', grad_fn=<MulBackward0>), tensor(656.4925, device='cuda:0', grad_fn=<MulBackward0>), tensor(329.4971, device='cuda:0', grad_fn=<MulBackward0>), tensor(240.1629, device='cuda:0', grad_fn=<MulBackward0>), tensor(648.9856, device='cuda:0', grad_fn=<MulBackward0>), tensor(327.0161, device='cuda:0', grad_fn=<MulBackward0>), tensor(301.4815, device='cuda:

New best validation epoch 145 0 loss=22.336044311523438
epoch 150 loss=[tensor(14.4984, device='cuda:0', grad_fn=<MulBackward0>), tensor(327.4037, device='cuda:0', grad_fn=<MulBackward0>), tensor(148.8966, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.1171, device='cuda:0', grad_fn=<MulBackward0>), tensor(325.3510, device='cuda:0', grad_fn=<MulBackward0>), tensor(205.5032, device='cuda:0', grad_fn=<MulBackward0>), tensor(9.8945, device='cuda:0', grad_fn=<MulBackward0>), tensor(323.8295, device='cuda:0', grad_fn=<MulBackward0>), tensor(149.9833, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 155 loss=[tensor(13.1488, device='cuda:0', grad_fn=<MulBackward0>), tensor(328.6966, device='cuda:0', grad_fn=<MulBackward0>), tensor(154.2778, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.6719, device='cuda:0', grad_fn=<MulBackward0>), tensor(324.7178, device='cuda:0', grad_fn=<MulBackward0>), tensor(119.5815, device='cuda:0', grad_fn=<MulBackward0>), tensor(10.3454, device='cuda:0', 

epoch 15 loss=[tensor(1351.4723, device='cuda:0', grad_fn=<MulBackward0>), tensor(683.2127, device='cuda:0', grad_fn=<MulBackward0>), tensor(8632.9902, device='cuda:0', grad_fn=<MulBackward0>), tensor(1250.7533, device='cuda:0', grad_fn=<MulBackward0>), tensor(677.9613, device='cuda:0', grad_fn=<MulBackward0>), tensor(4993.8989, device='cuda:0', grad_fn=<MulBackward0>), tensor(1244.9813, device='cuda:0', grad_fn=<MulBackward0>), tensor(677.5326, device='cuda:0', grad_fn=<MulBackward0>), tensor(5320.2959, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 15 0 loss=38.49183654785156
New best validation epoch 15 1 loss=0.6129032373428345
New best validation epoch 15 2 loss=11.438858985900879
epoch 20 loss=[tensor(1326.0773, device='cuda:0', grad_fn=<MulBackward0>), tensor(679.1190, device='cuda:0', grad_fn=<MulBackward0>), tensor(4938.6587, device='cuda:0', grad_fn=<MulBackward0>), tensor(1171.4486, device='cuda:0', grad_fn=<MulBackward0>), tensor(671.7844, device='cuda:

epoch 85 loss=[tensor(285.7677, device='cuda:0', grad_fn=<MulBackward0>), tensor(624.3425, device='cuda:0', grad_fn=<MulBackward0>), tensor(238.4480, device='cuda:0', grad_fn=<MulBackward0>), tensor(84.3329, device='cuda:0', grad_fn=<MulBackward0>), tensor(570.8414, device='cuda:0', grad_fn=<MulBackward0>), tensor(280.2079, device='cuda:0', grad_fn=<MulBackward0>), tensor(110.7842, device='cuda:0', grad_fn=<MulBackward0>), tensor(587.8194, device='cuda:0', grad_fn=<MulBackward0>), tensor(241.1593, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 85 0 loss=23.577411651611328
epoch 90 loss=[tensor(189.9427, device='cuda:0', grad_fn=<MulBackward0>), tensor(610.7554, device='cuda:0', grad_fn=<MulBackward0>), tensor(290.8705, device='cuda:0', grad_fn=<MulBackward0>), tensor(65.7361, device='cuda:0', grad_fn=<MulBackward0>), tensor(541.0490, device='cuda:0', grad_fn=<MulBackward0>), tensor(222.8315, device='cuda:0', grad_fn=<MulBackward0>), tensor(75.7212, device='cuda:0',

epoch 160 loss=[tensor(15.3024, device='cuda:0', grad_fn=<MulBackward0>), tensor(339.6233, device='cuda:0', grad_fn=<MulBackward0>), tensor(146.5788, device='cuda:0', grad_fn=<MulBackward0>), tensor(16.6105, device='cuda:0', grad_fn=<MulBackward0>), tensor(346.5083, device='cuda:0', grad_fn=<MulBackward0>), tensor(145.9081, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.2758, device='cuda:0', grad_fn=<MulBackward0>), tensor(338.5226, device='cuda:0', grad_fn=<MulBackward0>), tensor(125.5492, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 165 loss=[tensor(15.9489, device='cuda:0', grad_fn=<MulBackward0>), tensor(340.2068, device='cuda:0', grad_fn=<MulBackward0>), tensor(150.4522, device='cuda:0', grad_fn=<MulBackward0>), tensor(15.3274, device='cuda:0', grad_fn=<MulBackward0>), tensor(346.2586, device='cuda:0', grad_fn=<MulBackward0>), tensor(145.1817, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.0344, device='cuda:0', grad_fn=<MulBackward0>), tensor(338.6612, device='cuda:

epoch 35 loss=[tensor(774.5756, device='cuda:0', grad_fn=<MulBackward0>), tensor(662.7573, device='cuda:0', grad_fn=<MulBackward0>), tensor(1054.1033, device='cuda:0', grad_fn=<MulBackward0>), tensor(548.1241, device='cuda:0', grad_fn=<MulBackward0>), tensor(664.0789, device='cuda:0', grad_fn=<MulBackward0>), tensor(564.7421, device='cuda:0', grad_fn=<MulBackward0>), tensor(571.8776, device='cuda:0', grad_fn=<MulBackward0>), tensor(661.2207, device='cuda:0', grad_fn=<MulBackward0>), tensor(487.0883, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 35 0 loss=27.208972930908203
epoch 40 loss=[tensor(594.6190, device='cuda:0', grad_fn=<MulBackward0>), tensor(659.5771, device='cuda:0', grad_fn=<MulBackward0>), tensor(367.1229, device='cuda:0', grad_fn=<MulBackward0>), tensor(382.3677, device='cuda:0', grad_fn=<MulBackward0>), tensor(657.8151, device='cuda:0', grad_fn=<MulBackward0>), tensor(382.1979, device='cuda:0', grad_fn=<MulBackward0>), tensor(385.6225, device='cuda

epoch 125 loss=[tensor(18.1010, device='cuda:0', grad_fn=<MulBackward0>), tensor(341.7888, device='cuda:0', grad_fn=<MulBackward0>), tensor(334.2834, device='cuda:0', grad_fn=<MulBackward0>), tensor(14.8899, device='cuda:0', grad_fn=<MulBackward0>), tensor(317.5509, device='cuda:0', grad_fn=<MulBackward0>), tensor(140.0147, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.6158, device='cuda:0', grad_fn=<MulBackward0>), tensor(326.8306, device='cuda:0', grad_fn=<MulBackward0>), tensor(145.0491, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 125 2 loss=11.320365905761719
epoch 130 loss=[tensor(17.8166, device='cuda:0', grad_fn=<MulBackward0>), tensor(337.8684, device='cuda:0', grad_fn=<MulBackward0>), tensor(441.2894, device='cuda:0', grad_fn=<MulBackward0>), tensor(13.2056, device='cuda:0', grad_fn=<MulBackward0>), tensor(317.5150, device='cuda:0', grad_fn=<MulBackward0>), tensor(150.4157, device='cuda:0', grad_fn=<MulBackward0>), tensor(11.5278, device='cuda:0',

New best validation epoch 5 1 loss=0.5322580337524414
epoch 10 loss=[tensor(1272.9548, device='cuda:0', grad_fn=<MulBackward0>), tensor(680.2745, device='cuda:0', grad_fn=<MulBackward0>), tensor(12330.8896, device='cuda:0', grad_fn=<MulBackward0>), tensor(1178.4143, device='cuda:0', grad_fn=<MulBackward0>), tensor(679.2714, device='cuda:0', grad_fn=<MulBackward0>), tensor(8660.9941, device='cuda:0', grad_fn=<MulBackward0>), tensor(1222.4081, device='cuda:0', grad_fn=<MulBackward0>), tensor(679.2499, device='cuda:0', grad_fn=<MulBackward0>), tensor(10085.1299, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 10 1 loss=0.5483870506286621
epoch 15 loss=[tensor(1204.3971, device='cuda:0', grad_fn=<MulBackward0>), tensor(678.6893, device='cuda:0', grad_fn=<MulBackward0>), tensor(9254.0273, device='cuda:0', grad_fn=<MulBackward0>), tensor(1101.8768, device='cuda:0', grad_fn=<MulBackward0>), tensor(671.8207, device='cuda:0', grad_fn=<MulBackward0>), tensor(5361.2104, device

epoch 95 loss=[tensor(20.8826, device='cuda:0', grad_fn=<MulBackward0>), tensor(478.3948, device='cuda:0', grad_fn=<MulBackward0>), tensor(121.4329, device='cuda:0', grad_fn=<MulBackward0>), tensor(35.5984, device='cuda:0', grad_fn=<MulBackward0>), tensor(458.1515, device='cuda:0', grad_fn=<MulBackward0>), tensor(143.5188, device='cuda:0', grad_fn=<MulBackward0>), tensor(19.9690, device='cuda:0', grad_fn=<MulBackward0>), tensor(459.6592, device='cuda:0', grad_fn=<MulBackward0>), tensor(110.1654, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 100 loss=[tensor(20.9023, device='cuda:0', grad_fn=<MulBackward0>), tensor(430.9246, device='cuda:0', grad_fn=<MulBackward0>), tensor(149.8590, device='cuda:0', grad_fn=<MulBackward0>), tensor(27.7427, device='cuda:0', grad_fn=<MulBackward0>), tensor(420.5434, device='cuda:0', grad_fn=<MulBackward0>), tensor(141.1926, device='cuda:0', grad_fn=<MulBackward0>), tensor(16.2273, device='cuda:0', grad_fn=<MulBackward0>), tensor(417.3853, device='cuda:0

epoch 185 loss=[tensor(15.6664, device='cuda:0', grad_fn=<MulBackward0>), tensor(333.8156, device='cuda:0', grad_fn=<MulBackward0>), tensor(105.4496, device='cuda:0', grad_fn=<MulBackward0>), tensor(10.4132, device='cuda:0', grad_fn=<MulBackward0>), tensor(319.9431, device='cuda:0', grad_fn=<MulBackward0>), tensor(148.4850, device='cuda:0', grad_fn=<MulBackward0>), tensor(10.0496, device='cuda:0', grad_fn=<MulBackward0>), tensor(324.2123, device='cuda:0', grad_fn=<MulBackward0>), tensor(108.5724, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 190 loss=[tensor(14.3295, device='cuda:0', grad_fn=<MulBackward0>), tensor(333.7338, device='cuda:0', grad_fn=<MulBackward0>), tensor(139.6151, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.8149, device='cuda:0', grad_fn=<MulBackward0>), tensor(319.9296, device='cuda:0', grad_fn=<MulBackward0>), tensor(134.5793, device='cuda:0', grad_fn=<MulBackward0>), tensor(10.5532, device='cuda:0', grad_fn=<MulBackward0>), tensor(324.1597, device='cuda:

epoch 55 loss=[tensor(175.2136, device='cuda:0', grad_fn=<MulBackward0>), tensor(666.6788, device='cuda:0', grad_fn=<MulBackward0>), tensor(280.2811, device='cuda:0', grad_fn=<MulBackward0>), tensor(167.5721, device='cuda:0', grad_fn=<MulBackward0>), tensor(648.7516, device='cuda:0', grad_fn=<MulBackward0>), tensor(229.5706, device='cuda:0', grad_fn=<MulBackward0>), tensor(150.7833, device='cuda:0', grad_fn=<MulBackward0>), tensor(655.7319, device='cuda:0', grad_fn=<MulBackward0>), tensor(206.8485, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 60 loss=[tensor(127.4232, device='cuda:0', grad_fn=<MulBackward0>), tensor(660.9889, device='cuda:0', grad_fn=<MulBackward0>), tensor(228.8494, device='cuda:0', grad_fn=<MulBackward0>), tensor(127.5320, device='cuda:0', grad_fn=<MulBackward0>), tensor(639.4371, device='cuda:0', grad_fn=<MulBackward0>), tensor(184.7549, device='cuda:0', grad_fn=<MulBackward0>), tensor(106.4457, device='cuda:0', grad_fn=<MulBackward0>), tensor(647.3659, device='c

epoch 140 loss=[tensor(11.6712, device='cuda:0', grad_fn=<MulBackward0>), tensor(331.8296, device='cuda:0', grad_fn=<MulBackward0>), tensor(124.3185, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.2673, device='cuda:0', grad_fn=<MulBackward0>), tensor(316.8590, device='cuda:0', grad_fn=<MulBackward0>), tensor(129.0490, device='cuda:0', grad_fn=<MulBackward0>), tensor(9.3643, device='cuda:0', grad_fn=<MulBackward0>), tensor(322.4659, device='cuda:0', grad_fn=<MulBackward0>), tensor(104.3844, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 145 loss=[tensor(14.2227, device='cuda:0', grad_fn=<MulBackward0>), tensor(331.4860, device='cuda:0', grad_fn=<MulBackward0>), tensor(162.9690, device='cuda:0', grad_fn=<MulBackward0>), tensor(16.4584, device='cuda:0', grad_fn=<MulBackward0>), tensor(316.0858, device='cuda:0', grad_fn=<MulBackward0>), tensor(206.6347, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.5964, device='cuda:0', grad_fn=<MulBackward0>), tensor(321.9026, device='cuda:0

epoch 15 loss=[tensor(943.1495, device='cuda:0', grad_fn=<MulBackward0>), tensor(665.8813, device='cuda:0', grad_fn=<MulBackward0>), tensor(5646.8970, device='cuda:0', grad_fn=<MulBackward0>), tensor(1084.5143, device='cuda:0', grad_fn=<MulBackward0>), tensor(670.1281, device='cuda:0', grad_fn=<MulBackward0>), tensor(8415.6094, device='cuda:0', grad_fn=<MulBackward0>), tensor(997.7222, device='cuda:0', grad_fn=<MulBackward0>), tensor(666.9930, device='cuda:0', grad_fn=<MulBackward0>), tensor(6146.5054, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 15 0 loss=36.473140716552734
New best validation epoch 15 1 loss=0.5806451439857483
New best validation epoch 15 2 loss=12.477664947509766
epoch 20 loss=[tensor(802.6125, device='cuda:0', grad_fn=<MulBackward0>), tensor(663.6039, device='cuda:0', grad_fn=<MulBackward0>), tensor(2817.4377, device='cuda:0', grad_fn=<MulBackward0>), tensor(950.5367, device='cuda:0', grad_fn=<MulBackward0>), tensor(670.0418, device='cuda:0',

epoch 100 loss=[tensor(35.1981, device='cuda:0', grad_fn=<MulBackward0>), tensor(516.2342, device='cuda:0', grad_fn=<MulBackward0>), tensor(123.9644, device='cuda:0', grad_fn=<MulBackward0>), tensor(21.2813, device='cuda:0', grad_fn=<MulBackward0>), tensor(356.1556, device='cuda:0', grad_fn=<MulBackward0>), tensor(126.5497, device='cuda:0', grad_fn=<MulBackward0>), tensor(18.8750, device='cuda:0', grad_fn=<MulBackward0>), tensor(422.3814, device='cuda:0', grad_fn=<MulBackward0>), tensor(107.6611, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 100 1 loss=0.6935483813285828
epoch 105 loss=[tensor(35.3870, device='cuda:0', grad_fn=<MulBackward0>), tensor(464.5385, device='cuda:0', grad_fn=<MulBackward0>), tensor(190.5192, device='cuda:0', grad_fn=<MulBackward0>), tensor(21.4915, device='cuda:0', grad_fn=<MulBackward0>), tensor(338.0663, device='cuda:0', grad_fn=<MulBackward0>), tensor(288.9641, device='cuda:0', grad_fn=<MulBackward0>), tensor(19.4302, device='cuda:0',

epoch 175 loss=[tensor(15.0802, device='cuda:0', grad_fn=<MulBackward0>), tensor(321.4807, device='cuda:0', grad_fn=<MulBackward0>), tensor(185.4872, device='cuda:0', grad_fn=<MulBackward0>), tensor(13.6370, device='cuda:0', grad_fn=<MulBackward0>), tensor(318.3303, device='cuda:0', grad_fn=<MulBackward0>), tensor(172.8368, device='cuda:0', grad_fn=<MulBackward0>), tensor(9.5636, device='cuda:0', grad_fn=<MulBackward0>), tensor(318.3979, device='cuda:0', grad_fn=<MulBackward0>), tensor(158.2493, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 180 loss=[tensor(13.7932, device='cuda:0', grad_fn=<MulBackward0>), tensor(321.4913, device='cuda:0', grad_fn=<MulBackward0>), tensor(134.4905, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.7978, device='cuda:0', grad_fn=<MulBackward0>), tensor(317.4648, device='cuda:0', grad_fn=<MulBackward0>), tensor(129.7250, device='cuda:0', grad_fn=<MulBackward0>), tensor(9.5643, device='cuda:0', grad_fn=<MulBackward0>), tensor(318.0255, device='cuda:0'

epoch 35 loss=[tensor(593.5204, device='cuda:0', grad_fn=<MulBackward0>), tensor(668.3511, device='cuda:0', grad_fn=<MulBackward0>), tensor(631.2611, device='cuda:0', grad_fn=<MulBackward0>), tensor(380.0844, device='cuda:0', grad_fn=<MulBackward0>), tensor(658.6888, device='cuda:0', grad_fn=<MulBackward0>), tensor(649.4153, device='cuda:0', grad_fn=<MulBackward0>), tensor(407.7576, device='cuda:0', grad_fn=<MulBackward0>), tensor(661.6776, device='cuda:0', grad_fn=<MulBackward0>), tensor(469.5570, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 35 0 loss=23.173587799072266
epoch 40 loss=[tensor(479.1500, device='cuda:0', grad_fn=<MulBackward0>), tensor(664.3994, device='cuda:0', grad_fn=<MulBackward0>), tensor(705.8476, device='cuda:0', grad_fn=<MulBackward0>), tensor(282.3532, device='cuda:0', grad_fn=<MulBackward0>), tensor(656.0772, device='cuda:0', grad_fn=<MulBackward0>), tensor(360.7833, device='cuda:0', grad_fn=<MulBackward0>), tensor(320.2164, device='cuda:

epoch 135 loss=[tensor(15.1609, device='cuda:0', grad_fn=<MulBackward0>), tensor(335.2334, device='cuda:0', grad_fn=<MulBackward0>), tensor(123.0781, device='cuda:0', grad_fn=<MulBackward0>), tensor(14.2531, device='cuda:0', grad_fn=<MulBackward0>), tensor(325.1230, device='cuda:0', grad_fn=<MulBackward0>), tensor(150.9857, device='cuda:0', grad_fn=<MulBackward0>), tensor(11.6527, device='cuda:0', grad_fn=<MulBackward0>), tensor(326.8553, device='cuda:0', grad_fn=<MulBackward0>), tensor(120.5876, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 140 loss=[tensor(15.8432, device='cuda:0', grad_fn=<MulBackward0>), tensor(332.7984, device='cuda:0', grad_fn=<MulBackward0>), tensor(149.8267, device='cuda:0', grad_fn=<MulBackward0>), tensor(13.5365, device='cuda:0', grad_fn=<MulBackward0>), tensor(323.9374, device='cuda:0', grad_fn=<MulBackward0>), tensor(140.1153, device='cuda:0', grad_fn=<MulBackward0>), tensor(11.2480, device='cuda:0', grad_fn=<MulBackward0>), tensor(325.1768, device='cuda:

New best validation epoch 0 1 loss=0.5483870506286621
New best validation epoch 0 2 loss=11.968938827514648
epoch 5 loss=[tensor(1511.7396, device='cuda:0', grad_fn=<MulBackward0>), tensor(688.2916, device='cuda:0', grad_fn=<MulBackward0>), tensor(12449.3223, device='cuda:0', grad_fn=<MulBackward0>), tensor(1470.4745, device='cuda:0', grad_fn=<MulBackward0>), tensor(688.5327, device='cuda:0', grad_fn=<MulBackward0>), tensor(12800.2197, device='cuda:0', grad_fn=<MulBackward0>), tensor(1486.3617, device='cuda:0', grad_fn=<MulBackward0>), tensor(688.3569, device='cuda:0', grad_fn=<MulBackward0>), tensor(12613.7188, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 5 0 loss=38.590667724609375
New best validation epoch 5 1 loss=0.5645161271095276
epoch 10 loss=[tensor(1389.2368, device='cuda:0', grad_fn=<MulBackward0>), tensor(683.3027, device='cuda:0', grad_fn=<MulBackward0>), tensor(7867.7368, device='cuda:0', grad_fn=<MulBackward0>), tensor(1287.1338, device='cuda:0', g

epoch 75 loss=[tensor(57.8797, device='cuda:0', grad_fn=<MulBackward0>), tensor(640.7081, device='cuda:0', grad_fn=<MulBackward0>), tensor(211.0664, device='cuda:0', grad_fn=<MulBackward0>), tensor(48.6555, device='cuda:0', grad_fn=<MulBackward0>), tensor(619.4106, device='cuda:0', grad_fn=<MulBackward0>), tensor(167.2256, device='cuda:0', grad_fn=<MulBackward0>), tensor(43.0111, device='cuda:0', grad_fn=<MulBackward0>), tensor(626.4453, device='cuda:0', grad_fn=<MulBackward0>), tensor(151.8203, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 80 loss=[tensor(44.8359, device='cuda:0', grad_fn=<MulBackward0>), tensor(630.6481, device='cuda:0', grad_fn=<MulBackward0>), tensor(197.2856, device='cuda:0', grad_fn=<MulBackward0>), tensor(38.1175, device='cuda:0', grad_fn=<MulBackward0>), tensor(596.0482, device='cuda:0', grad_fn=<MulBackward0>), tensor(137.6207, device='cuda:0', grad_fn=<MulBackward0>), tensor(33.5988, device='cuda:0', grad_fn=<MulBackward0>), tensor(608.2642, device='cuda:0'

New best validation epoch 165 0 loss=20.981998443603516
epoch 170 loss=[tensor(12.5365, device='cuda:0', grad_fn=<MulBackward0>), tensor(333.6563, device='cuda:0', grad_fn=<MulBackward0>), tensor(132.7324, device='cuda:0', grad_fn=<MulBackward0>), tensor(13.0014, device='cuda:0', grad_fn=<MulBackward0>), tensor(335.7894, device='cuda:0', grad_fn=<MulBackward0>), tensor(204.7424, device='cuda:0', grad_fn=<MulBackward0>), tensor(10.7119, device='cuda:0', grad_fn=<MulBackward0>), tensor(330.6035, device='cuda:0', grad_fn=<MulBackward0>), tensor(149.4544, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 175 loss=[tensor(15.0885, device='cuda:0', grad_fn=<MulBackward0>), tensor(333.6821, device='cuda:0', grad_fn=<MulBackward0>), tensor(250.0696, device='cuda:0', grad_fn=<MulBackward0>), tensor(10.7676, device='cuda:0', grad_fn=<MulBackward0>), tensor(335.6351, device='cuda:0', grad_fn=<MulBackward0>), tensor(129.7762, device='cuda:0', grad_fn=<MulBackward0>), tensor(10.4269, device='cuda:0',

epoch 40 loss=[tensor(683.6617, device='cuda:0', grad_fn=<MulBackward0>), tensor(657.8170, device='cuda:0', grad_fn=<MulBackward0>), tensor(556.7497, device='cuda:0', grad_fn=<MulBackward0>), tensor(384.3727, device='cuda:0', grad_fn=<MulBackward0>), tensor(655.6841, device='cuda:0', grad_fn=<MulBackward0>), tensor(518.9015, device='cuda:0', grad_fn=<MulBackward0>), tensor(372.0129, device='cuda:0', grad_fn=<MulBackward0>), tensor(651.7503, device='cuda:0', grad_fn=<MulBackward0>), tensor(431.7140, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 40 0 loss=25.099721908569336
New best validation epoch 40 2 loss=11.74926471710205
epoch 45 loss=[tensor(528.9890, device='cuda:0', grad_fn=<MulBackward0>), tensor(654.5575, device='cuda:0', grad_fn=<MulBackward0>), tensor(287.9587, device='cuda:0', grad_fn=<MulBackward0>), tensor(253.4801, device='cuda:0', grad_fn=<MulBackward0>), tensor(647.4658, device='cuda:0', grad_fn=<MulBackward0>), tensor(294.6591, device='cuda:0', g

epoch 140 loss=[tensor(16.0201, device='cuda:0', grad_fn=<MulBackward0>), tensor(338.8936, device='cuda:0', grad_fn=<MulBackward0>), tensor(339.1200, device='cuda:0', grad_fn=<MulBackward0>), tensor(14.8135, device='cuda:0', grad_fn=<MulBackward0>), tensor(339.8829, device='cuda:0', grad_fn=<MulBackward0>), tensor(163.7143, device='cuda:0', grad_fn=<MulBackward0>), tensor(12.4802, device='cuda:0', grad_fn=<MulBackward0>), tensor(334.6819, device='cuda:0', grad_fn=<MulBackward0>), tensor(149.9017, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 145 loss=[tensor(15.3918, device='cuda:0', grad_fn=<MulBackward0>), tensor(337.8127, device='cuda:0', grad_fn=<MulBackward0>), tensor(158.9030, device='cuda:0', grad_fn=<MulBackward0>), tensor(26.2576, device='cuda:0', grad_fn=<MulBackward0>), tensor(337.9777, device='cuda:0', grad_fn=<MulBackward0>), tensor(165.4268, device='cuda:0', grad_fn=<MulBackward0>), tensor(16.9893, device='cuda:0', grad_fn=<MulBackward0>), tensor(333.1351, device='cuda:

In [249]:
print(sum([f[0] for f in rmse])/10/12)
print(sum([f[1] for f in rmse])/10)
print(sum([f[2] for f in rmse])/10)

2.146106179555257
0.7258064270019531
14.51638765335083


In [167]:
# Least squares
# With full features

mod = age_t
modErr = []

for grp in range(10):
    as3 = torch.cat([X.permute(0,2,1).reshape(620,-1), torch.ones(620,2000).float().cuda()], dim=1)
    # as3 = torch.cat([X[:,0,:].reshape(620,-1), torch.ones(620,1000).float().cuda()], dim=1)
    # print(as3.shape)
    trainIdcs = groups[grp][0][0:496]
    testIdcs = groups[grp][1]
    w, res, _, _ = torch.linalg.lstsq(as3[trainIdcs], mod[trainIdcs])
    # print(res)
    # print(mseLoss(as3[trainIdcs]@w, mod[trainIdcs])**0.5)
    modErr.append(float(mseLoss(as3[testIdcs]@w, mod[testIdcs])**0.5))
    print(modErr[-1])
    
print(sum(modErr)/len(modErr))

26.360719680786133
30.614137649536133
25.495744705200195
23.039260864257812
24.457931518554688
24.21812629699707
25.936012268066406
27.349708557128906
24.71824073791504
24.69891357421875
25.688879585266115


In [44]:
# Try guessing age of the training set

mod = age_t
modErr = []

for grp in range(10):
    as3 = torch.cat(Ass[grp] + [torch.ones(Ass[grp][0].shape[0],1).float().cuda()], dim=1)
    print(as3.shape)
    as3 = as3 + 0.01*torch.randn(620,as3.shape[1]).float().cuda()
    trainIdcs = groups[grp][0][0:496]
    testIdcs = groups[grp][1]
    w, res, _, _ = torch.linalg.lstsq(as3[trainIdcs], mod[trainIdcs])
    # print(res)
    # print(mseLoss(as3[trainIdcs]@w, mod[trainIdcs])**0.5)
    modErr.append(float(mseLoss(as3[testIdcs]@w, mod[testIdcs])**0.5))
    print(modErr[-1])
    
print(sum(modErr)/len(modErr))

torch.Size([620, 201])
28.8840274810791
torch.Size([620, 201])
33.90989303588867
torch.Size([620, 201])
34.3302001953125
torch.Size([620, 201])
27.656007766723633
torch.Size([620, 201])
29.042627334594727
torch.Size([620, 201])
28.122554779052734
torch.Size([620, 201])
26.24472427368164
torch.Size([620, 201])
32.64718246459961
torch.Size([620, 201])
30.897991180419922
torch.Size([620, 201])
29.122114181518555
30.08573226928711


In [49]:
from sklearn.linear_model import LogisticRegression

genErr = []

for grp in range(10):
    as3 = torch.cat(Ass[grp] + [torch.ones(Ass[grp][0].shape[0],1).float().cuda()], dim=1)
    as3 = as3 + 0.01*torch.randn(620,as3.shape[1]).float().cuda()
    as3np = as3.detach().cpu().numpy()
#     print(as3np.shape)
    trainIdcs = groups[grp][0][0:496]
    testIdcs = groups[grp][1]

    clf = LogisticRegression(max_iter=5000).fit(as3np[trainIdcs], np.argmax(gen[trainIdcs], axis=1))
    genErr.append(np.sum(clf.predict(as3np[testIdcs]) == np.argmax(gen[testIdcs], axis=1)))
    
    print(genErr[-1]/62)
    
print(sum(genErr)/620)

0.6612903225806451
0.532258064516129
0.5
0.5645161290322581
0.5967741935483871
0.45161290322580644
0.5645161290322581
0.532258064516129
0.5483870967741935
0.532258064516129
0.5483870967741935


In [46]:
import torch.nn as nn
import torch.nn.functional as F
import time

ceLoss = torch.nn.CrossEntropyLoss()
mseLoss = torch.nn.MSELoss()

rmse = []
Ass = []

def allBelowThresh(losses, thresh):
    for loss,thr in zip(losses, thresh):
        if loss > thr:
            return False
    return True

def flatten(res):
    flat = [r for re in res for r in re]
    return flat

def getAvg(res):
    nPara = len(res)
    nTasks = len(res[0])
    avg = []
    for task in range(nTasks):
        avg.append(0)
        for para in range(nPara):
#             print(res[para][task].shape)
            avg[-1] += res[para][task]/nPara
#     raise 'Bad'
    return avg

class MLP(nn.Module):
    def __init__(self, inp, dp=0.1, dimA=10, dimB=[1,2,1]):
        super(MLP, self).__init__()
        self.A = nn.ModuleList([nn.Linear(inp.shape[2],dimA).float().cuda() for _ in range(inp.shape[1])])
        self.B = nn.ModuleList([nn.Linear(dimA,b).float().cuda() for b in 2*dimB])
        self.dp = nn.Dropout(p=dp)
        
    def getLatent(self, x, para):
        return F.relu(self.A[para](x[:,para]).squeeze())
        
    def forward(self, x):
        x = self.dp(x)
        nTasks = int(len(self.B)/2)
        As = []
        res = []
        for para in range(2):
            A = self.getLatent(x, para)
            As.append(A)
        for para,A in enumerate(As):
            res.append([])
            for task in range(nTasks):
                r = self.B[para*nTasks+task](A).squeeze()
                res[-1].append(r)
        return res

def validate(model, X, ys, testIdcs):
    model.eval()
    losses = []
    with torch.no_grad():
        res = model(X)
        for r,y in zip(getAvg(res), ys):
            if y.dim() == 1:
                loss = mseLoss(r[testIdcs], y[testIdcs]).cpu().numpy()**0.5
                losses.append(loss)
            else:
                corr = (torch.argmax(r, dim=1) == torch.argmax(y, dim=1))[testIdcs]
                loss = torch.sum(corr)/len(testIdcs)
                losses.append(loss)
    model.train()
    return losses

def getAs(model, X):
    model.eval()
    As = []
    with torch.no_grad():
        for para in range(2):
            A = model.getLatent(X, para)
            As.append(A)
    model.train()
    return As

nEpochs = 5000
pPeriod = 500
thresh = [30,0.01,10]
regParam = [1,50] #[1,1e3,50]

for grp in range(10):
    trainIdcs = groups[grp][0][0:496]
    trainValidIdcs = groups[grp][0]
    validIdcs = np.arange(496,len(trainValidIdcs))
    testIdcs = groups[grp][1]

    X0 = nback_p_t
    X1 = emoid_p_t

    X = torch.stack([X0, X1], dim=1)
    Xt = X[trainIdcs]
    Xtv = X[trainValidIdcs]

    yy = [age_t, wrt_t]
    yt = [age_t[trainIdcs], wrt_t[trainIdcs]]
    ytv = [age_t[trainValidIdcs], wrt_t[trainValidIdcs]]

    nTasks = len(yy)

    mlp = MLP(X, dp=0.2, dimA=100, dimB=[1,1])
    optim = torch.optim.Adam(mlp.parameters(), lr=5e-4, weight_decay=5e-4)

    validLoss = [[] for _ in range(nTasks)]

    for epoch in range(nEpochs):
        optim.zero_grad()
        res = mlp(Xt)
        loss = []
        for i,(r,y) in enumerate(zip(flatten(res)+getAvg(res), (X.shape[1]+1)*yt)):
            if y.dim() > 1:
                loss.append(regParam[i%nTasks]*ceLoss(r, y))
            else:
                loss.append(regParam[i%nTasks]*mseLoss(r, y))
        regLoss = 0
        for w in mlp.A:
            regLoss += 1000*torch.sum(torch.abs(w.weight))
        sum(loss).backward()
        optim.step()
        if epoch % pPeriod == 0 or epoch == nEpochs-1 or allBelowThresh(loss[0:nTasks], thresh):
            print(f'epoch {epoch} loss={loss}')
            losses = validate(mlp, Xtv, ytv, validIdcs)
            for i,lss in enumerate(losses):
                if (len(validLoss[i]) == 0 or 
                        (yy[i].dim() == 1 and lss < min(validLoss[i])) or 
                        (yy[i].dim() > 1 and lss > max(validLoss[i]))):
                    print(f'New best validation epoch {epoch} {i} loss={lss}')
                    torch.save(mlp.state_dict(), f'../../Work/LatentSim/mlp{i}.pyt')
                    validLoss[i].append(float(lss))
            if allBelowThresh(loss[0:nTasks], thresh):
                print('Early stopping')
                break

    finalLoss = []

    for i in range(nTasks):
        mlp.load_state_dict(torch.load(f'../../Work/LatentSim/mlp{i}.pyt'))
        loss = validate(mlp, X, yy, testIdcs)
        if i == 0:
            Ass.append(getAs(mlp, X))
        finalLoss.append(float(loss[i]))

    rmse.append(finalLoss)

    print(f'FINISHED {rmse}')

epoch 0 loss=[tensor(34422.2031, device='cuda:0', grad_fn=<MulBackward0>), tensor(544889.6250, device='cuda:0', grad_fn=<MulBackward0>), tensor(34405.6211, device='cuda:0', grad_fn=<MulBackward0>), tensor(545956.6875, device='cuda:0', grad_fn=<MulBackward0>), tensor(34413.9102, device='cuda:0', grad_fn=<MulBackward0>), tensor(545422.9375, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 0 0 loss=167.27210998535156
New best validation epoch 0 1 loss=90.43282318115234
epoch 500 loss=[tensor(1436.8684, device='cuda:0', grad_fn=<MulBackward0>), tensor(7555.4214, device='cuda:0', grad_fn=<MulBackward0>), tensor(1492.4949, device='cuda:0', grad_fn=<MulBackward0>), tensor(7740.5166, device='cuda:0', grad_fn=<MulBackward0>), tensor(1410.9790, device='cuda:0', grad_fn=<MulBackward0>), tensor(6872.2358, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 500 0 loss=41.21884536743164
New best validation epoch 500 1 loss=12.83427619934082
epoch 1000 loss=[tensor(

epoch 4000 loss=[tensor(49.3135, device='cuda:0', grad_fn=<MulBackward0>), tensor(394.3701, device='cuda:0', grad_fn=<MulBackward0>), tensor(33.9465, device='cuda:0', grad_fn=<MulBackward0>), tensor(375.2084, device='cuda:0', grad_fn=<MulBackward0>), tensor(19.1354, device='cuda:0', grad_fn=<MulBackward0>), tensor(196.3895, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 4000 0 loss=27.44513511657715
epoch 4500 loss=[tensor(33.6170, device='cuda:0', grad_fn=<MulBackward0>), tensor(426.4153, device='cuda:0', grad_fn=<MulBackward0>), tensor(30.0117, device='cuda:0', grad_fn=<MulBackward0>), tensor(350.7631, device='cuda:0', grad_fn=<MulBackward0>), tensor(16.2084, device='cuda:0', grad_fn=<MulBackward0>), tensor(203.5190, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 4500 0 loss=27.413555145263672
epoch 4999 loss=[tensor(40.0446, device='cuda:0', grad_fn=<MulBackward0>), tensor(367.3333, device='cuda:0', grad_fn=<MulBackward0>), tensor(26.4257, d

epoch 2500 loss=[tensor(332.5319, device='cuda:0', grad_fn=<MulBackward0>), tensor(423.3511, device='cuda:0', grad_fn=<MulBackward0>), tensor(134.7227, device='cuda:0', grad_fn=<MulBackward0>), tensor(412.7199, device='cuda:0', grad_fn=<MulBackward0>), tensor(155.0220, device='cuda:0', grad_fn=<MulBackward0>), tensor(200.2215, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 2500 0 loss=28.99574089050293
epoch 3000 loss=[tensor(220.4681, device='cuda:0', grad_fn=<MulBackward0>), tensor(412.0721, device='cuda:0', grad_fn=<MulBackward0>), tensor(65.2248, device='cuda:0', grad_fn=<MulBackward0>), tensor(353.0078, device='cuda:0', grad_fn=<MulBackward0>), tensor(78.2517, device='cuda:0', grad_fn=<MulBackward0>), tensor(182.9882, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 3000 0 loss=28.414283752441406
epoch 3500 loss=[tensor(136.0849, device='cuda:0', grad_fn=<MulBackward0>), tensor(411.5482, device='cuda:0', grad_fn=<MulBackward0>), tensor(43.46

epoch 500 loss=[tensor(1622.8766, device='cuda:0', grad_fn=<MulBackward0>), tensor(7243.0967, device='cuda:0', grad_fn=<MulBackward0>), tensor(1472.9882, device='cuda:0', grad_fn=<MulBackward0>), tensor(6813.2607, device='cuda:0', grad_fn=<MulBackward0>), tensor(1485.3036, device='cuda:0', grad_fn=<MulBackward0>), tensor(6293.4419, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 500 0 loss=40.5368537902832
New best validation epoch 500 1 loss=13.371190071105957
epoch 1000 loss=[tensor(1006.9297, device='cuda:0', grad_fn=<MulBackward0>), tensor(1497.4073, device='cuda:0', grad_fn=<MulBackward0>), tensor(953.6862, device='cuda:0', grad_fn=<MulBackward0>), tensor(1220.6130, device='cuda:0', grad_fn=<MulBackward0>), tensor(898.8024, device='cuda:0', grad_fn=<MulBackward0>), tensor(845.0059, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 1000 0 loss=32.21788024902344
New best validation epoch 1000 1 loss=13.181123733520508
epoch 1500 loss=[tensor(610

epoch 4500 loss=[tensor(56.5196, device='cuda:0', grad_fn=<MulBackward0>), tensor(377.1762, device='cuda:0', grad_fn=<MulBackward0>), tensor(37.1112, device='cuda:0', grad_fn=<MulBackward0>), tensor(399.8521, device='cuda:0', grad_fn=<MulBackward0>), tensor(19.2382, device='cuda:0', grad_fn=<MulBackward0>), tensor(204.0666, device='cuda:0', grad_fn=<MulBackward0>)]
epoch 4999 loss=[tensor(50.2741, device='cuda:0', grad_fn=<MulBackward0>), tensor(429.9974, device='cuda:0', grad_fn=<MulBackward0>), tensor(28.0491, device='cuda:0', grad_fn=<MulBackward0>), tensor(306.7668, device='cuda:0', grad_fn=<MulBackward0>), tensor(16.3926, device='cuda:0', grad_fn=<MulBackward0>), tensor(190.8331, device='cuda:0', grad_fn=<MulBackward0>)]
FINISHED [[28.788951873779297, 14.245462417602539], [33.96131896972656, 15.25642204284668], [25.1175479888916, 13.195879936218262], [27.136314392089844, 16.137489318847656], [28.36180877685547, 15.172502517700195], [27.52297019958496, 15.063119888305664], [27.3611

epoch 2500 loss=[tensor(246.9523, device='cuda:0', grad_fn=<MulBackward0>), tensor(442.9783, device='cuda:0', grad_fn=<MulBackward0>), tensor(94.4824, device='cuda:0', grad_fn=<MulBackward0>), tensor(397.7973, device='cuda:0', grad_fn=<MulBackward0>), tensor(107.8034, device='cuda:0', grad_fn=<MulBackward0>), tensor(207.0249, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 2500 0 loss=27.17934226989746
epoch 3000 loss=[tensor(124.1713, device='cuda:0', grad_fn=<MulBackward0>), tensor(412.6176, device='cuda:0', grad_fn=<MulBackward0>), tensor(52.1737, device='cuda:0', grad_fn=<MulBackward0>), tensor(453.3909, device='cuda:0', grad_fn=<MulBackward0>), tensor(45.3048, device='cuda:0', grad_fn=<MulBackward0>), tensor(218.2160, device='cuda:0', grad_fn=<MulBackward0>)]
New best validation epoch 3000 0 loss=26.718700408935547
epoch 3500 loss=[tensor(69.0751, device='cuda:0', grad_fn=<MulBackward0>), tensor(466.5592, device='cuda:0', grad_fn=<MulBackward0>), tensor(38.9257

In [47]:
print(sum([f[0] for f in rmse])/10)
print(sum([f[1] for f in rmse])/10)
# print(sum([f[2] for f in rmse])/10)

28.948826408386232
15.538841915130615


In [18]:
Ass[0][0].shape

torch.Size([620, 100])

In [36]:
for i in range(10):
    print(torch.sum(torch.sum(Ass[i][1] == 0, dim=1) > 80))

tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
