### Notebook for experimenting with model architecture and hparams
Github repository link: [SambhawDrag/LBD-neurIPS-2021](https://github.com/SambhawDrag/LBD-neurIPS-2021)

In [None]:
import torch
from typing import Any
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from Traj2Dataset import TrajDataset, DatasetTransform
from torch.utils.data import DataLoader, SubsetRandomSampler
import pytorch_lightning as pl
from pytorch_lightning import loggers
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import seed_everything

In [None]:
class Net(nn.Module):
    def __init__(self, in_dims, out_dims):
        super(Net, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dims,16),
            nn.ReLU(),
            nn.Linear(16,32),
            nn.ReLU(),
            nn.Linear(32,64),
            nn.ReLU(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,out_dims),
        )

    def forward(self, x):
        return self.layers(x)


In [None]:
class LitMLP(pl.LightningModule):
    def __init__(
        self,
        in_dims: int,
        out_dims: int,
        lr: float = 1e-3,
        *args: Any,
        **kwargs: Any
    ) -> None:
        super().__init__(*args, **kwargs)

        self.model = Net(in_dims, out_dims)
        self.lr = lr

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.mse_loss(y_hat, y)
        self.log('train_loss', loss, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.mse_loss(y_hat, y)
        self.log('val_loss', loss, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.mse_loss(y_hat, y)
        self.log('test_loss', loss, logger=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)


### HYPERPARAMETERS

In [None]:
root_dir = 'dataset'
system = 'great-wine-beetle'
validation_split = 0.2
learning_rate = 5e-4
batch_size = 256
max_epochs = 200
shuffle = True
SEED = 42
logdir = './logs/'
num_workers = 4

seed_everything(SEED)

In [None]:
class TargetCentreTransform():
    def __init__(self):
        pass
    def __call__(self, x: np.ndarray) -> np.ndarray:
        x[-2:] -= x[:2]
        return x

### TRAINING AND VALIDATION DATASET PREPARATION

In [None]:
state_norms, action_norms = TrajDataset.norms(system)

state_dim = len(state_norms)
action_dim = len(action_norms)

mean = state_norms[:, 0]  # mean
std = state_norms[:, 1]  # std_dev
transform = DatasetTransform(mean, std)

target_mean = action_norms[:, 0]  # mean
target_std = action_norms[:, 1]  # std_dev
target_transform = DatasetTransform(target_mean, target_std)

train_dataset = TrajDataset(system, root_dir, train=True,
                            transform=transform, target_transform=target_transform)

test_dataset = TrajDataset(system, root_dir, train=False,
                           transform=transform, target_transform=target_transform)
indices = np.arange(len(train_dataset))

if shuffle is True:
    np.random.shuffle(indices)

split = int(validation_split * len(train_dataset))
train_indices = indices[split:]
valid_indices = indices[:split]

# may use wandb later
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)


In [None]:
callbacks = []

tb_logger = loggers.TensorBoardLogger(logdir,name = 'v3')

train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size, num_workers=num_workers,
                              sampler=train_sampler)
valid_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size, num_workers=num_workers,
                              sampler=valid_sampler)
model = LitMLP(state_dim, action_dim, learning_rate)
trainer = pl.Trainer(
    gpus=1, logger=tb_logger,
    callbacks=callbacks,
    progress_bar_refresh_rate=60,
    max_epochs=max_epochs,)

trainer.fit(model, train_dataloader, valid_dataloader)

In [None]:
def upscale(x):
    return x * torch.Tensor(target_std) + torch.Tensor(target_mean)

test_dataloader = DataLoader(test_dataset,
                             batch_size=1, shuffle=False,
                             num_workers=0)
logs = trainer.test(model,test_dataloader,verbose=False)
# print(logs)
for i, (x, y) in enumerate(test_dataloader):
    print(f'Time-step: {i}')
    print(f'Target: {upscale(y)}')
    print(f'Prediction: {upscale(model(x))}')

In [None]:
from matplotlib import pyplot as plt

targets = []
test_dataset = TrajDataset(system,root_dir,train = False)
print(test_dataset.states)

# first two states are X,Y for end-effectors
for i,(x,_) in enumerate(test_dataset):
    targets.append(x[0:2])
targets = np.array(targets)
plt.plot(targets[:,0],targets[:,1])

In [None]:
targets = []
train_dataset = TrajDataset(system, root_dir, train=True)

# first two states are X,Y for end-effectors
traj_ID = 3
for i, (x, _) in enumerate(train_dataset):
    # each trajectory is a slice of 200 points in dataset
    if i in range(200*traj_ID, 200*(traj_ID+1)):
        targets.append(x[:2])
targets = np.array(targets)
plt.plot(targets[:, 0], targets[:, 1])