In [1]:

import torch
import torch.optim as optim
import torch.nn as nn
from parse_data import get_data, get_modified_values, get_binary_values, make_data_scalar
import numpy as np
import random
from data_gen import Datagen
from recognition import Recognition
from generator import Generator
from evaluation import evaluate_model, bin_plot
from time_recognition import TimeRecognition

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device=None
print("Using device: ", device)

import torch
print(torch.__version__)
 

Using device:  cuda
1.12.0+cu116


In [2]:
gen = Datagen(device)

x, y, x_1 = gen.get_generated_data(seq_len=2)

print("x", x[0])
print("y", y[0])
print("x_1", x_1[0])

x tensor([[0.],
        [0.]], device='cuda:0')
y tensor([1.], device='cuda:0')
x_1 tensor([[0.],
        [1.]], device='cuda:0')


In [3]:
import random

# Hyperparameters
sequence_length = [2*i for i in range(4,16)] # 2-20 increments of two
hidden_layers = [1,2] # 1 and 2
hidden_1 = [2**i for i in range(2,7)] # 2^4 to 2^9
hidden_2 =[2**i for i in range(5,10)] # 2^2 to 2^5
variance = [0.001, 0.01, 0.005, 0.05]
lr = [0.001, 0.01, 0.1, 0.005] # stop at 0.005
data_probability = [i/5 for i in range(1,6)]
regularization = [1/i for i in range(1,10)]
for i in range(3):
    regularization.append(0)

epochs = 500
optimizer = [optim.Adam, optim.SGD]

options = []

for seq_len in sequence_length:
    for layers in hidden_layers:
        for h1 in hidden_1:
            for h2 in hidden_2:
                for l in lr:
                    for v in variance:
                        for p in data_probability:
                            for r in regularization:
                                entry = {}
                                entry["seq_len"] = seq_len
                                entry["layers"] = layers
                                entry["latent"] = h1
                                entry["hidden"] = h2
                                entry["l"] = l
                                entry["variance"] = v
                                entry["data_prob"] = p
                                entry["regularization"] = r
                                options.append(entry)
                
                                         
random.shuffle(options)    


In [None]:


import torch.utils.data as data
from itertools import chain
import torch.nn.functional as F

def loss(x, x_hat, mean, R, s, x_1,reg,  device=None, seq_len=1):
    
    mse = nn.MSELoss().to(device)
    l = F.binary_cross_entropy(x_hat, x, reduction='sum')
    amount = mean[0].size()[0]*mean[0].size()[1]
    for m, r in zip(mean, R):
        
        C = r @ r.transpose(-2,-1) + 1e-6
        det = C.det() + 1e-6 
        l += 0.5 * torch.sum(m.pow(2).sum(-1) 
                             + C.diagonal(dim1=-2,dim2=-1).sum(-1)
                            -det.log()  -1)/amount

    count = len(s)*2
    for a, b in zip(s, x_1):
        l += reg*mse(a[0], b[0])/count
        l += reg*mse(a[1], b[1])/count
    
    #print(l, F.binary_cross_entropy(x_hat, x, reduction='sum'))
    return l 

