In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset

from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, setup_logging, CfgNode as CN

import math
import random
from create_dataset import create_dataset
from run_shortestroute import StateActionReturnDataset

In [7]:
obss, actions, returns, done_idxs, rtgs, timesteps, env = create_dataset(10*1000)

train_dataset = StateActionReturnDataset(obss, 10 * 3, actions, done_idxs, rtgs, timesteps)

# Set up configs
C = CN()

## system
C.system = CN()
C.system.seed = 3407
C.system.work_dir = './out/decgpt'

# data
C.data = StateActionReturnDataset.get_default_config()

# model
C.model = GPT.get_default_config()
C.model.model_type = 'gpt-mini'
C.model.vocab_size = train_dataset.vocab_size
C.model.block_size = train_dataset.block_size
C.model.max_timestep = max(timesteps)

# trainer
C.trainer = Trainer.get_default_config()
C.trainer.learning_rate = 3e-5
C.trainer.max_epochs = 10_000
C.trainer.num_workers = 4
C.trainer.lr_decay = True
C.trainer.warmup_tokens = 30_000
C.trainer.final_tokens = 100_000_000

model = GPT(C.model)

# initialize a trainer instance and kick off training
trainer = Trainer(model, train_dataset, C.trainer)
trainer.run()

print(env.adj_mat)
env.draw()

max rtg is 0
max timestep is 129
number of parameters: 2.71M
running on device cuda


epoch 1 iter 155: train loss 1.86102. lr 2.999996e-05: 100%|██████████| 156/156 [00:02<00:00, 67.05it/s]
epoch 2 iter 14: train loss 1.91429. lr 2.999995e-05:   9%|▉         | 14/156 [00:00<00:02, 67.51it/s]

In [None]:
model.eval()
with torch.no_grad():
    done = True
    for i in range(1):
        # sample from the model
        state = torch.tensor(env.reset())
        state = state.type(torch.float32).to(trainer.device).unsqueeze(0).unsqueeze(0)
        rtgs = [-2]
        # first state is from env, first rtg is target return, and first timestep is 0
        sampled_action = model.generate(state, 1, temperature=1.0, do_sample=True, actions=None, 
                            rtgs=torch.tensor(rtgs, dtype=torch.long).to(trainer.device).unsqueeze(0).unsqueeze(-1),
                            timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(trainer.device))


        j = 0
        all_states = state
        actions = []
        while True:
            if done:
                state, reward_sum, done = torch.tensor(env.reset()), 0, False
            action = sampled_action.cpu().numpy()[0, -1]
            actions += [sampled_action]
            state, reward, done, = env.step(action)
            reward_sum += reward
            j += 1

            if done:
                break

            state = torch.tensor(state).unsqueeze(0).unsqueeze(0).to(trainer.device)

            all_states = torch.cat([all_states, state], dim=0)

            rtgs += [rtgs[-1] + reward]
            print(all_states.shape)
            # print((rtgs[-1], action, state))
            # all_states has all previous states and rtgs has all revious rtgs (will be cut to block_size in trainer.generate)
            # timestep is just the current timestep
            sampled_action = model.generate(all_states.unsqueeze(0), 1, temperature=1.0, do_sample=True, 
                            actions=torch.tensor(actions, dtype=torch.long).to(trainer.device).unsqueeze(-1).unsqueeze(0),
                            rtgs=torch.tensor(rtgs, dtype=torch.long).to(trainer.device).unsqueeze(0).unsqueeze(-1),
                            timesteps=(min(j, model.config.max_timestep) * torch.ones((1, 1, 1), dtype=torch.int64).to(trainer.device)))
print((rtgs[-1], action, state))

(-3, 0, tensor([[0]], device='cuda:0'))
(-4, 4, tensor([[0]], device='cuda:0'))
(-5, 10, tensor([[0]], device='cuda:0'))
(-6, 0, tensor([[0]], device='cuda:0'))
(-7, 6, tensor([[0]], device='cuda:0'))
(-8, 4, tensor([[0]], device='cuda:0'))
(-9, 1, tensor([[0]], device='cuda:0'))
(-10, 8, tensor([[0]], device='cuda:0'))
(-11, 1, tensor([[0]], device='cuda:0'))
(-12, 0, tensor([[0]], device='cuda:0'))
(-13, 4, tensor([[0]], device='cuda:0'))
(-14, 8, tensor([[0]], device='cuda:0'))
(-15, 4, tensor([[0]], device='cuda:0'))
(-16, 0, tensor([[0]], device='cuda:0'))
(-17, 0, tensor([[0]], device='cuda:0'))
(-18, 1, tensor([[0]], device='cuda:0'))
(-19, 0, tensor([[0]], device='cuda:0'))
(-20, 4, tensor([[0]], device='cuda:0'))
(-21, 0, tensor([[0]], device='cuda:0'))
(-22, 1, tensor([[0]], device='cuda:0'))
(-23, 6, tensor([[0]], device='cuda:0'))
(-24, 4, tensor([[0]], device='cuda:0'))
(-25, 10, tensor([[0]], device='cuda:0'))
(-26, 1, tensor([[0]], device='cuda:0'))
(-27, 8, tensor([[0]]

KeyboardInterrupt: 

In [None]:
state = torch.tensor(env.reset())
state = state.type(torch.float32).to(trainer.device).unsqueeze(0).unsqueeze(0)
# print(state)
rtgs = torch.tensor(-2, dtype=torch.long).to(trainer.device).unsqueeze(0).unsqueeze(-1)
model.generate(state, 1, temperature=0.1, do_sample=True, rtgs=rtgs, timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(trainer.device))

tensor([[5]], device='cuda:0')