In [None]:
import torch
from typing import Any
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from Traj2Dataset import TrajDataset, DatasetTransform
from torch.utils.data import DataLoader, SubsetRandomSampler
import pytorch_lightning as pl
from pytorch_lightning import loggers
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import seed_everything
from model import LitMLP as Net

In [None]:
root_dir = 'dataset'
system = 'great-piquant-bumblebee'
validation_split = 0.2
learning_rate = 1e-3
batch_size = 32
max_epochs = 100
shuffle = True
SEED = 42
logdir = './logs/'
num_workers = 4

seed_everything(SEED)

In [None]:

train_dataset = TrajDataset(system, root_dir)

state_dim = len(train_dataset.states)
action_dim = len(train_dataset.actions)

mean = [x['mean'] for x in train_dataset.states]  # mean
std = [x['std'] for x in train_dataset.states]   # std_dev
transform = DatasetTransform(mean, std)

target_mean = [x['mean'] for x in train_dataset.actions]  # mean
target_std = [x['std'] for x in train_dataset.actions]   # std_dev
target_transform = DatasetTransform(target_mean, target_std)

train_dataset.tranform = transform
train_dataset.target_transform = target_transform

test_dataset = TrajDataset(system, root_dir, train=False,
                           transform=transform,
                           target_transform=target_transform)
indices = np.arange(len(train_dataset))

if shuffle is True:
    np.random.shuffle(indices)

In [None]:
path = "./submission/models/great-piquant-bumblebee-v1.ckpt"
self.model = Net.load_from_checkpoint(
            path, in_dims=10, out_dims=2)

In [None]:
def upscale(x):
    return x * torch.Tensor(target_std) + torch.Tensor(target_mean)


test_dataloader = DataLoader(test_dataset,
                             batch_size=1, shuffle=False,
                             num_workers=0)

for i, (x, y) in enumerate(test_dataloader):
    print(f'Time-step: {i}')
    print(f'Target: {upscale(y)}')
    print(f'Prediction: {upscale(model(x))}')

In [None]:
def test():
    ctrl = controller(system='great-piquant-bumblebee', d_control=2)
    ## input provided as a vector of shape (X,1)
    print(ctrl.get_input(
        np.random.randn(8,1),
        np.random.randn(2,1),
        np.random.randn(2,1)
    ))

In [None]:
from matplotlib import pyplot as plt

targets = []
test_dataset = TrajDataset(system,root_dir,train = False)
print(test_dataset.states)

# first two states are X,Y for end-effectors
for i,(x,_) in enumerate(test_dataset):
    targets.append(x[0:2])
targets = np.array(targets)
plt.plot(targets[:,0],targets[:,1])

In [None]:
targets = []
train_dataset = TrajDataset(system, root_dir, train=True)

# first two states are X,Y for end-effectors
traj_ID = 3
for i, (x, _) in enumerate(train_dataset):
    # each trajectory is a slice of 200 points in dataset
    if i in range(200*traj_ID, 200*(traj_ID+1)):
        targets.append(x[:2])
targets = np.array(targets)
plt.plot(targets[:, 0], targets[:, 1])