In [1]:
import tqdm

import torch
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer

from src.models import CausalBERTFeatureExtractor, CausalBertLMPolicyWrapper, base_config
from src.environments import AddGymEnv
from src.datasets import Environment_Gold_Dataset

from stable_baselines3 import PPO, DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

%load_ext autoreload
%autoreload 2
%load_ext line_profiler
%load_ext memory_profiler

ModuleNotFoundError: No module named 'gym'

In [2]:
env_kwargs = dict(max_token_length=24, max_val=1)
env = make_vec_env(AddGymEnv, 20, vec_env_cls=SubprocVecEnv, env_kwargs=env_kwargs)
single_env = AddGymEnv(**env_kwargs)

In [3]:
pad_id=0

policy_kwargs = dict(
    features_extractor_class=CausalBERTFeatureExtractor,
    features_extractor_kwargs=dict(config=base_config, action_space=env.action_space, pad_id=pad_id),
)

learner = PPO('MlpPolicy', env, policy_kwargs=policy_kwargs, n_steps=256, verbose=10, batch_size=256)
learner.policy.action_net  = nn.Linear(learner.policy.mlp_extractor.latent_dim_pi, 110, bias=False)

Using cuda device


In [47]:
obs = env.reset()
for i in range(10):
    action, _states = learner.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()
    if dones:
        env.reset()

BrokenPipeError: [Errno 32] Broken pipe

## Wrapping a policy as a language model

In [6]:
lm_model = CausalBertLMPolicyWrapper(learner.policy)

In [7]:
ids = torch.randint(0,100, size=(4,7), device='cuda')
lm_model.generate(ids)

tensor([[ 24,  20,  72,  76,  35,  87,  48,  69,  57,  45,  93,  11,  93,  34,
         107,  85,  57,  93,  25,  97],
        [ 71,  87,   0,  42,   5,  63,  18,  29,  40,   1,   8,  34, 107, 107,
         107,  62,  54,  41,  59,  71],
        [ 30,  74,  20,  86,  24,  31,  48,  93,  93, 101,  95,  88,  86,  69,
          34,  52,  91,  83,  86,  80],
        [ 84,  29,  67,  26,  17,  40,  16, 107,  44,  11,  25,  86,  95,  69,
          24,  71,  29,  54,  29,  88]], device='cuda:0')

In [23]:
env_dataset = Environment_Gold_Dataset(single_env, 1000)

In [24]:
env_dataloader = env_dataset.to_dataloader(1)

In [25]:
batch = next(iter(env_dataloader))

In [26]:
batch['input_ids']

tensor([[  1, 108, 107, 103,  24,  19,  24,  39,   6,  24,  19,  24,   8,  24,
           5,   7,  24]])

In [27]:
batch['target_policies'].argmax(-1)

tensor([[108, 107, 103,  24,  19,  24,  39,   6,  24,  19,  24,   8,  24,   5,
           7,  24,   2]])

In [28]:
batch['grad_mask']

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True]])

In [19]:
lm_model.to('cpu')

CausalBertLMPolicyWrapper(
  (policy): ActorCriticPolicy(
    (features_extractor): CausalBERTFeatureExtractor(
      (transformer): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(110, 256, padding_idx=0)
          (position_embeddings): Embedding(512, 256)
          (token_type_embeddings): Embedding(2, 256)
          (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=256, out_features=256, bias=True)
                  (key): Linear(in_features=256, out_features=256, bias=True)
                  (value): Linear(in_features=256, out_features=256, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                  (distance_embedding

In [29]:
pl_trainer = Trainer(gpus=1, gradient_clip_val=0.5, amp_level='O1', max_epochs=200)
pl_trainer.fit(lm_model, env_dataloader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type              | Params
--------------------------------------------------
0 | policy      | ActorCriticPolicy | 10.1 M
1 | action_head | Linear            | 28.2 K
2 | value_head  | Linear            | 257   
3 | dropout     | Dropout           | 0     
--------------------------------------------------
10.1 M    Trainable params
0         Non-trainable params
10.1 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [16]:
logits = lm_model(batch['input_ids']).logits

In [17]:
print(single_env.tokenizer.decode(batch['input_ids'][0].tolist(), skip_special_tokens=False))
print('trainable tokens')
single_env.tokenizer.decode(batch['input_ids'][0][batch['grad_mask'][0]].tolist(), skip_special_tokens=False)

[BOS]What is 0+0?[SP]0+0>>>0[NL][ESP]0
trainable tokens


'?[SP]0+0[NL][ESP]0'

In [12]:
print('target_tokens')
single_env.tokenizer.decode(batch['target_policies'].argmax(-1)[0][batch['grad_mask'][0]].tolist(), skip_special_tokens=False)

target_tokens


'[SP]0+0>>>[ESP]0[EOS]'

In [110]:
single_env.tokenizer.decode(logits[0,-3].argsort(descending=True)[:10].tolist(), skip_special_tokens=False)

'1+[SP]5234976'

In [30]:
_ = lm_model.eval()
ids = single_env.tokenizer.encode('[BOS]What is 0+0?[]').ids
input_ids = torch.tensor([ids], dtype=torch.long)
logits = lm_model(input_ids).logits
last_logits = logits[0][-1].softmax(dim=-1)
indices = torch.argsort(last_logits, descending=True, dim=0)
for idx in indices.reshape(-1,1)[:10]:
    print(f"{single_env.tokenizer.decode([idx], skip_special_tokens=False)} -> {float(100*last_logits[idx]):0.1f}%")

+ -> 41.2%
? -> 23.9%
[SP] -> 18.7%
0 -> 14.4%
  -> 0.5%
>>> -> 0.2%
[NL] -> 0.1%
 is -> 0.0%
[ESP] -> 0.0%
a -> 0.0%


In [12]:
action, _states = learner.predict(obs)

In [13]:
env.tokenizer.decode([action], skip_special_tokens=False)

'[SP]'

In [51]:
single_env.gold_state

[1, 108, 107, 103, 31, 19, 24, 39, 6, 31, 19, 24, 8, 31, 5, 7, 31, 2]