In [1]:
import os 

from AEMG.data_utils import DynamicsDataset
from AEMG.systems.utils import get_system

import matplotlib.pyplot as plt
import numpy as np 
%matplotlib inline

import torch 
from torch.utils.data import DataLoader
from AEMG.models import *

from tqdm.notebook import tqdm 
import pickle

In [2]:
# config_fname = "config/pendulum_lqr_1K.txt"
config_fname = "config/physics_pendulum.txt"
# config_fname = "config/hopper.txt"

with open(config_fname, 'r') as f:
    config = eval(f.read())

dataset = DynamicsDataset(config)

Getting data for:  physics_pendulum


100%|██████████| 1000/1000 [00:30<00:00, 32.40it/s]


In [3]:
# Split into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=True)


In [4]:
encoder  = Encoder(config["high_dims"],config["low_dims"])
dynamics = LatentDynamics(config["low_dims"])
decoder  = Decoder(config["low_dims"],config["high_dims"])

In [5]:
criterion = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(set(list(encoder.parameters()) + list(dynamics.parameters()) + list(decoder.parameters())), 
    lr=config["learning_rate"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, threshold=0.001, patience=5, verbose=True)

In [6]:
train_losses = {'loss_ae1': [], 'loss_ae2': [], 'loss_dyn': [], 'loss_total': []}
test_losses = {'loss_ae1': [], 'loss_ae2': [], 'loss_dyn': [], 'loss_total': []}

patience = config["patience"]

for epoch in tqdm(range(config["epochs"])):
    loss_ae1_train = 0
    loss_ae2_train = 0
    loss_dyn_train = 0

    current_train_loss = 0
    epoch_train_loss = 0

    warmup = 1 if epoch <= config["warmup"] else 0

    encoder.train()
    dynamics.train()
    decoder.train()

    counter = 0
    for i, (x_t, x_tau) in enumerate(train_loader,0):
        optimizer.zero_grad()

        # Forward pass
        z_t = encoder(x_t)
        x_t_pred = decoder(z_t)

        z_tau_pred = dynamics(z_t)
        z_tau = encoder(x_tau)
        x_tau_pred = decoder(z_tau_pred)

        # Compute losses
        loss_ae1 = criterion(x_t, x_t_pred)
        loss_ae2 = criterion(x_tau, x_tau_pred)
        loss_dyn = criterion(z_tau, z_tau_pred)

        loss_total = loss_ae1 + loss_ae2 + warmup * loss_dyn

        # Backward pass
        loss_total.backward()
        optimizer.step()

        current_train_loss += loss_total.item()
        epoch_train_loss += loss_total.item()

        loss_ae1_train += loss_ae1.item()
        loss_ae2_train += loss_ae2.item()
        loss_dyn_train += loss_dyn.item() * warmup
        counter += 1

        if (i+1) % 100 == 0:
            print("Epoch: {}, Iteration: {}, Loss: {}".format(epoch, i+1, current_train_loss))
            current_train_loss = 0
        
    train_losses['loss_ae1'].append(loss_ae1_train / counter)
    train_losses['loss_ae2'].append(loss_ae2_train / counter)
    train_losses['loss_dyn'].append(loss_dyn_train / counter)
    train_losses['loss_total'].append(epoch_train_loss / counter)

    with torch.no_grad():
        loss_ae1_test = 0
        loss_ae2_test = 0
        loss_dyn_test = 0
        epoch_test_loss = 0

        encoder.eval()
        dynamics.eval()
        decoder.eval()

        counter = 0
        for i, (x_t, x_tau) in enumerate(test_loader,0):
            # Forward pass
            z_t = encoder(x_t)
            x_t_pred = decoder(z_t)

            z_tau_pred = dynamics(z_t)
            z_tau = encoder(x_tau)
            x_tau_pred = decoder(z_tau_pred)

            # Compute losses
            loss_ae1 = criterion(x_t, x_t_pred)
            loss_ae2 = criterion(x_tau, x_tau_pred)
            loss_dyn = criterion(z_tau, z_tau_pred)

            loss_total = loss_ae1 + loss_ae2 + warmup * loss_dyn

            epoch_test_loss += loss_total.item()

            loss_ae1_test += loss_ae1.item()
            loss_ae2_test += loss_ae2.item()
            loss_dyn_test += loss_dyn.item() * warmup
            counter += 1

        test_losses['loss_ae1'].append(loss_ae1_test / counter)
        test_losses['loss_ae2'].append(loss_ae2_test / counter)
        test_losses['loss_dyn'].append(loss_dyn_test / counter)
        test_losses['loss_total'].append(epoch_test_loss / counter)

        if epoch >= patience and np.mean(test_losses['loss_total'][-patience:]) >= np.mean(test_losses['loss_total'][-patience:-1]):
            break

        if epoch >= config["warmup"]:
            scheduler.step(epoch_test_loss / counter)
        
    print("Epoch: {}, Train Loss: {}, Test Loss: {}".format(epoch, epoch_train_loss / counter, epoch_test_loss / counter))

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Iteration: 100, Loss: 5.799330146983266
Epoch: 0, Iteration: 200, Loss: 3.116762340068817
Epoch: 0, Iteration: 300, Loss: 2.7566737243905663
Epoch: 0, Iteration: 400, Loss: 1.4036371521651745
Epoch: 0, Iteration: 500, Loss: 1.2406577551737428
Epoch: 0, Iteration: 600, Loss: 1.2367299851030111
Epoch: 0, Iteration: 700, Loss: 1.181992569938302
Epoch: 0, Iteration: 800, Loss: 1.0595833156257868
Epoch: 0, Train Loss: 0.08387403534385099, Test Loss: 0.0069988203728273855
Epoch: 1, Iteration: 100, Loss: 0.643575509544462
Epoch: 1, Iteration: 200, Loss: 0.600460265763104
Epoch: 1, Iteration: 300, Loss: 0.5905587156303227
Epoch: 1, Iteration: 400, Loss: 0.5720366602763534
Epoch: 1, Iteration: 500, Loss: 0.5700590014457703
Epoch: 1, Iteration: 600, Loss: 0.5494118130300194
Epoch: 1, Iteration: 700, Loss: 0.5351232746616006
Epoch: 1, Iteration: 800, Loss: 0.45590485306456685
Epoch: 1, Train Loss: 0.021903899781082727, Test Loss: 0.0034672205313002547
Epoch: 2, Iteration: 100, Loss: 0.3

Epoch: 16, Iteration: 600, Loss: 0.03025722385791596
Epoch: 16, Iteration: 700, Loss: 0.029314720261027105
Epoch: 16, Iteration: 800, Loss: 0.028882417202112265
Epoch: 16, Train Loss: 0.0011743670542414522, Test Loss: 0.00029051863979367477
Epoch: 17, Iteration: 100, Loss: 0.030886536871548742
Epoch: 17, Iteration: 200, Loss: 0.028583197825355455
Epoch: 17, Iteration: 300, Loss: 0.03033670663717203
Epoch: 17, Iteration: 400, Loss: 0.02642940721125342
Epoch: 17, Iteration: 500, Loss: 0.027079086561570875
Epoch: 17, Iteration: 600, Loss: 0.028912921610753983
Epoch: 17, Iteration: 700, Loss: 0.03139105210721027
Epoch: 17, Iteration: 800, Loss: 0.026880713805439882
Epoch: 17, Train Loss: 0.0011608554329404174, Test Loss: 0.0002864723480662722
Epoch: 18, Iteration: 100, Loss: 0.0284871712065069
Epoch: 18, Iteration: 200, Loss: 0.028279514619498514
Epoch: 18, Iteration: 300, Loss: 0.028185363349621184
Epoch: 18, Iteration: 400, Loss: 0.02975408366182819
Epoch: 18, Iteration: 500, Loss: 0.028

Epoch: 32, Iteration: 800, Loss: 0.02058110455982387
Epoch: 32, Train Loss: 0.000833499206341144, Test Loss: 0.0002077306089797865
Epoch: 33, Iteration: 100, Loss: 0.019259098618931603
Epoch: 33, Iteration: 200, Loss: 0.021226897813903634
Epoch: 33, Iteration: 300, Loss: 0.01972089896298712
Epoch: 33, Iteration: 400, Loss: 0.020209313413943164
Epoch: 33, Iteration: 500, Loss: 0.021129453009052668
Epoch: 33, Iteration: 600, Loss: 0.023470743748475797
Epoch: 33, Iteration: 700, Loss: 0.019415783441218082
Epoch: 33, Iteration: 800, Loss: 0.02139683380664792
Epoch: 33, Train Loss: 0.0008261591215595629, Test Loss: 0.00019200016260965722
Epoch: 34, Iteration: 100, Loss: 0.02187681778013939
Epoch: 34, Iteration: 200, Loss: 0.020316156980697997
Epoch: 34, Iteration: 300, Loss: 0.019875047641107813
Epoch: 34, Iteration: 400, Loss: 0.02113784826360643
Epoch: 34, Iteration: 500, Loss: 0.022089364611019846
Epoch: 34, Iteration: 600, Loss: 0.019592848992033396
Epoch: 34, Iteration: 700, Loss: 0.01

Epoch: 49, Iteration: 100, Loss: 0.018628460325999185
Epoch: 49, Iteration: 200, Loss: 0.01767322904197499
Epoch: 49, Iteration: 300, Loss: 0.0185088714279118
Epoch: 49, Iteration: 400, Loss: 0.018419019666907843
Epoch: 49, Iteration: 500, Loss: 0.016159952552698087
Epoch: 49, Iteration: 600, Loss: 0.02030187139462214
Epoch: 49, Iteration: 700, Loss: 0.019512542436132208
Epoch: 49, Iteration: 800, Loss: 0.02038514856394613
Epoch: 49, Train Loss: 0.0007483908826443813, Test Loss: 0.00018072506084216368
Epoch: 50, Iteration: 100, Loss: 0.020223633153364062
Epoch: 50, Iteration: 200, Loss: 0.017343435596558265
Epoch: 50, Iteration: 300, Loss: 0.01868723022926133
Epoch: 50, Iteration: 400, Loss: 0.018490725829906296
Epoch: 50, Iteration: 500, Loss: 0.020391230187669862
Epoch: 50, Iteration: 600, Loss: 0.01670876242860686
Epoch: 50, Iteration: 700, Loss: 0.018780280035571195
Epoch: 50, Iteration: 800, Loss: 0.018485749707906507
Epoch: 50, Train Loss: 0.0007433769343031218, Test Loss: 0.0001

Epoch: 65, Iteration: 300, Loss: 0.017165501005365513
Epoch: 65, Iteration: 400, Loss: 0.015072130532644223
Epoch: 65, Iteration: 500, Loss: 0.016628758690785617
Epoch: 65, Iteration: 600, Loss: 0.01749813689093571
Epoch: 65, Iteration: 700, Loss: 0.016196038341149688
Epoch: 65, Iteration: 800, Loss: 0.01685696837375872
Epoch: 65, Train Loss: 0.00066649339658306, Test Loss: 0.00016035198934507007
Epoch: 66, Iteration: 100, Loss: 0.017576710219145752
Epoch: 66, Iteration: 200, Loss: 0.015363815771706868
Epoch: 66, Iteration: 300, Loss: 0.015159894872340374
Epoch: 66, Iteration: 400, Loss: 0.016038665802625474
Epoch: 66, Iteration: 500, Loss: 0.01768802206061082
Epoch: 66, Iteration: 600, Loss: 0.01721964144235244
Epoch: 66, Iteration: 700, Loss: 0.017881407591630705
Epoch: 66, Iteration: 800, Loss: 0.015834810750675388
Epoch: 66, Train Loss: 0.0006631053846619691, Test Loss: 0.00015522531957602655
Epoch: 67, Iteration: 100, Loss: 0.017456575609685387
Epoch: 67, Iteration: 200, Loss: 0.0

Epoch: 81, Iteration: 400, Loss: 0.015270852491084952
Epoch: 81, Iteration: 500, Loss: 0.016256662616797257
Epoch: 81, Iteration: 600, Loss: 0.014161326449539047
Epoch: 81, Iteration: 700, Loss: 0.014758786724996753
Epoch: 81, Iteration: 800, Loss: 0.014727625824889401
Epoch: 81, Train Loss: 0.0006150761918056082, Test Loss: 0.00016939030755146463
Epoch: 82, Iteration: 100, Loss: 0.015031199596705846
Epoch: 82, Iteration: 200, Loss: 0.01690816315385746
Epoch: 82, Iteration: 300, Loss: 0.014378939977177652
Epoch: 82, Iteration: 400, Loss: 0.014575467292161193
Epoch: 82, Iteration: 500, Loss: 0.015761073354951805
Epoch: 82, Iteration: 600, Loss: 0.0155514779035002
Epoch: 82, Iteration: 700, Loss: 0.013730639093409991
Epoch: 82, Iteration: 800, Loss: 0.016495707390276948
Epoch: 82, Train Loss: 0.0006075214504433868, Test Loss: 0.00014772468907867567
Epoch: 83, Iteration: 100, Loss: 0.01394737531154533
Epoch: 83, Iteration: 200, Loss: 0.016522937970876228
Epoch: 83, Iteration: 300, Loss: 0

Epoch: 97, Iteration: 400, Loss: 0.014482011039945064
Epoch: 97, Iteration: 500, Loss: 0.013140515147824772
Epoch: 97, Iteration: 600, Loss: 0.013397999075095868
Epoch: 97, Iteration: 700, Loss: 0.012265705361642176
Epoch: 97, Iteration: 800, Loss: 0.013985218192829052
Epoch: 97, Train Loss: 0.00054068210474351, Test Loss: 0.00013186830355688508
Epoch: 98, Iteration: 100, Loss: 0.01314552612893749
Epoch: 98, Iteration: 200, Loss: 0.015044760308228433
Epoch: 98, Iteration: 300, Loss: 0.012541441174107604
Epoch: 98, Iteration: 400, Loss: 0.012157595709140878
Epoch: 98, Iteration: 500, Loss: 0.012428040430677356
Epoch: 98, Iteration: 600, Loss: 0.01518566283994005
Epoch: 98, Iteration: 700, Loss: 0.014100105043326039
Epoch: 98, Iteration: 800, Loss: 0.012863520467362832
Epoch: 98, Train Loss: 0.000541578999197708, Test Loss: 0.00013202746717294372
Epoch: 99, Iteration: 100, Loss: 0.01389289947837824
Epoch: 99, Iteration: 200, Loss: 0.013453738974931184
Epoch: 99, Iteration: 300, Loss: 0.0

Epoch: 113, Iteration: 300, Loss: 0.011656517599476501
Epoch: 113, Iteration: 400, Loss: 0.012983818360225996
Epoch: 113, Iteration: 500, Loss: 0.012796100621926598
Epoch: 113, Iteration: 600, Loss: 0.013786638028250309
Epoch: 113, Iteration: 700, Loss: 0.013536218899389496
Epoch: 113, Iteration: 800, Loss: 0.012916660161863547
Epoch: 113, Train Loss: 0.000532154703051512, Test Loss: 0.00012969320418566177
Epoch: 114, Iteration: 100, Loss: 0.013883559968235204
Epoch: 114, Iteration: 200, Loss: 0.01382082091004122
Epoch: 114, Iteration: 300, Loss: 0.013501982197340112
Epoch: 114, Iteration: 400, Loss: 0.013649825177708408
Epoch: 114, Iteration: 500, Loss: 0.012269269082025858
Epoch: 114, Iteration: 600, Loss: 0.01220192057735403
Epoch: 114, Iteration: 700, Loss: 0.014578625046851812
Epoch: 114, Iteration: 800, Loss: 0.012862928881077096
Epoch: 114, Train Loss: 0.00053205822976734, Test Loss: 0.00013046542337851613
Epoch: 115, Iteration: 100, Loss: 0.01411118379837717
Epoch: 115, Iterati

Epoch: 128, Train Loss: 0.0005216788880026805, Test Loss: 0.00012757266848308213
Epoch: 129, Iteration: 100, Loss: 0.01274483920133207
Epoch: 129, Iteration: 200, Loss: 0.013479872803145554
Epoch: 129, Iteration: 300, Loss: 0.012609029705345165
Epoch: 129, Iteration: 400, Loss: 0.012302871087740641
Epoch: 129, Iteration: 500, Loss: 0.012751831327477703
Epoch: 129, Iteration: 600, Loss: 0.011854612159368116
Epoch: 129, Iteration: 700, Loss: 0.014327201421110658
Epoch: 129, Iteration: 800, Loss: 0.01310827118504676
Epoch: 129, Train Loss: 0.0005214557714103009, Test Loss: 0.0001279646643209133
Epoch: 130, Iteration: 100, Loss: 0.0136176140185853
Epoch: 130, Iteration: 200, Loss: 0.01323226081149187
Epoch: 130, Iteration: 300, Loss: 0.013143218420736957
Epoch: 130, Iteration: 400, Loss: 0.013117450427671429
Epoch: 130, Iteration: 500, Loss: 0.01336491726760869
Epoch: 130, Iteration: 600, Loss: 0.012966617148777004
Epoch: 130, Iteration: 700, Loss: 0.01223997716078884
Epoch: 130, Iteration

Epoch: 144, Iteration: 500, Loss: 0.013304330153914634
Epoch: 144, Iteration: 600, Loss: 0.013254876685095951
Epoch: 144, Iteration: 700, Loss: 0.012844438144384185
Epoch: 144, Iteration: 800, Loss: 0.012244326080690371
Epoch: 144, Train Loss: 0.0005212777171490744, Test Loss: 0.00012753724388870705
Epoch: 145, Iteration: 100, Loss: 0.01290389370842604
Epoch: 145, Iteration: 200, Loss: 0.012256680543941911
Epoch: 145, Iteration: 300, Loss: 0.011319936824293109
Epoch: 145, Iteration: 400, Loss: 0.012522560489742318
Epoch: 145, Iteration: 500, Loss: 0.013148885722330306
Epoch: 145, Iteration: 600, Loss: 0.013671619613887742
Epoch: 145, Iteration: 700, Loss: 0.015110422951693181
Epoch: 145, Iteration: 800, Loss: 0.013583983323769644
Epoch: 145, Train Loss: 0.0005212776542901163, Test Loss: 0.0001275697008724703
Epoch: 146, Iteration: 100, Loss: 0.014437292909860844
Epoch: 146, Iteration: 200, Loss: 0.013505887221981538
Epoch: 146, Iteration: 300, Loss: 0.01262833871805924
Epoch: 146, Iter

Epoch: 160, Iteration: 300, Loss: 0.01521415208117105
Epoch: 160, Iteration: 400, Loss: 0.013676128728548065
Epoch: 160, Iteration: 500, Loss: 0.012042864229442785
Epoch: 160, Iteration: 600, Loss: 0.014157570836687228
Epoch: 160, Iteration: 700, Loss: 0.013043976156041026
Epoch: 160, Iteration: 800, Loss: 0.012492770347307669
Epoch: 160, Train Loss: 0.0005212774281041829, Test Loss: 0.0001276905796786698
Epoch: 161, Iteration: 100, Loss: 0.013407214239123277
Epoch: 161, Iteration: 200, Loss: 0.013465807580359979
Epoch: 161, Iteration: 300, Loss: 0.012664420806686394
Epoch: 161, Iteration: 400, Loss: 0.012152014754974516
Epoch: 161, Iteration: 500, Loss: 0.012048799053445691
Epoch: 161, Iteration: 600, Loss: 0.014236483926652
Epoch: 161, Iteration: 700, Loss: 0.013938435760792345
Epoch: 161, Iteration: 800, Loss: 0.011977619342360413
Epoch: 161, Train Loss: 0.000521277618009999, Test Loss: 0.0001275863120498212
Epoch: 162, Iteration: 100, Loss: 0.011755124342016643
Epoch: 162, Iteratio

In [7]:
# Save the models
torch.save(encoder, os.path.join(config["model_dir"], "encoder.pt"))
torch.save(dynamics, os.path.join(config["model_dir"], "dynamics.pt"))
torch.save(decoder, os.path.join(config["model_dir"], "decoder.pt"))

In [8]:
# if log_dir doesn't exist, create it
if not os.path.exists(config["log_dir"]):
    os.makedirs(config["log_dir"])

# Save the losses as a pickle file
with open(os.path.join(config["log_dir"], "losses.pkl"), "wb") as f:
    pickle.dump({"train_losses": train_losses, "test_losses": test_losses}, f)