In [1]:
import gymnasium
import numpy as np

import sys
np.set_printoptions(threshold=sys.maxsize)

from gran.util.gym_fb_control import (
    reset_env_state,
    run_env_step,
    get_task_info,
    get_task_name,
)


task = "cart_pole"
env = gymnasium.make(get_task_name(task))
x_size, _, _, _ = get_task_info(task)

obs_list, rew_list, done_list, action_list = [], [], [], []

for episode_nb in range(300):
    
    obs = reset_env_state(env, 0)
    done, rew = False, np.nan
    
    while True:
        
        obs_list.append(obs)
        rew_list.append(rew)
        done_list.append(done)
        
        if done:
            action_list.append([np.nan, np.nan])
            break
        
        action = env.action_space.sample()
        action_list.append([1,0] if action == 0 else [0,1])
        
        obs, rew, done = run_env_step(env, action)
        
obs_array = np.array(obs_list, dtype=np.float32)
rew_array = np.array(rew_list, dtype=np.float32)
done_array = np.array(done_list, dtype=np.float32)
action_array = np.array(action_list, dtype=np.float32)

env.close()

In [3]:
import wandb
from pytorch_lightning.loggers import WandbLogger

with open("../../wandb_key.txt", "r") as f:
    key = f.read()

wandb.login(key=key)

[34m[1mwandb[0m: Currently logged in as: [33mmaximilienlc[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [5]:
from gran.bprop.model.ae.mlp import MLPAE

wandb.finish()
pl.seed_everything(0)
wandb_logger = WandbLogger()

dm = AEDataModule(data=obs_array, batch_size=10000)
model = MLPAE(x_size)

trainer = pl.Trainer(max_epochs=1000, accelerator='gpu', devices=1, logger=wandb_logger, enable_progress_bar=False)
trainer.fit(model, dm)
wandb.finish()

INFO:lightning_lite.utilities.seed:Global seed set to 0


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 10.6 K
1 | decoder | Sequential | 10.5 K
---------------------------------------
21.1 K    Trainable params
0         Non-trainable params
21.1 K    Total params
0.084     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1000` reached.


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,999.0
train_loss,5e-05
trainer/global_step,999.0
val_loss,6e-05


In [6]:
model(torch.tensor(obs_array, dtype=torch.float))

tensor([[ 0.0287, -0.0252, -0.0534, -0.0450],
        [ 0.0195,  0.1994, -0.0551, -0.3548],
        [ 0.0311, -0.0204, -0.0608, -0.0722],
        ...,
        [-0.0134,  0.4108, -0.1734, -1.2046],
        [ 0.0014,  0.2301, -0.1913, -0.9925],
        [ 0.0226,  0.0544, -0.2043, -0.7890]], grad_fn=<AddmmBackward0>)

In [7]:
obs_array

array([[ 0.01369617, -0.02302133, -0.04590265, -0.04834723],
       [ 0.01323574,  0.17272775, -0.04686959, -0.3551522 ],
       [ 0.0166903 , -0.02169755, -0.05397264, -0.07760915],
       ...,
       [-0.00378081,  0.3945922 , -0.17737961, -1.2545196 ],
       [ 0.00411104,  0.20212865, -0.20247   , -1.0222306 ],
       [ 0.00815361,  0.0101946 , -0.22291462, -0.7993309 ]],
      dtype=float32)

In [None]:
latents = model.encoder(torch.tensor(obs_list, dtype=torch.float))

# AR

In [5]:
from gran.bprop.model.ar.mdnrnn import MDNRNN

model = MDNRNN().load_from_checkpoint("lightning_logs/vrwd4hg8/checkpoints/epoch=49999-step=50000.ckpt")
data = [obs_array, action_array, rew_array, done_array]
dm = ARDataModule(data=data, batch_size=300)
dm.setup("train")

In [6]:
from gran.bprop.model.ar.mdnrnn import MDNRNN

wandb.finish()
pl.seed_everything(0)
wandb_logger = WandbLogger()

data = [obs_array, action_array, rew_array, done_array]
dm = ARDataModule(data=data, batch_size=6)
model = MDNRNN()

trainer = pl.Trainer(max_epochs=50000, accelerator='gpu', devices=1, logger=wandb_logger, enable_progress_bar=False)
trainer.fit(model, dm)
wandb.finish()

INFO:lightning_lite.utilities.seed:Global seed set to 0


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type   | Params
--------------------------------
0 | lstm | LSTM   | 43.2 K
1 | fc   | Linear | 4.7 K 
--------------------------------
47.9 K    Trainable params
0         Non-trainable params
47.9 K    Total params
0.192     Total estimated model params size (MB)
  rank_zero_warn(


tensor([[0, 0, 0],
        [0, 0, 1],
        [0, 0, 2],
        ...,
        [5, 9, 2],
        [5, 9, 3],
        [5, 9, 4]], device='cuda:0')


NameError: name 'q' is not defined

In [14]:
dir(model)

['CHECKPOINT_HYPER_PARAMS_KEY',
 'CHECKPOINT_HYPER_PARAMS_NAME',
 'CHECKPOINT_HYPER_PARAMS_TYPE',
 'T_destination',
 '_DeviceDtypeModuleMixin__update_properties',
 '_LightningModule__check_allowed',
 '_LightningModule__check_not_nested',
 '_LightningModule__to_tensor',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__jit_unused_properties__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_apply_batch_transfer_handler',
 '_automatic_optimization',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_batch_hook',
 '_call_impl',
 '_compiler_ctx',
 '_current_fx_name',
 '_device',
 '_dtype',
 '_example_input_ar

In [26]:
dm.dataset[1][0]

tensor([[ 0.4825,  0.1949, -0.1110, -0.0339,  0.0000,  1.0000],
        [ 0.4772,  0.5611, -0.1210, -0.4043,  0.0000,  1.0000],
        [ 0.5166,  0.9272, -0.1940, -0.7749,  0.0000,  1.0000],
        [ 0.6005,  1.2935, -0.3301, -1.1481,  0.0000,  1.0000],
        [ 0.7291,  1.6601, -0.5298, -1.5259,  0.0000,  1.0000],
        [ 0.9023,  2.0269, -0.7939, -1.9103,  1.0000,  0.0000],
        [ 1.1203,  1.6647, -1.1234, -1.6018,  0.0000,  1.0000],
        [ 1.2940,  2.0325, -1.4004, -2.0051,  0.0000,  1.0000],
        [ 1.5126,  2.4003, -1.7461, -2.4164,  0.0000,  1.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0

In [31]:
model(dm.dataset[0:2][0], torch.tensor([16,8]))[1].shape

torch.Size([24, 5, 4])

In [8]:
a = torch.tensor([[ 0, 18],
        [ 1, 12],
        [ 2,  8],
        [ 3, 19],
        [ 4, 24],
        [ 5,  9],
        [ 6, 45],
        [ 7,  7],
        [ 8, 20],
        [ 9,  8],
        [10, 24],
        [11,  9],
        [12,  8],
        [13, 25],
        [14,  9],
        [15, 20],
        [16, 33],
        [17, 19],
        [18, 11],
        [19,  9]], device='cuda:0')

In [11]:
b = torch.tensor([[ 0, 18]], device='cuda:0')

In [15]:
a[:, 1]

tensor([18, 12,  8, 19, 24,  9, 45,  7, 20,  8, 24,  9,  8, 25,  9, 20, 33, 19,
        11,  9], device='cuda:0')