In [1]:
import os 

from AEMG.data_utils import DynamicsDataset
from AEMG.systems.utils import get_system

import matplotlib.pyplot as plt
import numpy as np 
%matplotlib inline

import torch 
from torch.utils.data import DataLoader
from AEMG.models import *

from tqdm.notebook import tqdm 
import pickle

In [2]:
system = get_system("pendulum")
# config_fname = "config/pendulum_lqr_1K.txt"
config_fname = "config/physics_pendulum.txt"

with open(config_fname, 'r') as f:
    config = eval(f.read())

dataset = DynamicsDataset(config)

Getting data for:  physics_pendulum


100%|██████████| 1000/1000 [00:45<00:00, 22.22it/s]


In [3]:
# Split into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=True)


In [4]:
encoder  = Encoder(config["high_dims"],config["low_dims"])
dynamics = LatentDynamics(config["low_dims"])
decoder  = Decoder(config["low_dims"],config["high_dims"])

In [5]:
criterion = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(set(list(encoder.parameters()) + list(dynamics.parameters()) + list(decoder.parameters())), 
    lr=config["learning_rate"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, threshold=0.001, patience=5, verbose=True)

In [6]:
train_losses = {'loss_ae1': [], 'loss_ae2': [], 'loss_dyn': [], 'loss_total': []}
test_losses = {'loss_ae1': [], 'loss_ae2': [], 'loss_dyn': [], 'loss_total': []}

patience = config["patience"]

for epoch in tqdm(range(config["epochs"])):
    loss_ae1_train = 0
    loss_ae2_train = 0
    loss_dyn_train = 0

    current_train_loss = 0
    epoch_train_loss = 0

    warmup = 1 if epoch <= config["warmup"] else 0

    encoder.train()
    dynamics.train()
    decoder.train()

    counter = 0
    for i, (x_t, x_tau) in enumerate(train_loader,0):
        optimizer.zero_grad()

        # Forward pass
        z_t = encoder(x_t)
        x_t_pred = decoder(z_t)

        z_tau_pred = dynamics(z_t)
        z_tau = encoder(x_tau)
        x_tau_pred = decoder(z_tau_pred)

        # Compute losses
        loss_ae1 = criterion(x_t, x_t_pred)
        loss_ae2 = criterion(x_tau, x_tau_pred)
        loss_dyn = criterion(z_tau, z_tau_pred)

        loss_total = loss_ae1 + loss_ae2 + warmup * loss_dyn

        # Backward pass
        loss_total.backward()
        optimizer.step()

        current_train_loss += loss_total.item()
        epoch_train_loss += loss_total.item()

        loss_ae1_train += loss_ae1.item()
        loss_ae2_train += loss_ae2.item()
        loss_dyn_train += loss_dyn.item() * warmup
        counter += 1

        if (i+1) % 100 == 0:
            print("Epoch: {}, Iteration: {}, Loss: {}".format(epoch, i+1, current_train_loss))
            current_train_loss = 0
        
    train_losses['loss_ae1'].append(loss_ae1_train / counter)
    train_losses['loss_ae2'].append(loss_ae2_train / counter)
    train_losses['loss_dyn'].append(loss_dyn_train / counter)
    train_losses['loss_total'].append(epoch_train_loss / counter)

    with torch.no_grad():
        loss_ae1_test = 0
        loss_ae2_test = 0
        loss_dyn_test = 0
        epoch_test_loss = 0

        encoder.eval()
        dynamics.eval()
        decoder.eval()

        counter = 0
        for i, (x_t, x_tau) in enumerate(test_loader,0):
            # Forward pass
            z_t = encoder(x_t)
            x_t_pred = decoder(z_t)

            z_tau_pred = dynamics(z_t)
            z_tau = encoder(x_tau)
            x_tau_pred = decoder(z_tau_pred)

            # Compute losses
            loss_ae1 = criterion(x_t, x_t_pred)
            loss_ae2 = criterion(x_tau, x_tau_pred)
            loss_dyn = criterion(z_tau, z_tau_pred)

            loss_total = loss_ae1 + loss_ae2 + warmup * loss_dyn

            epoch_test_loss += loss_total.item()

            loss_ae1_test += loss_ae1.item()
            loss_ae2_test += loss_ae2.item()
            loss_dyn_test += loss_dyn.item() * warmup
            counter += 1

        test_losses['loss_ae1'].append(loss_ae1_test / counter)
        test_losses['loss_ae2'].append(loss_ae2_test / counter)
        test_losses['loss_dyn'].append(loss_dyn_test / counter)
        test_losses['loss_total'].append(epoch_test_loss / counter)

        if epoch >= patience and np.mean(test_losses['loss_total'][-patience:]) >= np.mean(test_losses['loss_total'][-patience:-1]):
            break

        if epoch >= config["warmup"]:
            scheduler.step(epoch_test_loss / counter)
        
    print("Epoch: {}, Train Loss: {}, Test Loss: {}".format(epoch, epoch_train_loss / counter, epoch_test_loss / counter))

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Iteration: 100, Loss: 7.273164948448539
Epoch: 0, Iteration: 200, Loss: 2.9534839782863855
Epoch: 0, Iteration: 300, Loss: 2.3458500783890486
Epoch: 0, Iteration: 400, Loss: 1.713400412350893
Epoch: 0, Iteration: 500, Loss: 0.9405544474720955
Epoch: 0, Iteration: 600, Loss: 0.6114334119483829
Epoch: 0, Iteration: 700, Loss: 0.5618428718298674
Epoch: 0, Iteration: 800, Loss: 0.538589644478634
Epoch: 0, Train Loss: 0.07971856984189682, Test Loss: 0.005303990248213029
Epoch: 1, Iteration: 100, Loss: 0.4962015822529793
Epoch: 1, Iteration: 200, Loss: 0.4459876911714673
Epoch: 1, Iteration: 300, Loss: 0.35431585228070617
Epoch: 1, Iteration: 400, Loss: 0.2798199006356299
Epoch: 1, Iteration: 500, Loss: 0.2338814929826185
Epoch: 1, Iteration: 600, Loss: 0.2113340785726905
Epoch: 1, Iteration: 700, Loss: 0.197131528344471
Epoch: 1, Iteration: 800, Loss: 0.1902435659430921
Epoch: 1, Train Loss: 0.01168235778150874, Test Loss: 0.0017935771217209198
Epoch: 2, Iteration: 100, Loss: 0.17

Epoch: 16, Iteration: 600, Loss: 0.04144890184397809
Epoch: 16, Iteration: 700, Loss: 0.04374198050936684
Epoch: 16, Iteration: 800, Loss: 0.04256481758784503
Epoch: 16, Train Loss: 0.0016968153421938728, Test Loss: 0.00041014975709905034
Epoch: 17, Iteration: 100, Loss: 0.04186595280771144
Epoch: 17, Iteration: 200, Loss: 0.04186607027077116
Epoch: 17, Iteration: 300, Loss: 0.04057490234845318
Epoch: 17, Iteration: 400, Loss: 0.04273464865400456
Epoch: 17, Iteration: 500, Loss: 0.0411005723872222
Epoch: 17, Iteration: 600, Loss: 0.042252741259289905
Epoch: 17, Iteration: 700, Loss: 0.04097498196642846
Epoch: 17, Iteration: 800, Loss: 0.03951478152885102
Epoch: 17, Train Loss: 0.001656765508125343, Test Loss: 0.00041173019863274547
Epoch: 18, Iteration: 100, Loss: 0.04020504659274593
Epoch: 18, Iteration: 200, Loss: 0.040955312084406614
Epoch: 18, Iteration: 300, Loss: 0.0400980809063185
Epoch: 18, Iteration: 400, Loss: 0.0394077833625488
Epoch: 18, Iteration: 500, Loss: 0.040450469416

Epoch: 32, Train Loss: 0.0013995157147828523, Test Loss: 0.00035551542921575467
Epoch: 33, Iteration: 100, Loss: 0.03388402357813902
Epoch: 33, Iteration: 200, Loss: 0.03590705612441525
Epoch: 33, Iteration: 300, Loss: 0.03410803625592962
Epoch: 33, Iteration: 400, Loss: 0.035110166456433944
Epoch: 33, Iteration: 500, Loss: 0.03457008187251631
Epoch: 33, Iteration: 600, Loss: 0.03354499945999123
Epoch: 33, Iteration: 700, Loss: 0.03441454845597036
Epoch: 33, Iteration: 800, Loss: 0.03452479632687755
Epoch: 33, Train Loss: 0.0013802588690922076, Test Loss: 0.0003785430853392431
Epoch: 34, Iteration: 100, Loss: 0.03483213212166447
Epoch: 34, Iteration: 200, Loss: 0.03726212809851859
Epoch: 34, Iteration: 300, Loss: 0.034021387866232544
Epoch: 34, Iteration: 400, Loss: 0.03650591065525077
Epoch: 34, Iteration: 500, Loss: 0.03429651670739986
Epoch: 34, Iteration: 600, Loss: 0.03474533182452433
Epoch: 34, Iteration: 700, Loss: 0.0343802310526371
Epoch: 34, Iteration: 800, Loss: 0.0339682286

Epoch: 49, Iteration: 100, Loss: 0.028112271073041484
Epoch: 49, Iteration: 200, Loss: 0.028034971823217347
Epoch: 49, Iteration: 300, Loss: 0.0288456350681372
Epoch: 49, Iteration: 400, Loss: 0.028379967421642505
Epoch: 49, Iteration: 500, Loss: 0.0287107680196641
Epoch: 49, Iteration: 600, Loss: 0.028640187927521765
Epoch: 49, Iteration: 700, Loss: 0.029007053453824483
Epoch: 49, Iteration: 800, Loss: 0.0285148919938365
Epoch: 49, Train Loss: 0.001141996280285358, Test Loss: 0.0002880741789516422
Epoch: 50, Iteration: 100, Loss: 0.0285873246320989
Epoch: 50, Iteration: 200, Loss: 0.029067060924717225
Epoch: 50, Iteration: 300, Loss: 0.02839294400473591
Epoch: 50, Iteration: 400, Loss: 0.02857309217506554
Epoch: 50, Iteration: 500, Loss: 0.027671404081047513
Epoch: 50, Iteration: 600, Loss: 0.028435753454687074
Epoch: 50, Iteration: 700, Loss: 0.028575116797583178
Epoch: 50, Iteration: 800, Loss: 0.028241214648005553
Epoch: 50, Train Loss: 0.0011384704229792082, Test Loss: 0.000284077

Epoch: 65, Iteration: 300, Loss: 0.024621151242172346
Epoch: 65, Iteration: 400, Loss: 0.025211987071088515
Epoch: 65, Iteration: 500, Loss: 0.02539442510169465
Epoch: 65, Iteration: 600, Loss: 0.024914497189456597
Epoch: 65, Iteration: 700, Loss: 0.02537338639376685
Epoch: 65, Iteration: 800, Loss: 0.024731724566663615
Epoch: 65, Train Loss: 0.0009987398180418567, Test Loss: 0.00024626171772962666
Epoch: 66, Iteration: 100, Loss: 0.024338172355783172
Epoch: 66, Iteration: 200, Loss: 0.025850593781797215
Epoch: 66, Iteration: 300, Loss: 0.025177846764563583
Epoch: 66, Iteration: 400, Loss: 0.0250244585913606
Epoch: 66, Iteration: 500, Loss: 0.02402762095152866
Epoch: 66, Iteration: 600, Loss: 0.024454357975628227
Epoch: 66, Iteration: 700, Loss: 0.024304263584781438
Epoch: 66, Iteration: 800, Loss: 0.024599743279395625
Epoch: 66, Train Loss: 0.0009883789294388593, Test Loss: 0.00024361256476489542
Epoch: 67, Iteration: 100, Loss: 0.02412051825376693
Epoch: 67, Iteration: 200, Loss: 0.0

Epoch: 81, Iteration: 500, Loss: 0.021171055108425207
Epoch: 81, Iteration: 600, Loss: 0.02101589234371204
Epoch: 81, Iteration: 700, Loss: 0.021150466141989455
Epoch: 81, Iteration: 800, Loss: 0.02190155464631971
Epoch: 81, Train Loss: 0.0008590485779128642, Test Loss: 0.0002181987918045346
Epoch: 82, Iteration: 100, Loss: 0.02110339084174484
Epoch: 82, Iteration: 200, Loss: 0.021074074393254705
Epoch: 82, Iteration: 300, Loss: 0.02158471454458777
Epoch: 82, Iteration: 400, Loss: 0.020950140606146306
Epoch: 82, Iteration: 500, Loss: 0.021369026828324422
Epoch: 82, Iteration: 600, Loss: 0.02091290373937227
Epoch: 82, Iteration: 700, Loss: 0.021240654416033067
Epoch: 82, Iteration: 800, Loss: 0.021196158093516715
Epoch: 82, Train Loss: 0.000849407178112134, Test Loss: 0.00021584957627759826
Epoch: 83, Iteration: 100, Loss: 0.02164810913382098
Epoch: 83, Iteration: 200, Loss: 0.02103813945723232
Epoch: 83, Iteration: 300, Loss: 0.021115199881023727
Epoch: 83, Iteration: 400, Loss: 0.0213

Epoch: 97, Iteration: 700, Loss: 0.018830656583304517
Epoch: 97, Iteration: 800, Loss: 0.01907644262246322
Epoch: 97, Train Loss: 0.0007742641514868901, Test Loss: 0.0001913558359527188
Epoch: 98, Iteration: 100, Loss: 0.019161990188877098
Epoch: 98, Iteration: 200, Loss: 0.01955060221371241
Epoch: 98, Iteration: 300, Loss: 0.019013718992937356
Epoch: 98, Iteration: 400, Loss: 0.01950902762473561
Epoch: 98, Iteration: 500, Loss: 0.01910083316033706
Epoch: 98, Iteration: 600, Loss: 0.01864267108612694
Epoch: 98, Iteration: 700, Loss: 0.01952875305141788
Epoch: 98, Iteration: 800, Loss: 0.018834663103916682
Epoch: 98, Train Loss: 0.0007670576953622062, Test Loss: 0.00019098224268833492
Epoch: 99, Iteration: 100, Loss: 0.01889586563629564
Epoch: 99, Iteration: 200, Loss: 0.018935617335955612
Epoch: 99, Iteration: 300, Loss: 0.019798479086603038
Epoch: 99, Iteration: 400, Loss: 0.01899462702567689
Epoch: 99, Iteration: 500, Loss: 0.01932631977251731
Epoch: 99, Iteration: 600, Loss: 0.01907

Epoch: 113, Iteration: 600, Loss: 0.019219859619624913
Epoch: 113, Iteration: 700, Loss: 0.018436193677189294
Epoch: 113, Iteration: 800, Loss: 0.017291250704147387
Epoch: 113, Train Loss: 0.0007262350472998226, Test Loss: 0.0001829627002449921
Epoch: 114, Iteration: 100, Loss: 0.017374147100781556
Epoch: 114, Iteration: 200, Loss: 0.018560764292487875
Epoch: 114, Iteration: 300, Loss: 0.018428877519909292
Epoch: 114, Iteration: 400, Loss: 0.018796727061271667
Epoch: 114, Iteration: 500, Loss: 0.01886997983092442
Epoch: 114, Iteration: 600, Loss: 0.017369550056173466
Epoch: 114, Iteration: 700, Loss: 0.01767247996031074
Epoch: 114, Iteration: 800, Loss: 0.017874223383842036
Epoch: 114, Train Loss: 0.0007244056294707598, Test Loss: 0.00018225929311965915
Epoch: 115, Iteration: 100, Loss: 0.017622818602831103
Epoch: 115, Iteration: 200, Loss: 0.017404731690476183
Epoch: 115, Iteration: 300, Loss: 0.01796006129734451
Epoch: 115, Iteration: 400, Loss: 0.0182699997239979
Epoch: 115, Iterati

Epoch: 129, Iteration: 500, Loss: 0.017234546110557858
Epoch: 129, Iteration: 600, Loss: 0.017700475378660485
Epoch: 129, Iteration: 700, Loss: 0.0168336369752069
Epoch: 129, Iteration: 800, Loss: 0.016847785329446197
Epoch: 129, Train Loss: 0.0006914126588107293, Test Loss: 0.00017203277703599945
Epoch: 130, Iteration: 100, Loss: 0.017752974046743475
Epoch: 130, Iteration: 200, Loss: 0.01645668534183642
Epoch: 130, Iteration: 300, Loss: 0.016889858670765534
Epoch: 130, Iteration: 400, Loss: 0.017772088925994467
Epoch: 130, Iteration: 500, Loss: 0.01692995796474861
Epoch: 130, Iteration: 600, Loss: 0.018014148765360005
Epoch: 130, Iteration: 700, Loss: 0.017675443712505512
Epoch: 130, Iteration: 800, Loss: 0.016949675453361124
Epoch: 130, Train Loss: 0.0006910775950948884, Test Loss: 0.00017108196344506639
Epoch: 131, Iteration: 100, Loss: 0.017117115698056296
Epoch: 131, Iteration: 200, Loss: 0.017643254577706102
Epoch: 131, Iteration: 300, Loss: 0.01721615593851311
Epoch: 131, Iterat

Epoch: 145, Iteration: 400, Loss: 0.016156600242538843
Epoch: 145, Iteration: 500, Loss: 0.01651166811643634
Epoch: 145, Iteration: 600, Loss: 0.01556746980349999
Epoch: 145, Iteration: 700, Loss: 0.01703412036295049
Epoch: 145, Iteration: 800, Loss: 0.016965566952421796
Epoch: 145, Train Loss: 0.000666679876588536, Test Loss: 0.00016840197366430566
Epoch: 146, Iteration: 100, Loss: 0.016725277331715915
Epoch: 146, Iteration: 200, Loss: 0.01686370250536129
Epoch: 146, Iteration: 300, Loss: 0.016089295022538863
Epoch: 146, Iteration: 400, Loss: 0.017570193027495407
Epoch: 146, Iteration: 500, Loss: 0.017820553614001255
Epoch: 146, Iteration: 600, Loss: 0.01686684470041655
Epoch: 146, Iteration: 700, Loss: 0.01670794029632816
Epoch: 146, Iteration: 800, Loss: 0.016054179475759156
Epoch: 146, Train Loss: 0.0006720811086838253, Test Loss: 0.0001664683695540287
Epoch: 147, Iteration: 100, Loss: 0.016142978027346544
Epoch: 147, Iteration: 200, Loss: 0.016927750475588255
Epoch: 147, Iteration

Epoch: 161, Iteration: 200, Loss: 0.015496400708798319
Epoch: 161, Iteration: 300, Loss: 0.01595251762046246
Epoch: 161, Iteration: 400, Loss: 0.015940347824653145
Epoch: 161, Iteration: 500, Loss: 0.01610591961798491
Epoch: 161, Iteration: 600, Loss: 0.0160677276289789
Epoch: 161, Iteration: 700, Loss: 0.015981481890776195
Epoch: 161, Iteration: 800, Loss: 0.01608067711640615
Epoch: 161, Train Loss: 0.0006397757347949999, Test Loss: 0.0001617618267797792
Epoch: 162, Iteration: 100, Loss: 0.01581269098096527
Epoch: 162, Iteration: 200, Loss: 0.015723940799944103
Epoch: 162, Iteration: 300, Loss: 0.016590500388701912
Epoch: 162, Iteration: 400, Loss: 0.016394787606259342
Epoch: 162, Iteration: 500, Loss: 0.015871082308876794
Epoch: 162, Iteration: 600, Loss: 0.015909268535324372
Epoch: 162, Iteration: 700, Loss: 0.015673963505832944
Epoch: 162, Iteration: 800, Loss: 0.015587420159135945
Epoch: 162, Train Loss: 0.0006393311213777606, Test Loss: 0.0001613004952572578
Epoch: 163, Iteration

Epoch: 176, Iteration: 800, Loss: 0.01567584742588224
Epoch: 176, Train Loss: 0.0006363093561958522, Test Loss: 0.00016097559763010257
Epoch: 177, Iteration: 100, Loss: 0.015899833153525833
Epoch: 177, Iteration: 200, Loss: 0.015937228294205852
Epoch: 177, Iteration: 300, Loss: 0.01599939670995809
Epoch: 177, Iteration: 400, Loss: 0.01612798189307796
Epoch: 177, Iteration: 500, Loss: 0.016078150256362278
Epoch: 177, Iteration: 600, Loss: 0.015844756111619063
Epoch: 177, Iteration: 700, Loss: 0.015636980773706455
Epoch: 177, Iteration: 800, Loss: 0.015830409771297127
Epoch   178: reducing learning rate of group 0 to 1.0000e-08.
Epoch: 177, Train Loss: 0.0006363285865853562, Test Loss: 0.00016097980615128183
Epoch: 178, Iteration: 100, Loss: 0.015827053270186298
Epoch: 178, Iteration: 200, Loss: 0.01613191895012278
Epoch: 178, Iteration: 300, Loss: 0.015654255541448947
Epoch: 178, Iteration: 400, Loss: 0.016051508267992176
Epoch: 178, Iteration: 500, Loss: 0.015911364294879604
Epoch: 178

Epoch: 192, Iteration: 600, Loss: 0.015993817360140383
Epoch: 192, Iteration: 700, Loss: 0.0158130782074295
Epoch: 192, Iteration: 800, Loss: 0.015483097951801028
Epoch: 192, Train Loss: 0.0006362608074860257, Test Loss: 0.00016097757283465138
Epoch: 193, Iteration: 100, Loss: 0.01490407691017026
Epoch: 193, Iteration: 200, Loss: 0.015763145398523193
Epoch: 193, Iteration: 300, Loss: 0.016279649760690518
Epoch: 193, Iteration: 400, Loss: 0.015674829526687972
Epoch: 193, Iteration: 500, Loss: 0.01600625655555632
Epoch: 193, Iteration: 600, Loss: 0.016027871053665876
Epoch: 193, Iteration: 700, Loss: 0.01616201049182564
Epoch: 193, Iteration: 800, Loss: 0.016337992172339
Epoch: 193, Train Loss: 0.0006362665028508199, Test Loss: 0.00016097486356309775
Epoch: 194, Iteration: 100, Loss: 0.015726775767689105
Epoch: 194, Iteration: 200, Loss: 0.016108747950056568
Epoch: 194, Iteration: 300, Loss: 0.016048023877374362
Epoch: 194, Iteration: 400, Loss: 0.01613346751400968
Epoch: 194, Iteration:

Epoch: 208, Iteration: 500, Loss: 0.015776794389239512
Epoch: 208, Iteration: 600, Loss: 0.015820921449630987
Epoch: 208, Iteration: 700, Loss: 0.01610967308806721
Epoch: 208, Iteration: 800, Loss: 0.015552422155451495
Epoch: 208, Train Loss: 0.0006362689778499503, Test Loss: 0.00016097380948807693
Epoch: 209, Iteration: 100, Loss: 0.016323363925039303
Epoch: 209, Iteration: 200, Loss: 0.015830812706553843
Epoch: 209, Iteration: 300, Loss: 0.016261633376416285
Epoch: 209, Iteration: 400, Loss: 0.015136707843339536
Epoch: 209, Iteration: 500, Loss: 0.015531583572737873
Epoch: 209, Iteration: 600, Loss: 0.01631188752799062
Epoch: 209, Iteration: 700, Loss: 0.016191501934372354
Epoch: 209, Iteration: 800, Loss: 0.015581030784233008
Epoch: 209, Train Loss: 0.0006362699584612701, Test Loss: 0.0001609748046513119
Epoch: 210, Iteration: 100, Loss: 0.016038696790928952
Epoch: 210, Iteration: 200, Loss: 0.015365369203209411
Epoch: 210, Iteration: 300, Loss: 0.01572177767957328
Epoch: 210, Itera

In [7]:
# Save the models
torch.save(encoder, os.path.join(config["model_dir"], "encoder.pt"))
torch.save(dynamics, os.path.join(config["model_dir"], "dynamics.pt"))
torch.save(decoder, os.path.join(config["model_dir"], "decoder.pt"))

In [8]:
# if log_dir doesn't exist, create it
if not os.path.exists(config["log_dir"]):
    os.makedirs(config["log_dir"])

# Save the losses as a pickle file
with open(os.path.join(config["log_dir"], "losses.pkl"), "wb") as f:
    pickle.dump({"train_losses": train_losses, "test_losses": test_losses}, f)