best_model = None
best_score = 10000000000000000
batch_size = 10
best_history= [0,0,0,0,0,0]
for entry in options:

    x_d, y_d, x_1_d = gen.get_generated_data(entry["seq_len"], entry["variance"], entry["data_prob"])
    x_t, y_t, x_t_1 = gen.get_true_data(entry["seq_len"])
    x_val, y_val, x_val_1 = gen.get_test_data(entry["seq_len"])


    model_t = TimeRecognition(input_dim=x_d[0].size()[1],
                              hidden_size=entry["hidden"],
                              seq_len=entry["seq_len"],
                              layers=entry["layers"],
                             device=device)

    model_g = Generator(hidden_size=entry["hidden"],
                        latent_dim=entry["latent"],
                        output_dim=y_d[0].size()[0],
                        layers=entry["layers"],
                        seq_len=entry["seq_len"],
                        device=device)
    model_r = Recognition(input_dim=x_d[0].size()[1],
                          latent_dim=entry["latent"],
                          layers=entry["layers"],
                          device=device)

    loader = data.DataLoader(data.TensorDataset(x_d, y_d, x_1_d), batch_size=batch_size, shuffle=True)
    optimizer = optim.Adam(chain(model_r.parameters(), model_g.parameters(), model_t.parameters()), lr=entry["l"])
    #optimizer = optim.Adam(model_r.parameters())
    history = []
    bce = nn.BCELoss().to(device)
    for e in range(epochs):
        model_g.train()
        model_r.train()
        model_t.train()


        for x, y, x_1 in loader:

            x.to(device)
            y.to(device)
            if x.size()[0] < batch_size:
                continue
            if random.random() < 0.5:
                continue

            t = model_t(x)
            t_1 = model_t(x_1)
            model_g.make_internal_state()
            rec = model_r(x_1)
            model_g.set_xi(rec[-1])
            model_g.set_internal_state(t)
            b, s = model_g()

            l = loss(y, b, rec[0], rec[1], s, t_1, entry["regularization"], device, entry["seq_len"])
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
                        

        
        if e % 10 != 0:
            continue
        
        count = 0
        sum_loss = [0, 0]
        for j in range(2):
            for x, y, x_1 in loader:
                model_g.eval()
                model_t.eval()
                model_r.eval()
                model_g.make_internal_state()
                model_g.make_xi()
                with torch.no_grad():
                    model_g.make_internal_state()
                    rec = model_r(x_1)
                    t = model_t(x)
                    t_1 = model_t(x_1)
                    model_g.set_internal_state(t)
                    model_g.set_xi(rec[-1])
                    b,s = model_g()
                    l = loss(y, b, rec[0], rec[1],s,t_1,entry["regularization"], device, entry["seq_len"])
                    res = []
                    
                    sum_loss[j] += l.item()
                    count += 1
                    
        
        
        sum_loss[0] /= count
     
        
        history.append([e, sum_loss[0], sum_loss[1]])
        print(history[-1])

        if len(history) > 15:
            #if no real improvements are being done stop the training. 
            # but keep doing the training if the results without correctly feeding values get better
            if abs(history[-15][1] - history[-1][1]) < 0.0001:
                break
        
    
    if history[-1][1] < best_score:
        print("New best model:\nNew loss: ", history[-1], "\nOld loss:", best_history[-1], "\nHistory:" , history[-10:])
        best_model = model_g
        best_history = history
        best_score = history[-1][1]
        best_config = entry
        with torch.no_grad():
            evaluate_model(best_model,model_r, model_t, x_t, y_t, x_t_1,x_val,y_val, x_val_1, entry)
    else:
        with torch.no_grad():
            evaluate_model(model_g,model_r, model_t, x_t, y_t, x_t_1,x_val,y_val, x_val_1, entry)
        print("Old model still stands:\nCurrent loss: ", history[-1], "\nBest loss:", best_history[-1])
    

