In [10]:
import torch
import torch.nn.functional as F

import time
from IPython.display import clear_output

from example_src.tinyhome import TinyHomeEngineV1, print_grid, print_act
from example_src.buffer import ReplayBuffer

from mamba_lm import MambaLM, MambaLMConfig

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
L = 5
num_actions = 5
num_obs_type = 4

nb_instances = 512
steps = 10000

envs = TinyHomeEngineV1(B=nb_instances, h=L, w=L)
buffer = ReplayBuffer(num_envs=nb_instances, capacity=int(1e6), obs_dim=L*L, act_dim=num_actions)

obs = envs.reset()

for _ in range(steps):
    a = torch.randint(low=0, high=num_actions, size=(nb_instances,))
    next_obs, rew = envs.step(a)

    buffer.store(obs.view(-1, L*L), a, rew.squeeze(1))
    obs = next_obs

In [5]:
config = MambaLMConfig(d_model=16, n_layers=4, vocab_size=num_actions+num_obs_type, pad_vocab_size_multiple=num_actions+num_obs_type)
model = MambaLM(config).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=3e-3)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
sum([p.numel() for p in model.parameters()])

13664

In [7]:
for i in range(1000):    
    B, T = 64, 10
    batch = buffer.sample(B, T)

    obs = torch.tensor(batch['obs']).long().to(device)
    act = torch.tensor(batch['act']).long().to(device)

    tokens = torch.cat([obs, torch.zeros(B, T, 1, dtype=torch.int, device='cuda')], dim=2).view(B, 26*T) # (B, 26T)
    tokens[:, 25::26] = act+4

    input = tokens
    output = tokens[:, 1:].reshape(-1)

    logits = model(tokens[:, :-1]) # (B, 26T-1, vocab_size)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), output)

    optim.zero_grad()
    loss.backward()
    optim.step()

    if i%100==0:
        print(loss.item())

6.719587326049805
0.38883084058761597
0.26524049043655396
0.22898846864700317
0.21574825048446655
0.1874855011701584
0.1682835817337036
0.14247804880142212
0.1338169276714325
0.10260577499866486


In [15]:
tokens = torch.ones(1, 2, dtype=torch.long).cuda() # (B=1, 2)
T = 20
for _ in range(26*T-2):
  logits = model(tokens)[0, -1]
  probs = F.softmax(logits, dim=0)
  sampled = torch.multinomial(probs, num_samples=1, replacement=True)
  tokens = torch.cat([tokens, sampled.view(1, 1)], dim=1)
tokens = tokens.view(T, 26) # (T, 26)

In [16]:
for timestep in tokens:
  grid = timestep[:-1].view(1, 5, 5)
  a = timestep[-1]-4

  clear_output(wait=True)
  print_grid(grid)
  print_act(a.item())
  time.sleep(0.1)

#####
#   #
#   #
#G@ #
#####


E
