In [1]:
import os 
import time
import torch 
import torch.nn as nn 
import numpy as np
import matplotlib.pyplot as plt 

from train_mlp_ae import * 
from trainer import AE_Trainer
from mlp_ae import MLPVAE, MLP_VQVAE, vae_loss, vqvae_loss

In [5]:
src_dir = "/home/horowitz3/latent-diffusion_project/datasets/three_piece_assembly_latent_actions/buf.pkl"
n_test = 2 
n_val = 10
ac_chunk = 30
obs_dim = [11, 1]
ac_dim = [30, 7]
batch_size = 50
action_dataset = LatentActionBuffer(src_dir, 
                                            n_test,
                                            n_val, 
                                            mode="train",
                                            obs_dim=obs_dim, 
                                            ac_chunk = ac_chunk, 
                                            ac_dim = ac_dim)
val_action_dataset = LatentActionBuffer(src_dir, 
                                            n_test,
                                            n_val,
                                            mode="val", 
                                            obs_dim=obs_dim, 
                                            ac_chunk = ac_chunk, 
                                            ac_dim = ac_dim)
test_action_dataset = LatentActionBuffer(src_dir, 
                                            n_test,
                                            n_val,
                                            mode="test", 
                                            obs_dim=obs_dim, 
                                            ac_chunk = ac_chunk, 
                                            ac_dim = ac_dim) 

train_loader = DataLoader(action_dataset, 
                                  batch_size = batch_size, 
                                  shuffle = True, 
                                  num_workers=10)
val_loader = DataLoader(val_action_dataset, 
                                  batch_size = batch_size, 
                                  shuffle = True, 
                                  num_workers=10) 
test_loader = DataLoader(test_action_dataset, 
                                  batch_size = batch_size, 
                                  shuffle = True, 
                                  num_workers=10) 

model_type = "MLPVQVAE"
obs_dim = [11, 1]
ac_dim = [30, 7]
latent_dim = 55
hidden = [220, 110, 55 ]
n_embeddings = 256
model = MLP_VQVAE(obs_dim, ac_dim, latent_dim, hidden, n_embeddings=n_embeddings)
loss_fn = vqvae_loss
save_dir = "lr_sweep"
optim_params = {"log_freq": 100,
                "save_freq": 1000,
                "epochs": 200, 
                "lr": 0.0001}

trainer = AE_Trainer(model, 
                         loss_fn, 
                         "cuda:0", 
                         optim_params, 
                         f"tuned_mlp")
trainer.train_loop(train_loader, val_loader)

100
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]


100%|██████████| 100/100 [00:00<00:00, 136.30it/s]


100
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]


100%|██████████| 100/100 [00:00<00:00, 199.99it/s]


100
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]


100%|██████████| 100/100 [00:00<00:00, 195.10it/s]


{'log_freq': 100, 'save_freq': 1000, 'epochs': 200, 'lr': 0.0001}
Logged loss: 0.007097862958908081
Logged loss: 0.27896030008792877
Logged loss: 0.15153389543294907
Logged loss: 0.0921760156005621
Logged loss: 0.0851698362082243
Logged loss: 0.08359148748219013
Logged loss: 0.08214808374643326
Logged loss: 0.08129057176411152
Logged loss: 0.07939649477601052
Logged loss: 0.07763247437775135
Logged loss: 0.07868045561015606
Logged loss: 0.07751094229519367
Logged loss: 0.07603641849011183
Logged loss: 0.0766664882004261
Logged loss: 0.07804643921554089
Logged loss: 0.07941940754652023
Logged loss: 0.07772727355360985
Logged loss: 0.08015279285609722
Logged loss: 0.0816723557189107
Logged loss: 0.0816355885565281
Logged loss: 0.08301253162324429
Logged loss: 0.08682212710380555
Logged loss: 0.0798247543349862
Logged loss: 0.08045720353722573
Logged loss: 0.08028021059930325
Logged loss: 0.08143622417002916
Logged loss: 0.07843689642846584
Logged loss: 0.07990831062197686
Logged loss: 0.

In [7]:
test_loader = DataLoader(test_action_dataset, 
                                  batch_size = 10, 
                                  shuffle = True, 
                                  num_workers=10) 

model_type = "MLPVQVAE"
obs_dim = [11, 1]
ac_dim = [30, 7]
latent_dim = 55
hidden = [220, 110, 55 ]
model = MLP_VQVAE(obs_dim, ac_dim, latent_dim, hidden, n_embeddings=n_embeddings)
state_dict = torch.load("/home/horowitz3/dit-policy/latent_actions/tuned_mlp/checkpoint.pth")
model.load_state_dict(state_dict["model_state_dict"])
model.eval().cuda()
criterion = nn.MSELoss()
test_losses = []
with torch.no_grad(): 
    for batch in test_loader: 
        states = batch[0]
        actions = batch[1].flatten(start_dim=1, end_dim=2) 

        states = states.to("cuda")
        actions = actions.to("cuda")
        gt = torch.cat([states, actions], dim=1)
        pred, z_q, z_e, indices = model(states, actions)
        # print(pred)
        loss = criterion(pred, gt)
        test_losses.append(loss.item())
    losses = np.array(test_losses)
    avg_loss = losses.mean()
    std_loss = losses.std()

    print(f"Evaluation complete. Average Reconstruction Loss (MSE): {avg_loss:.4f} (Std Dev: {std_loss:.4f})")

Evaluation complete. Average Reconstruction Loss (MSE): 0.0742 (Std Dev: 0.0211)
