In [1]:
import torch
from typing import Any
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from Traj2Dataset import TrajDataset,DatasetTransform
from torch.utils.data import DataLoader,SubsetRandomSampler
import pytorch_lightning as pl
from pytorch_lightning import loggers
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import seed_everything


In [14]:
class Net(nn.Module):
    def __init__(self, in_dims, out_dims):
        super(Net, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dims,16),
            nn.ReLU(),
            nn.Linear(16,16),
            nn.ReLU(),
            nn.Linear(16, out_dims)
        )

    def forward(self, x):
        return self.layers(x)


In [3]:
class LitMLP(pl.LightningModule):
    def __init__(
        self,
        in_dims: int,
        out_dims: int,
        lr: float = 1e-3,
        *args: Any,
        **kwargs: Any
    ) -> None:
        super().__init__(*args, **kwargs)
        
        self.model = Net(in_dims, out_dims)
        self.lr = lr
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
    
    def training_step(self,batch,batch_idx):
        x,y = batch
        y_hat = self.forward(x)
        loss = F.mse_loss(y_hat, y)
        self.log('train_loss',loss,logger = True)
        return loss

    def validation_step(self,batch,batch_idx):
        x,y = batch
        y_hat = self.forward(x)
        loss = F.mse_loss(y_hat, y)
        self.log('val_loss', loss,logger = True)
        return loss


    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(),lr = self.lr)

### HYPERPARAMETERS

In [4]:
root_dir = 'dataset'
system = 'great-piquant-bumblebee'
validation_split = 0.2
learning_rate = 1e-3
batch_size = 32
max_epochs = 100
shuffle = True
SEED = 42
logdir = './logs/'
num_workers = 4

seed_everything(SEED)

Global seed set to 42


42

### TRAINING AND VALIDATION DATASET PREPARATION

In [5]:
dataset = TrajDataset(system, root_dir)

state_dim = len(dataset.states)
action_dim = len(dataset.actions)

mean = [x['mean'] for x in dataset.states]  # mean
std = [x['std'] for x in dataset.states]   # std_dev
transform = DatasetTransform(mean, std)

target_mean = [x['mean'] for x in dataset.actions]  # mean
target_std = [x['std'] for x in dataset.actions]   # std_dev
target_transform = DatasetTransform(target_mean, target_std)

dataset.tranform = transform
dataset.target_transform = target_transform


indices = np.arange(len(dataset))

if shuffle is True:
    np.random.shuffle(indices)

split = int(validation_split * len(dataset))
train_indices = indices[split:]
valid_indices = indices[:split]

  # may use wandb later
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

In [6]:
callbacks = []

tb_logger = loggers.TensorBoardLogger(logdir)

train_dataloader = DataLoader(dataset,
            batch_size=batch_size,num_workers=num_workers,
            sampler=train_sampler)
valid_dataloader = DataLoader(dataset,
            batch_size=batch_size,num_workers=num_workers,
            sampler=valid_sampler)

model = LitMLP(state_dim, action_dim,learning_rate)
trainer = pl.Trainer(
    gpus=1, logger=tb_logger,
    callbacks=callbacks,
    progress_bar_refresh_rate=60,
    max_epochs= max_epochs,)

trainer.fit(model, train_dataloader, valid_dataloader)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 482   
-------------------------------
482       Trainable params
0         Non-trainable params
482       Total params
0.002     Total estimated model params size (MB)


                                                              

Global seed set to 42


Epoch 99: 100%|██████████| 125/125 [00:01<00:00, 90.70it/s, loss=2.31, v_num=11]


In [18]:
x,y = dataset[0]
x = torch.Tensor(x)
print(model(x),y)

tensor([-5.6737,  3.7289], grad_fn=<AddBackward0>) [-21.243172  15.410698]


In [15]:
model_checkpoint = LitMLP.load_from_checkpoint("./logs/default/version_11/checkpoints/epoch=99-step=9999.ckpt",
                                               in_dims=10, out_dims=2, lr=1e-3)
model_checkpoint.eval()

In [17]:
print(x,model(x),y)

tensor([ 1.2797, -0.9566,  0.0000, -0.9566, -4.0413,  4.0828,  0.0000,  4.0828,
         1.2387, -0.9151]) tensor([-4.9117,  4.6574], grad_fn=<AddBackward0>) [-3.8265183  4.1371217]
