In [None]:
import torch
from typing import Any
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from Traj2Dataset import TrajDataset,DatasetTransform
from torch.utils.data import DataLoader,SubsetRandomSampler
import pytorch_lightning as pl
from pytorch_lightning import loggers
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import seed_everything


In [None]:
class Net(nn.Module):
    def __init__(self, in_dims, out_dims):
        super(Net, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dims, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, out_dims)
        )

    def forward(self, x):
        return self.layers(x)


In [None]:
class LitMLP(pl.LightningModule):
    def __init__(
        self,
        in_dims: int,
        out_dims: int,
        lr: float = 1e-3,
        *args: Any,
        **kwargs: Any
    ) -> None:
        super().__init__(*args, **kwargs)
        
        self.model = Net(in_dims, out_dims)
        self.lr = lr
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
    
    def training_step(self,batch,batch_idx):
        x,y = batch
        y_hat = self.forward(x)
        loss = F.mse_loss(y_hat, y)
        self.log({'loss':loss},logger = True)
        return loss

    def validation_step(self,batch,batch_idx):
        x,y = batch
        y_hat = self.forward(x)
        loss = F.mse_loss(y_hat, y)
        self.log('val_loss', loss,logger = True)
        return loss


    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(),lr = self.lr)

### HYPERPARAMETERS

In [None]:
root_dir = 'dataset'
system = 'great-piquant-bumblebee'
validation_split = 0.2
learning_rate = 1e-3
batch_size = 32
max_epochs = 100
shuffle = True
SEED = 42
logdir = './logs/'
num_workers = 4

seed_everything(SEED)

### TRAINING AND VALIDATION DATASET PREPARATION

In [None]:
dataset = TrajDataset(system, root_dir)

state_dim = len(dataset.states)
action_dim = len(dataset.actions)

mean = [x['mean'] for x in dataset.states]  # mean
std = [x['std'] for x in dataset.states]   # std_dev
transform = DatasetTransform(mean, std)

target_mean = [x['mean'] for x in dataset.actions]  # mean
target_std = [x['std'] for x in dataset.actions]   # std_dev
target_transform = DatasetTransform(target_mean, target_std)

dataset.tranform = transform
dataset.target_transform = target_transform


indices = np.arange(len(dataset))

if shuffle is True:
    np.random.shuffle(indices)

split = int(validation_split * len(dataset))
train_indices = indices[split:]
valid_indices = indices[:split]

  # may use wandb later
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

In [None]:
callbacks = []

tb_logger = loggers.TensorBoardLogger(logdir)

train_dataloader = DataLoader(dataset,
            batch_size=batch_size,num_workers=num_workers,
            sampler=train_sampler)
valid_dataloader = DataLoader(dataset,
            batch_size=batch_size,num_workers=num_workers,
            sampler=valid_sampler)

model = LitMLP(state_dim, action_dim,learning_rate)
trainer = pl.Trainer(
    gpus=1, logger=tb_logger,
    callbacks=callbacks,
    progress_bar_refresh_rate=10,
    max_epochs= max_epochs,)

trainer.fit(model, train_dataloader, valid_dataloader)


In [None]:
x,y = dataset[24]
x = torch.Tensor(x)
print(model(x),y)