In [1]:
import torch 
from torch.utils.data import DataLoader
from torch.nn import MSELoss
import numpy as np
import pickle
import math
import time

from sys_id.dataset import load_trajectory, WheeledTrajWindowed, PhysProps
from model import GPT2


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
eval_params = {
        'checkpoint_path': './logs/2024-04-10_22-02-09/checkpoint_epoch_940.pth', 
        'dataset_folder_path': '../dataset/eval_model/wheeled_flat', 
        'window_size': 50,
        'batch_size': 1, 
    }

model_params = {
        "n_layer": 2,
        "n_head": 3,
        "pdrop": 0.1,
        "max_seq_length": 1000,
        'position_encoding': 'sine',
        "output_size": 3,
        "input_size": (42 + 12) * eval_params['window_size'], 
        "hidden_size": (42 + 12) * eval_params['window_size'], 
    }

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2(**model_params).to(device)
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load(eval_params['checkpoint_path'], map_location=device)
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [5]:
import pickle as pkl 
import os
file_path = os.path.join(eval_params["dataset_folder_path"], "traj_0000.pkl")
with open(file_path, "rb") as file:
    traj = pkl.load(file)

print(traj.keys())

dict_keys(['obs', 'act', 'physprops'])


In [10]:
obs_history = np.array(traj['obs'][:50]).flatten()
action_history = np.array(traj['act'][:50]).flatten()
tmp_history = np.concatenate([obs_history, action_history])
history_input = torch.tensor(tmp_history, device = device, dtype=torch.float).unsqueeze(0)
print(history_input.shape)

torch.Size([1, 2700])


In [11]:
model.eval()
estimation, _ = model(history_input, None)

In [7]:
history_input.shape

torch.Size([540])

In [12]:
estimation.shape

torch.Size([3])

In [13]:
estimation

tensor([2.7634, 2.2557, 1.2845], device='cuda:0', grad_fn=<SqueezeBackward0>)

In [19]:
traj['physprops'][50]

[4.388001441955566, 5.940247535705566, 1.285484790802002]