In [1]:
import gymnasium
import numpy as np

from gran.util.gym_state_control import (
    reset_emulator_state,
    run_emulator_step,
    get_task_info,
    get_task_name,
)
from gran.util.misc import standardize

task = "cart_pole"
emulator = gymnasium.make(get_task_name(task))
x_size, _, _, _ = get_task_info(task)

states = []

for i in range(10000):
    if i == 0 or done:
        obs, done = reset_emulator_state(emulator, 0), False

    states.append(obs)

    obs, rew, done = run_emulator_step(
        emulator, emulator.action_space.sample()
    )

states = standardize(np.array(states))

emulator.close()


ModuleNotFoundError: No module named 'gymnasium'

In [2]:
import torch
import pytorch_lightning as pl

import wandb
from pytorch_lightning.loggers import WandbLogger

with open("../../wandb_key.txt", "r") as f:
    key = f.read()

wandb.login(key=key)

class CartPoleDataset(torch.utils.data.Dataset):
    def __init__(self, data: np.ndarray):
        self.X = torch.tensor(data, dtype=torch.float)
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, index: int):
        return self.X[index]

class CartPoleDataModule(pl.LightningDataModule):
    def __init__(self, data: str, batch_size: int):
        super().__init__()
        self.data = data
        self.batch_size = batch_size
    def setup(self, stage: str):
        self.cartpole_train = CartPoleDataset(self.data)
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.cartpole_train, batch_size=self.batch_size
        )


In [6]:
from gran.grad.model.ae.var.mlp import MLPVAE

wandb.finish()
pl.seed_everything(0)
wandb_logger = WandbLogger()

model = MLPVAE(x_size=x_size, hidden_size=x_size*100, latent_size=x_size*100)
dm = CartPoleDataModule(np.array(states), 10000)

trainer = pl.Trainer(max_epochs=10000, accelerator='gpu', devices=-1, logger=wandb_logger)
trainer.fit(model, dm)
wandb.finish()

INFO:lightning_lite.utilities.seed:Global seed set to 1
INFO:pytorch_lightning.loggers.comet:CometLogger will be initialized in online mode
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 483 K 
1 | decoder | Sequential | 322 K 
---------------------------------------
805 K     Trainable params
0         Non-trainable params
805 K     Total params
3.222     Total estimated model params size (MB)


Epoch 49: 100%|████████| 1/1 [00:00<00:00, 13.50it/s, loss=2.79e+04, v_num=3161]

COMET INFO: Experiment is live on comet.com https://www.comet.com/maximilienlc/general/8839bcfef77d4dc18c2b339175773161



Epoch 999: 100%|███████| 1/1 [00:00<00:00, 12.19it/s, loss=2.49e+04, v_num=3161]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1000` reached.



Epoch 999: 100%|███████| 1/1 [00:00<00:00,  9.79it/s, loss=2.49e+04, v_num=3161]

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/maximilienlc/general/8839bcfef77d4dc18c2b339175773161
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     train/kl_loss [20]    : (19375.384765625, 32517.49609375)
COMET INFO:     train/loss [20]       : (13745.50390625, 42732.2109375)
COMET INFO:     train/recon_loss [20] : (8128.00537109375, 36937.4765625)





COMET INFO:   Others:
COMET INFO:     Created from : pytorch-lightning
COMET INFO:   Uploads:
COMET INFO:     environment details      : 1
COMET INFO:     filename                 : 1
COMET INFO:     git metadata             : 1
COMET INFO:     git-patch (uncompressed) : 1 (14.85 KB)
COMET INFO:     installed packages       : 1
COMET INFO:     notebook                 : 1
COMET INFO:     os packages              : 1
COMET INFO:     source_code              : 1
COMET INFO: ---------------------------
COMET INFO: Uploading 1 metrics, params and output messages
COMET INFO: Waiting for completion of the file uploads (may take several seconds)
COMET INFO: The Python SDK has 10800 seconds to finish before aborting...
COMET INFO: All files uploaded, waiting for confirmation they have been all received


In [5]:
comet_logger.experiment.end()

COMET INFO: Experiment is live on comet.com https://www.comet.com/maximilienlc/general/861d5c80059e4b3992851efbf43a62f0

COMET INFO: -----------------------------------
COMET INFO: Comet.ml ExistingExperiment Summary
COMET INFO: -----------------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/maximilienlc/general/861d5c80059e4b3992851efbf43a62f0
COMET INFO:   Others:
COMET INFO:     Created from : pytorch-lightning
COMET INFO: -----------------------------------
COMET INFO: Uploading 1 metrics, params and output messages


In [12]:
model(torch.tensor(states, dtype=torch.float))[0]

tensor([[ 0.2300,  0.4470, -0.1629, -0.4479],
        [ 0.3193,  0.7621, -0.2560, -0.7443],
        [ 0.2627,  0.4685, -0.2122, -0.4776],
        ...,
        [ 0.1802,  0.2392, -0.1190, -0.2576],
        [ 0.2385,  0.4825, -0.1821, -0.4850],
        [ 0.3299,  0.8030, -0.2808, -0.7885]], grad_fn=<AddmmBackward0>)

In [13]:
states

array([[ 0.51926327,  0.24747598, -0.2024285 , -0.09827513],
       [ 0.5141689 ,  0.6195593 , -0.21317889, -0.49938208],
       [ 0.5523916 ,  0.24999224, -0.29214984, -0.13653122],
       ...,
       [ 0.4953347 , -0.11645585, -0.29040217,  0.17917903],
       [ 0.44787222,  0.25581813, -0.25396302, -0.22501059],
       [ 0.44374904,  0.62801325, -0.28626865, -0.6279266 ]],
      dtype=float32)