In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from data_utils import split_data_to_traj_and_control, mat2tracks
import wandb
import scipy.io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
train_mat_path = "./matlab/sdreDataset.mat"
val_mat_path = "./matlab/sdreVal.mat"

train = scipy.io.loadmat(train_mat_path)["dataset"]
val = scipy.io.loadmat(val_mat_path)["sdreVal"]

reshape = True
train_tracks = mat2tracks(train, reshape=reshape)
val_tracks = mat2tracks(val, reshape=reshape)

train_tracks = np.vstack(train_tracks) 
val_tracks = np.vstack(val_tracks)

In [4]:
train_dataset = split_data_to_traj_and_control(train_tracks)
test_dataset = split_data_to_traj_and_control(val_tracks)
len(train_dataset), len(val_tracks)

(20200, 10100)

In [5]:
train_loader = DataLoader(train_dataset, 
                          batch_size=64, 
                          shuffle=True,
                          drop_last=True)

test_loader = DataLoader(test_dataset, 
                          batch_size=64)

In [18]:
from models import DummyModel2
from train_utils import train_epoch, eval_epoch
from models import SplittedModel, SplittedModel2

In [22]:
# hidden_dim_1 = 64
# hidden_dim_2 = 64
# dropout_rate = 0.
# model = DummyModel2(hidden_dim_1=hidden_dim_1, 
#                    hidden_dim_2=hidden_dim_2, 
#                    dropout_rate=dropout_rate)


# hidden_dim_1 = 64
# model = SplittedModel(hidden_dim_1=hidden_dim_1)

hidden_dim_1 = 16
hidden_dim_2 = 16
model = SplittedModel2(hidden_dim_1=hidden_dim_1, hidden_dim_2=hidden_dim_2)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
criteria = nn.MSELoss()
optimizer = optim.Adam(model.parameters())
model

SplittedModel2(
  (block1): SplittedModelBlock2(
    (fc1): Linear(in_features=2, out_features=32, bias=True)
    (fc2): Linear(in_features=32, out_features=32, bias=True)
    (fc3): Linear(in_features=32, out_features=1, bias=True)
    (act): ReLU()
  )
  (block2): SplittedModelBlock2(
    (fc1): Linear(in_features=2, out_features=32, bias=True)
    (fc2): Linear(in_features=32, out_features=32, bias=True)
    (fc3): Linear(in_features=32, out_features=1, bias=True)
    (act): ReLU()
  )
  (block3): SplittedModelBlock2(
    (fc1): Linear(in_features=2, out_features=32, bias=True)
    (fc2): Linear(in_features=32, out_features=32, bias=True)
    (fc3): Linear(in_features=32, out_features=1, bias=True)
    (act): ReLU()
  )
)

In [23]:
wandb_loggging = True

if wandb_loggging:
    project_name = "SDRE_Approx"
    # run_name = f"Orig split; MLP 6-{hidden_dim_1}-{hidden_dim_2}-3, dropout={dropout_rate}"
    # run_name = f"Orig split; Splitted Model, hidden_dim={hidden_dim_1}"
    run_name = f"Orig split; Splitted Model, hidden_dim={hidden_dim_1},{hidden_dim_2}"
    wandb.login()
    wandb.init(project=project_name,
               name=run_name)



0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss,█▅▅▂▂▂▂▂▄▁▂▂▂▂▂▂▂▂▂▁▁▁▂▁▂▂▁▁▁▁▂▁▁▁▂▃▁▁▁▂

0,1
epoch,149.0
val_loss,0.00823


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

In [24]:
best_loss = 1e6
# save_path = f"MLP_3_{hidden_dim_1}_{hidden_dim_2}_6_best.pth"
save_path = f"splitted_model_best.pth"

for epoch in range(150):
# for epoch in range(100, 300):
    train_epoch(model, device, train_loader, criteria, optimizer)
    val_loss = eval_epoch(model, device, test_loader, criteria)
    
    if val_loss < best_loss:
        best_loss = val_loss
        # torch.save(model.state_dict(), save_path)
        print(f"Improve eval losss on epoch {epoch} = ", best_loss)
        
    if wandb_loggging:
        wandb.log({
            "val_loss": val_loss,
            "epoch" : epoch
            })

Improve eval losss on epoch 0 =  tensor(0.0106)
Improve eval losss on epoch 1 =  tensor(0.0076)
Improve eval losss on epoch 2 =  tensor(0.0067)
Improve eval losss on epoch 3 =  tensor(0.0064)
Improve eval losss on epoch 4 =  tensor(0.0062)
Improve eval losss on epoch 5 =  tensor(0.0060)
Improve eval losss on epoch 7 =  tensor(0.0060)
Improve eval losss on epoch 9 =  tensor(0.0058)
Improve eval losss on epoch 12 =  tensor(0.0057)
Improve eval losss on epoch 20 =  tensor(0.0057)
Improve eval losss on epoch 21 =  tensor(0.0057)
Improve eval losss on epoch 28 =  tensor(0.0056)
Improve eval losss on epoch 43 =  tensor(0.0056)
Improve eval losss on epoch 47 =  tensor(0.0056)
Improve eval losss on epoch 54 =  tensor(0.0056)
Improve eval losss on epoch 87 =  tensor(0.0056)
Improve eval losss on epoch 93 =  tensor(0.0056)
Improve eval losss on epoch 108 =  tensor(0.0056)
Improve eval losss on epoch 127 =  tensor(0.0055)