[0, 17.547841547674672, 26872.675704956055]
[10, 17.4524155783591, 26762.66370010376]
[20, 17.452003518844087, 26724.762882232666]
[30, 17.493527422374594, 26773.114753723145]
[40, 17.426821513213316, 26701.25640106201]
[50, 17.436093238252884, 26739.085342407227]
[60, 17.45832499262558, 26728.480560302734]
[70, 17.416660119597033, 26690.044540405273]
[80, 17.43880926535584, 26751.465545654297]
[90, 17.503481215036256, 26795.016441345215]
[100, 17.492648851778114, 26794.30743408203]
[110, 17.783321295954853, 27219.39860534668]
[120, 17.515038126753765, 26862.164264678955]
[130, 17.817390835316314, 27266.075119018555]
[140, 17.497040514535133, 26816.080921173096]
[150, 17.550386517228407, 26850.557542800903]
[160, 17.59357397114328, 26925.06930732727]
[170, 17.522343939340455, 26839.775463104248]
[180, 17.4361542883492, 26732.2423286438]
[190, 17.461722988995184, 26746.116275787354]
[200, 17.518507546606635, 26843.555110931396]
[210, 17.527628119247076, 26887.671031951904]
[220, 17.4149

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

[0, 5.530860319461275, 8472.35688495636]
[10, 5.44604346584091, 8345.87527179718]
[20, 5.427022946410017, 8310.75431728363]
[30, 5.459651928968903, 8348.285373687744]
[40, 5.432409644438144, 8322.579416275024]
[50, 5.417523978584429, 8289.800099372864]
[60, 5.4113887607584426, 8291.138054847717]
[70, 5.419271232565141, 8286.727126121521]
[80, 5.48652286753642, 8415.732560157776]
[90, 5.46618869099231, 8390.707265853882]
[100, 5.489182711270086, 8414.326011657715]
[110, 5.440545626159747, 8336.040493011475]
[120, 5.517198285608316, 8442.180633544922]
[130, 5.444980613245354, 8349.87398815155]
[140, 5.521116546488929, 8434.067008018494]
[150, 5.414909547676306, 8303.92007446289]
[160, 5.403094089373596, 8285.860313415527]
[170, 5.415791669028858, 8315.848898887634]
[180, 5.387303869655798, 8225.929094314575]
[190, 5.373007396805068, 8212.149488449097]
[200, 5.3623765461127375, 8215.94179725647]
[210, 5.35926717379697, 8202.090001106262]
[220, 5.34394308543392, 8194.308185577393]
[230, 5.

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

[0, 7.96841549624351, 12197.21286869049]
[10, 7.206181554196709, 11050.161675453186]
[20, 7.184416995658576, 11013.541392326355]
[30, 7.188541863668059, 11002.239295959473]
[40, 7.176464863919091, 11012.164929389954]
[50, 7.188213546344568, 11003.827660560608]
[60, 7.173216185432812, 10988.749794960022]
[70, 7.1759239120832, 10991.37683391571]
[80, 7.184115023276824, 10993.62639427185]
[90, 7.180804225856891, 11004.446343421936]
[100, 7.188011231683875, 10991.07633972168]
[110, 7.175296338980254, 11000.261646270752]
[120, 7.175666458612945, 11011.434679031372]
[130, 7.188980146114259, 11008.568476676941]
[140, 7.185471830417843, 11002.81362915039]
[150, 7.179086179708376, 11011.487866401672]
[160, 7.167616908295036, 11004.941648483276]
[170, 7.175376676703869, 10991.646372795105]
[180, 7.194247925561651, 10985.731075286865]
[190, 7.162592450258937, 10972.44027042389]
[200, 7.157607059254659, 10947.655336380005]
[210, 7.149655057618263, 10941.7962474823]
[220, 7.14762650021685, 10936.30

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [490, 7.124709520265266, 10907.518242835999] 
Best loss: [490, 5.350793542189636, 8190.683953285217]
[0, 25.717602460876147, 39399.20559692383]
[10, 25.200444617408372, 38590.00400543213]
[20, 25.203018835879494, 38610.3666267395]
[30, 25.170074540080975, 38555.425521850586]
[40, 25.153607004927593, 38514.25024795532]
[50, 25.146456895235623, 38537.48161697388]
[60, 25.142870581803685, 38517.89920043945]
[70, 25.13329977777546, 38505.713523864746]
[80, 25.121482585179898, 38487.945362091064]
[90, 25.107331773942196, 38471.72447967529]
[100, 25.109799512061375, 38478.10715484619]
[110, 25.099091280845066, 38445.43586349487]
[120, 25.09204064710333, 38430.33927536011]
[130, 25.082398300071297, 38416.985679626465]
[140, 25.062515699521057, 38404.66569900513]
[150, 25.042752263440164, 38359.15546417236]
[160, 25.04493731057986, 38363.72757720947]
[170, 25.01733298090046, 38334.697078704834]
[180, 25.034277395542233, 38325.06895828247]
[190, 25.0216310

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [490, 24.73404737676715, 37891.679695129395] 
Best loss: [490, 5.350793542189636, 8190.683953285217]
[0, 181.23328555034905, 277649.44916915894]
[10, 198.4732787030173, 304061.0627822876]
[20, 198.47578788737405, 304064.90701293945]
[30, 198.47728573373342, 304067.2029647827]
[40, 198.476799788114, 304066.45806884766]
[50, 198.68811984348545, 304390.1985168457]
[60, 198.69007398939007, 304393.1930541992]
[70, 198.68832127543715, 304390.50999450684]
[80, 198.6895317595559, 304392.36302948]
[90, 198.6891888601039, 304391.83751678467]
[100, 198.6887069134426, 304391.0995941162]
[110, 198.68902585400613, 304391.58560180664]
[120, 198.6877166060804, 304389.5844268799]
[130, 198.68973349030895, 304392.6719055176]
[140, 198.6905438495367, 304393.9135131836]
[150, 198.68880882960386, 304391.25534820557]
[160, 198.6865417092027, 304387.7861175537]
[170, 198.68892544181168, 304391.4336929321]
[180, 198.93661760972623, 304770.8970336914]
[190, 198.9371028133

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [490, 199.67971536820613, 305909.3238143921] 
Best loss: [490, 5.350793542189636, 8190.683953285217]
[0, 9.833517044704825, 15063.292650222778]
[10, 9.434688009107704, 14460.251297950745]
[20, 9.435022554571884, 14458.14135169983]
[30, 9.440172207884627, 14446.262466430664]
[40, 9.423885582010055, 14426.732360839844]
[50, 9.416898736132032, 14414.850706100464]
[60, 9.406905911298708, 14429.100450515747]
[70, 9.405271134239575, 14402.382389068604]
[80, 9.402638638299688, 14409.939205169678]
[90, 9.40358331496037, 14411.55333328247]
[100, 9.398789703379101, 14410.323835372925]
[110, 9.394828528710507, 14402.803073883057]
[120, 9.396754663233347, 14392.819421768188]
[130, 9.392418604918, 14398.914485931396]
[140, 9.389546909780478, 14397.33737564087]
[150, 9.403450149780154, 14405.231876373291]
[160, 9.391980686635947, 14389.410913467407]
[170, 9.395092130018588, 14386.873039245605]
[180, 9.390586361250119, 14393.707202911377]
[190, 9.398921342804911

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [490, 9.329511098388592, 14292.073858261108] 
Best loss: [490, 5.350793542189636, 8190.683953285217]
[0, 3.5974931299842057, 5495.7560296058655]
[10, 3.489854812933322, 5392.306377410889]
[20, 3.482085183146106, 5362.90541601181]
[30, 3.4714663411556272, 5309.191036224365]
[40, 3.4455578560617512, 5276.036802768707]
[50, 3.4377027430982565, 5264.957020282745]
[60, 3.4739296784911393, 5303.051542758942]
[70, 3.4479900950240094, 5292.937838554382]
[80, 3.448613315276004, 5277.350896835327]
[90, 3.451701805579133, 5288.452558755875]
[100, 3.448363216987789, 5280.167891263962]
[110, 3.4398520235604466, 5269.988178491592]
[120, 3.4555750241790055, 5297.361393213272]
[130, 3.435421272295262, 5270.21798324585]
[140, 3.4696366777619243, 5290.9004101753235]
[150, 3.4533198797982916, 5302.11617231369]
[160, 3.4476197746341595, 5263.393726825714]
[170, 3.4426160117042284, 5285.9470410346985]
[180, 3.461210384692598, 5311.564577102661]
[190, 3.434719757685151

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

[0, 49.55518285786203, 75910.56774139404]
[10, 49.28475654841092, 75476.95802307129]
[20, 210.9117156793181, 323116.7753601074]
[30, 207.77834737892252, 319041.225692749]
[40, 203.99205764093225, 312294.4493179321]
[50, 214.2306270549565, 328201.3096847534]
[60, 211.97391556821978, 324744.04008483887]
[70, 210.97094279854477, 323207.4809951782]
[80, 208.47156417089715, 319378.43713378906]
[90, 212.9741630504399, 326276.4233779907]
[100, 212.47077480744443, 325505.2279968262]
[110, 210.72084056366205, 322824.3291015625]
[120, 208.12636229014583, 318946.23220825195]
[130, 205.721340538005, 315165.0919189453]
[140, 204.4867848329071, 313449.75466918945]
[150, 212.72113004241226, 325888.77116394043]
[160, 212.47091482575817, 325505.4412689209]
[170, 212.2207472479997, 325122.18452453613]
[180, 211.470715938598, 323973.1390914917]
[190, 211.470810942488, 323973.28302001953]
[200, 206.50262091116247, 316126.00997161865]
[210, 214.7141134745147, 328942.012840271]
[220, 213.7214585894393, 3274

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [490, 212.72066403182305, 325888.0604476929] 
Best loss: [490, 3.375583776917221, 5160.119679927826]
[0, 25.396942920535412, 38907.769428253174]
[10, 25.185600133851054, 38580.22511291504]
[20, 25.169403833135302, 38562.66062164307]
[30, 25.164574727688382, 38553.89047241211]
[40, 25.17633491894595, 38559.08006286621]
[50, 25.15542349055915, 38544.04666900635]
[60, 25.152172850566494, 38536.23086166382]
[70, 25.16557128248887, 38561.82769393921]
[80, 25.164360514508836, 38551.68284225464]
[90, 25.13691423082476, 38514.40265274048]
[100, 25.139150537336462, 38512.165004730225]
[110, 25.143462838454283, 38523.56982803345]
[120, 25.14179868922221, 38511.634952545166]
[130, 25.141263555920155, 38523.38391876221]
[140, 25.137719963611573, 38510.6767616272]
[150, 25.136526336869125, 38510.20877075195]
[160, 25.13095399293825, 38509.71998214722]
[170, 25.13850292773533, 38515.92821121216]
[180, 25.134549868013465, 38501.057331085205]
[190, 25.14180038740

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [190, 154.7195227728189, 237030.30893325806] 
Best loss: [490, 3.375583776917221, 5160.119679927826]
[0, 7.317424368920587, 11197.352924346924]
[10, 7.282297491093528, 11152.579363822937]
[20, 7.171595282093975, 10993.974789619446]
[30, 7.205651498650135, 11022.207424163818]
[40, 7.176653500325375, 11002.047939300537]
[50, 7.197259653330471, 11008.168234825134]
[60, 7.2133923940185465, 11033.586723327637]
[70, 7.184749464453356, 10996.355233192444]
[80, 7.234658302899752, 11082.619052886963]
[90, 7.178966427596369, 11005.307967185974]
[100, 7.2150033777749885, 11058.715753555298]
[110, 7.179183950000892, 11003.992615699768]
[120, 7.205927556862097, 11048.773699760437]
[130, 7.184120626424685, 10999.451252937317]
[140, 7.189142690314948, 11029.17938709259]
[150, 7.200751844336407, 11036.512878417969]
[160, 7.205445358087126, 11056.363373756409]
[170, 7.195050990301386, 11002.65487575531]
[180, 7.194492004558872, 10998.319628715515]
[190, 7.17012780

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Old model still stands:
Current loss:  [490, 163.66117663047334, 250728.92532348633] 
Best loss: [490, 3.375583776917221, 5160.119679927826]
[0, 17.672824227156276, 27058.662006378174]
[10, 17.43568943561524, 26736.75941848755]
[20, 17.48090230578231, 26764.356357574463]
[30, 17.517374656219083, 26850.97677230835]
[40, 17.512190136523532, 26786.18320465088]
[50, 17.47760398207383, 26757.93119430542]
[60, 17.44028626036084, 26705.27140045166]
[70, 17.436614355281502, 26711.332118988037]
[80, 17.447856596804787, 26708.083974838257]
[90, 17.55319599071931, 26878.74666595459]
[100, 17.47838546215087, 26789.10796737671]
[110, 17.473807883947387, 26777.110095977783]
[120, 17.477880632285974, 26812.75588607788]
[130, 17.44504014007728, 26717.531520843506]
[140, 17.47934552832621, 26763.860832214355]
[150, 17.43720961364069, 26709.231021881104]
[160, 17.60386377588576, 26954.85832977295]
[170, 17.433053808486182, 26729.473121643066]
[180, 17.453377102436036, 26715.406898498535]
[190, 17.462221

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [490, 17.486030000930665, 26789.7889251709] 
Best loss: [490, 3.375583776917221, 5160.119679927826]
[0, 40.08581321357747, 61343.977684020996]
[10, 199.31873321533203, 305356.3480606079]
[20, 199.22223430265024, 305274.9369812012]
[30, 199.18595902664543, 305329.2130126953]
[40, 199.47476139516806, 305495.33644104004]
[50, 199.47378315938047, 305659.9463043213]
[60, 199.47378310958027, 305593.83613586426]
[70, 199.47378339842157, 305593.8355560303]
[80, 199.47223937231317, 305691.47091674805]
[90, 199.472239551594, 305591.47093200684]
[100, 199.47223934243306, 305591.4709625244]
[110, 199.47223935239307, 305591.47093200684]
[120, 199.4722392478126, 305591.4708404541]
[130, 199.47111015718227, 305589.74102020264]
[140, 199.47111013726217, 305589.74097442627]
[150, 199.47111033646308, 305589.7408294678]
[160, 199.47111041614343, 305589.7411880493]
[170, 199.4711103464231, 305589.74085235596]
[180, 199.47111025180268, 305589.7409210205]
[190, 199.471

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [270, 199.47111021694252, 305589.74072265625] 
Best loss: [490, 3.375583776917221, 5160.119679927826]
[0, 17.801300444739915, 27261.73701095581]
[10, 17.435858572120765, 26712.90725326538]
[20, 17.436646309596128, 26716.289529800415]
[30, 17.441430535079917, 26713.758716583252]
[40, 17.42002980951852, 26699.285469055176]
[50, 17.420573585649695, 26674.010988235474]
[60, 17.41431828018268, 26671.79056930542]
[70, 17.410705678456758, 26665.881393432617]
[80, 17.393913435873724, 26641.96212387085]
[90, 17.383898946697343, 26659.475788116455]
[100, 17.387970523485627, 26613.83727645874]
[110, 17.374667025732933, 26604.655925750732]
[120, 17.363566655091766, 26604.408420562744]
[130, 17.36544048070285, 26587.56692123413]
[140, 17.352433142400596, 26601.042728424072]
[150, 17.33808705825407, 26579.810329437256]
[160, 17.35191143772932, 26580.387702941895]
[170, 17.361035238669373, 26598.307983398438]
[180, 17.335225653088123, 26576.010948181152]
[190, 1

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [490, 17.32309045119323, 26535.625785827637] 
Best loss: [490, 3.375583776917221, 5160.119679927826]
[0, 25.67406060241221, 39365.9168510437]
[10, 25.21353031138527, 38604.89810562134]
[20, 25.18489280132961, 38582.57046508789]
[30, 25.206610604926126, 38606.03705596924]
[40, 25.15679805646055, 38559.51389694214]
[50, 25.158715848200625, 38535.603202819824]
[60, 25.140313280469133, 38498.485931396484]
[70, 25.139510167174176, 38512.36696243286]
[80, 25.129858335689217, 38470.58158111572]
[90, 25.154167262443362, 38527.05470275879]
[100, 25.129249119571856, 38480.604358673096]
[110, 25.134062816829033, 38507.91650009155]
[120, 25.115274165380093, 38479.811557769775]
[130, 25.126071153048123, 38496.75256729126]
[140, 25.099592193922238, 38449.1672706604]
[150, 25.09756410215293, 38434.61494064331]
[160, 25.084732979458245, 38423.49110031128]
[170, 25.07735953567545, 38416.73411941528]
[180, 25.065994499246383, 38388.83737564087]
[190, 25.08161787575

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [490, 24.9026020089889, 38138.39237213135] 
Best loss: [490, 3.375583776917221, 5160.119679927826]
[0, 378.3216882810269, 579588.9114685059]
[10, 391.48904980691856, 599761.1492614746]
[20, 394.340774346892, 604129.8386688232]
[30, 394.35995793155837, 604159.3505249023]
[40, 394.3834997829507, 604195.9956359863]
[50, 394.3824950932834, 604194.1002197266]
[60, 394.61623932362846, 604551.8927078247]
[70, 394.62003794538134, 604558.0006103516]
[80, 394.5890930614023, 604510.3921661377]
[90, 394.6156558641877, 604551.6469726562]
[100, 394.62596046706716, 604566.4669952393]
[110, 394.7907447117738, 604819.5163879395]
[120, 394.70326378326814, 604685.512512207]
[130, 394.6440122718911, 604595.2500610352]
[140, 394.5857549224136, 604505.7589416504]
[150, 394.5912471004317, 604513.8963470459]
[160, 394.64860997810064, 604602.3137359619]
[170, 394.66571367627336, 604627.6947021484]
[180, 394.6422320731937, 604592.5399780273]
[190, 394.6312028272345, 604574

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [490, 394.6353689049305, 604581.1992034912] 
Best loss: [490, 3.375583776917221, 5160.119679927826]
[0, 7.234750917623933, 11060.273988723755]
[10, 10.683057189921486, 16367.64380455017]
[20, 10.652526066135177, 16324.083723068237]
[30, 9.850148085203246, 15063.894432067871]
[40, 9.892031380775391, 15020.14243888855]
[50, 159.67838207922156, 244535.7511358261]
[60, 159.21882506699225, 243923.24005889893]
[70, 159.718729119388, 244689.09290409088]
[80, 161.26833737425642, 247177.09735298157]
[90, 161.19171955815807, 247453.0929260254]
[100, 160.10515202056646, 245945.9506149292]
[110, 159.7187290820378, 244689.09299850464]
[120, 159.71872913681807, 244789.09270191193]
[130, 159.62995604933397, 244689.09295272827]
[140, 159.78400317807112, 244689.09288024902]
[150, 159.7187290347276, 244689.09274291992]
[160, 159.71872906896527, 244719.09285354614]
[170, 159.63598932858858, 244689.092751503]
[180, 159.7187290565152, 244689.0928964615]
[190, 159.7187

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/3000 [00:00<?, ?it/s]

Old model still stands:
Current loss:  [380, 159.7187291038254, 244689.09282684326] 
Best loss: [490, 3.375583776917221, 5160.119679927826]
[0, 3.634951170991046, 5581.94886803627]
[10, 3.560966320510944, 5445.885446548462]
[20, 3.4979582025861617, 5345.657736778259]
[30, 3.477188682742903, 5318.981878757477]
[40, 3.4384055697886815, 5265.795308113098]
[50, 3.4208286866818023, 5229.155207157135]
[60, 3.4146142043270578, 5214.185295581818]
[70, 3.4147542708847602, 5217.859266757965]
[80, 3.3938401059755767, 5217.989878892899]
[90, 3.397914551879345, 5184.279887676239]
[100, 3.381524423083811, 5211.051362514496]
[110, 3.3934390102914356, 5205.57883310318]
[120, 3.407035644309639, 5220.3008761405945]
[130, 3.4070518589517778, 5215.273873090744]
[140, 3.3874574686466246, 5214.564924240112]
[150, 3.4107053475965095, 5217.988342761993]
[160, 3.4176420402900356, 5214.47881436348]
[170, 3.398270099343581, 5201.282989501953]
[180, 3.407874928442056, 5197.761781692505]
[190, 3.3996922558966256, 

In [None]:
a = torch.zeros(10,5,1)

In [None]:
a[:,-1,:].size()

In [None]:
b = torch.zeros(30,1)
c = torch.zeros(1,1)


In [None]:
torch.cat((b[1:],c))

In [None]:
23