HParameters

In [1]:
PAR_seed=123
#PAR_context_length=30
#PAR_epochs=5
PAR_epochs=1
PAR_model_type='reward_conditioned'
#PAR_num_steps=500000
PAR_num_steps=1
PAR_num_buffers=1
#PAR_num_buffers=50
PAR_game='Breakout'
#PAR_batch_size=128

PAR_trajectories_per_buffer=1 # help='Number of trajectories to sample from each of the buffers.')
#PAR_trajectories_per_buffer=10 # help='Number of trajectories to sample from each of the buffers.')
PAR_data_dir_prefix='/content/drive/MyDrive/Deep/UPC/Projecte/datasets/'

# default GPT2 values?
hparams = {
    'batch_size': 2,
    'context_length':30,
    'path':'/content/drive/MyDrive/Deep/UPC/Projecte/datasets/Pong',
}

#MODEL

In [15]:

"""
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
    - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
    - all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
"""

import math
import logging

import torch
import torch.nn as nn
from torch.nn import functional as F

logger = logging.getLogger(__name__)

import numpy as np

class GELU(nn.Module):
    def forward(self, input):
        return F.gelu(input)

class GPTConfig:
    """ base GPT config, params common to all GPT versions """
    embd_pdrop = 0.1
    resid_pdrop = 0.1
    attn_pdrop = 0.1

    def __init__(self, vocab_size, block_size, **kwargs):
        self.vocab_size = vocab_size
        self.block_size = block_size
        for k,v in kwargs.items():
            setattr(self, k, v)

class GPT1Config(GPTConfig):
    """ GPT-1 like network roughly 125M params """
    n_layer = 12
    n_head = 12
    n_embd = 768

class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        # self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
        #                              .view(1, 1, config.block_size, config.block_size))
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size + 1, config.block_size + 1))
                                     .view(1, 1, config.block_size + 1, config.block_size + 1))
        self.n_head = config.n_head

    def forward(self, x, layer_past=None):
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y

class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    """  the full GPT language model, with a context size of block_size """

    def __init__(self, config):
        super().__init__()

        self.config = config

        self.model_type = config.model_type

        # input embedding stem
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        # self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size + 1, config.n_embd))
        self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep+1, config.n_embd))
        self.drop = nn.Dropout(config.embd_pdrop)

        # transformer
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        # decoder head
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.block_size = config.block_size
        self.apply(self._init_weights)


        logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))


        self.state_encoder = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(),
                                 nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(),
                                 nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(),
                                 nn.Flatten(), nn.Linear(3136, config.n_embd), nn.Tanh())

        self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh())

        self.action_embeddings = nn.Sequential(nn.Embedding(config.vocab_size, config.n_embd), nn.Tanh())
        nn.init.normal_(self.action_embeddings[0].weight, mean=0.0, std=0.02)

    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        # whitelist_weight_modules = (torch.nn.Linear, )
        whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        no_decay.add('pos_emb')
        no_decay.add('global_pos_emb')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    # state, action, and return
    def forward(self, states, actions, targets=None, rtgs=None, timesteps=None):
        # states: (batch, block_size, 4*84*84)
        # actions: (batch, block_size, 1)
        # targets: (batch, block_size, 1)
        # rtgs: (batch, block_size, 1)
        # timesteps: (batch, 1, 1)

        print('states ',states.shape)
        print('actions ',actions.shape)
        print('rtgs ',rtgs.shape)
        print('timesteps ',timesteps.shape)

        print('states_reshape ', states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous().shape)
        state_embeddings = self.state_encoder(states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()) # (batch * block_size, n_embd)
        print('state_embeddings ', state_embeddings.shape)
        state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd) # (batch, block_size, n_embd)

        if actions is not None and self.model_type == 'reward_conditioned':
            print('1_rtg_', rtgs.shape)
            rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
            action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd)

            token_embeddings = torch.zeros((states.shape[0], states.shape[1]*3 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device)
            print('token_embeddings ', token_embeddings.shape)
            token_embeddings[:,::3,:] = rtg_embeddings
            token_embeddings[:,1::3,:] = state_embeddings
            token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] + int(targets is None):,:]
        elif actions is None and self.model_type == 'reward_conditioned': # only happens at very first timestep of evaluation
            rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))

            token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2, self.config.n_embd), dtype=torch.float32, device=state_embeddings.device)
            token_embeddings[:,::2,:] = rtg_embeddings # really just [:,0,:]
            token_embeddings[:,1::2,:] = state_embeddings # really just [:,1,:]
        elif actions is not None and self.model_type == 'naive':
            action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd)

            token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device)
            token_embeddings[:,::2,:] = state_embeddings
            token_embeddings[:,1::2,:] = action_embeddings[:,-states.shape[1] + int(targets is None):,:]
        elif actions is None and self.model_type == 'naive': # only happens at very first timestep of evaluation
            token_embeddings = state_embeddings
        else:
            raise NotImplementedError()

        batch_size = states.shape[0]
        print('self.global_pos_emb ', self.global_pos_emb.shape)
        all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd

        print('all_global_pos_emb ', all_global_pos_emb.shape)
        print('2_', torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1).shape)
        print('self.pos_emb ', self.pos_emb.shape)
        print('3_', self.pos_emb[:, :token_embeddings.shape[1], :].shape)
        position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :]

        print('token_embeddings ', token_embeddings.shape)
        print('position_embeddings ', position_embeddings.shape)
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)

        if actions is not None and self.model_type == 'reward_conditioned':
            logits = logits[:, 1::3, :] # only keep predictions from state_embeddings
        elif actions is None and self.model_type == 'reward_conditioned':
            logits = logits[:, 1:, :]
        elif actions is not None and self.model_type == 'naive':
            logits = logits[:, ::2, :] # only keep predictions from state_embeddings
        elif actions is None and self.model_type == 'naive':
            logits = logits # for completeness
        else:
            raise NotImplementedError()

        print('logits_type ', logits.dtype)
        print('targets_type ', targets.dtype)
        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))

        return logits, loss


#UTILS

#TRAINING

In [3]:

"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
"""

import math
import logging

from tqdm import tqdm
import numpy as np

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader

logger = logging.getLogger(__name__)


from collections import deque
import random
import cv2
import torch
from PIL import Image

class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    batch_size = 64
    learning_rate = 3e-4
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1 # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False
    warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
    final_tokens = 260e9 # (at what point we reach 10% of original LR)
    # checkpoint settings
    ckpt_path = None
    num_workers = 0 # for DataLoader

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)

class Trainer:

    def __init__(self, model, train_dataset, test_dataset, config):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.config = config

        # take over whatever gpus are on the system
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

    def save_checkpoint(self):
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        logger.info("saving %s", self.config.ckpt_path)
        # torch.save(raw_model.state_dict(), self.config.ckpt_path)

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split, epoch_num=0):
            is_train = split == 'train'
            model.train(is_train)
            data = self.train_dataset if is_train else self.test_dataset

            loader = DataLoader(data, shuffle=True, pin_memory=True,
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

            losses = []
            pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
            for it, (x, y, r, t) in pbar:

                # place data on the correct device
                x = x.to(self.device)
                y = y.to(self.device)
                r = r.to(self.device)
                t = t.to(self.device)

                # forward the model
                with torch.set_grad_enabled(is_train):
                    # logits, loss = model(x, y, r)
                    #print("in_model",y.shape)
                    logits, loss = model(x, y, y, r, t)
                    loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
                    losses.append(loss.item())

                if is_train:

                    # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                    optimizer.step()

                    # decay the learning rate based on our progress
                    if config.lr_decay:
                        self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
                        else:
                            # cosine learning rate decay
                            progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                            lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                    else:
                        lr = config.learning_rate

                    # report progress
                    pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")

            if not is_train:
                test_loss = float(np.mean(losses))
                logger.info("test loss: %f", test_loss)
                return test_loss

        # best_loss = float('inf')

        best_return = -float('inf')

        self.tokens = 0 # counter used for learning rate decay

        for epoch in range(config.max_epochs):

            run_epoch('train', epoch_num=epoch)
            # if self.test_dataset is not None:
            #     test_loss = run_epoch('test')

            # # supports early stopping based on the test loss, or just save always if no test set is provided
            # good_model = self.test_dataset is None or test_loss < best_loss
            # if self.config.ckpt_path is not None and good_model:
            #     best_loss = test_loss
            #     self.save_checkpoint()

            # -- pass in target returns
            if self.config.model_type == 'naive':
                eval_return = self.get_returns(0)
            elif self.config.model_type == 'reward_conditioned':
                if self.config.game == 'Breakout':
                    eval_return = self.get_returns(90)
                elif self.config.game == 'Seaquest':
                    eval_return = self.get_returns(1150)
                elif self.config.game == 'Qbert':
                    eval_return = self.get_returns(14000)
                elif self.config.game == 'Pong':
                    eval_return = self.get_returns(20)
                else:
                    raise NotImplementedError()
            else:
                raise NotImplementedError()

    def get_returns(self, ret):
        self.model.train(False)
        args=Args(self.config.game.lower(), self.config.seed)
        print("ENV***** ",args)
        env = Env(args)
        env.eval()

        T_rewards, T_Qs = [], []
        done = True
        for i in range(10):
            state = env.reset()
            state = state.type(torch.float32).to(self.device).unsqueeze(0).unsqueeze(0)
            rtgs = [ret]
            # first state is from env, first rtg is target return, and first timestep is 0
            sampled_action = sample(self.model.module, state, 1, temperature=1.0, sample=True, actions=None,
                rtgs=torch.tensor(rtgs, dtype=torch.long).to(self.device).unsqueeze(0).unsqueeze(-1),
                timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(self.device))

            j = 0
            all_states = state
            actions = []
            while True:
                if done:
                    state, reward_sum, done = env.reset(), 0, False
                action = sampled_action.cpu().numpy()[0,-1]
                actions += [sampled_action]
                state, reward, done = env.step(action)
                reward_sum += reward
                j += 1

                if done:
                    T_rewards.append(reward_sum)
                    break

                state = state.unsqueeze(0).unsqueeze(0).to(self.device)

                all_states = torch.cat([all_states, state], dim=0)

                rtgs += [rtgs[-1] - reward]
                # all_states has all previous states and rtgs has all previous rtgs (will be cut to block_size in utils.sample)
                # timestep is just current timestep
                sampled_action = sample(self.model.module, all_states.unsqueeze(0), 1, temperature=1.0, sample=True,
                    actions=torch.tensor(actions, dtype=torch.long).to(self.device).unsqueeze(1).unsqueeze(0),
                    rtgs=torch.tensor(rtgs, dtype=torch.long).to(self.device).unsqueeze(0).unsqueeze(-1),
                    timesteps=(min(j, self.config.max_timestep) * torch.ones((1, 1, 1), dtype=torch.int64).to(self.device)))
        env.close()
        eval_return = sum(T_rewards)/10.
        print("target return: %d, eval return: %d" % (ret, eval_return))
        self.model.train(True)
        return eval_return


class Args:
    def __init__(self, game, seed):
        self.device = torch.device('cuda')
        self.seed = seed
        self.max_episode_length = 108e3
        self.game = game
        self.history_length = 4


# CREATE DATASET

Test load files

In [4]:
# Atari dataset from https://console.cloud.google.com/storage/browser/atari-replay-datasets/dqn downloable with the command:
#!gsutil -m cp -R gs://atari-replay-datasets/dqn/Pong/1/replay_logs .

# But we get the dataset from my drive (Alex)
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Fixed replay buffer class ??

In [5]:
# source: https://github.com/google-research/batch_rl/blob/master/batch_rl/fixed_replay/replay_memory/fixed_replay_buffer.py

import collections
from concurrent import futures
from dopamine.replay_memory import circular_replay_buffer
import numpy as np
import tensorflow.compat.v1 as tf
import gin

gfile = tf.gfile

STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX

class FixedReplayBuffer(object):
  """Object composed of a list of OutofGraphReplayBuffers."""

  def __init__(self, data_dir, replay_suffix, *args, **kwargs):  # pylint: disable=keyword-arg-before-vararg
    """Initialize the FixedReplayBuffer class.
    Args:
      data_dir: str, log Directory from which to load the replay buffer.
      replay_suffix: int, If not None, then only load the replay buffer
        corresponding to the specific suffix in data directory.
      *args: Arbitrary extra arguments.
      **kwargs: Arbitrary keyword arguments.
    """
    self._args = args
    self._kwargs = kwargs
    self._data_dir = data_dir
    self._loaded_buffers = False
    self.add_count = np.array(0)
    self._replay_suffix = replay_suffix
    if not self._loaded_buffers:
      if replay_suffix is not None:
        assert replay_suffix >= 0, 'Please pass a non-negative replay suffix'
        self.load_single_buffer(replay_suffix)
      else:
        self._load_replay_buffers(num_buffers=50)

  def load_single_buffer(self, suffix):
    """Load a single replay buffer."""
    replay_buffer = self._load_buffer(suffix)
    if replay_buffer is not None:
      self._replay_buffers = [replay_buffer]
      self.add_count = replay_buffer.add_count
      self._num_replay_buffers = 1
      self._loaded_buffers = True

  def _load_buffer(self, suffix):
    """Loads a OutOfGraphReplayBuffer replay buffer."""
    try:
      # pytype: disable=attribute-error
      replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(
          *self._args, **self._kwargs)
      replay_buffer.load(self._data_dir, suffix)
      tf.logging.info('Loaded replay buffer ckpt {} from {}'.format(
          suffix, self._data_dir))
      # pytype: enable=attribute-error
      return replay_buffer
    except tf.errors.NotFoundError:
      return None

  def _load_replay_buffers(self, num_buffers=None):
    """Loads multiple checkpoints into a list of replay buffers."""
    if not self._loaded_buffers:  # pytype: disable=attribute-error
      ckpts = gfile.ListDirectory(self._data_dir)  # pytype: disable=attribute-error
      # Assumes that the checkpoints are saved in a format CKPT_NAME.{SUFFIX}.gz
      ckpt_counters = collections.Counter(
          [name.split('.')[-2] for name in ckpts])
      # Should contain the files for add_count, action, observation, reward,
      # terminal and invalid_range
      ckpt_suffixes = [x for x in ckpt_counters if ckpt_counters[x] in [6, 7]]
      if num_buffers is not None:
        ckpt_suffixes = np.random.choice(
            ckpt_suffixes, num_buffers, replace=False)
      self._replay_buffers = []
      # Load the replay buffers in parallel
      with futures.ThreadPoolExecutor(
          max_workers=num_buffers) as thread_pool_executor:
        replay_futures = [thread_pool_executor.submit(
            self._load_buffer, suffix) for suffix in ckpt_suffixes]
      for f in replay_futures:
        replay_buffer = f.result()
        if replay_buffer is not None:
          self._replay_buffers.append(replay_buffer)
          self.add_count = max(replay_buffer.add_count, self.add_count)
      self._num_replay_buffers = len(self._replay_buffers)
      if self._num_replay_buffers:
        self._loaded_buffers = True

  def get_transition_elements(self):
    return self._replay_buffers[0].get_transition_elements()

  def sample_transition_batch(self, batch_size=None, indices=None):
    buffer_index = np.random.randint(self._num_replay_buffers)
    return self._replay_buffers[buffer_index].sample_transition_batch(
        batch_size=batch_size, indices=indices)

  def load(self, *args, **kwargs):  # pylint: disable=unused-argument
    pass

  def reload_buffer(self, num_buffers=None):
    self._loaded_buffers = False
    self._load_replay_buffers(num_buffers)

  def save(self, *args, **kwargs):  # pylint: disable=unused-argument
    pass

  def add(self, *args, **kwargs):  # pylint: disable=unused-argument
    pass

Create Dataset Function

In [6]:
import csv
import logging
# make deterministic

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from torch.utils.data import Dataset

from collections import deque
import random
import torch
import pickle
#import blosc
import argparse
#from fixed_replay_buffer import FixedReplayBuffer

def create_dataset(num_buffers, num_steps, game, data_dir_prefix, trajectories_per_buffer):
    # -- load data from memory (make more efficient)
    obss = []
    actions = []
    returns = [0]
    done_idxs = []
    stepwise_returns = []

    transitions_per_buffer = np.zeros(50, dtype=int)
    num_trajectories = 0
    while len(obss) < num_steps:
        buffer_num = np.random.choice(np.arange(50 - num_buffers, 50), 1)[0]
        i = transitions_per_buffer[buffer_num]
        print('loading from buffer %d which has %d already loaded' % (buffer_num, i))
        print(data_dir_prefix + game + '/1/replay_logs')
        frb = FixedReplayBuffer(
            data_dir=data_dir_prefix + game + '/1/replay_logs',
            replay_suffix=buffer_num,
            observation_shape=(84, 84),
            stack_size=4,
            update_horizon=1,
            gamma=0.99,
            observation_dtype=np.uint8,
            batch_size=32,
            replay_capacity=100000)
        if frb._loaded_buffers:
            done = False
            curr_num_transitions = len(obss) # quants timesteps
            trajectories_to_load = trajectories_per_buffer
            while not done:
                states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i])

                #print("states ",states.shape)
               # plt.imshow(states[0])
                states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84)
                #print("ac ",ac.shape)
                #print("next_states ",next_states.shape)
                #print("next_action ",next_action.shape)
                #print("ret ",ret.shape)

                obss += [states]
                actions += [ac[0]]
                stepwise_returns += [ret[0]]
                if terminal[0]:
                    done_idxs += [len(obss)]
                    returns += [0]
                    if trajectories_to_load == 0:
                        done = True
                    else:
                        trajectories_to_load -= 1
                returns[-1] += ret[0]
                i += 1
                if i >= 100000:
                    obss = obss[:curr_num_transitions]
                    actions = actions[:curr_num_transitions]
                    stepwise_returns = stepwise_returns[:curr_num_transitions]
                    returns[-1] = 0
                    i = transitions_per_buffer[buffer_num]
                    done = True
            num_trajectories += (trajectories_per_buffer - trajectories_to_load)
            transitions_per_buffer[buffer_num] = i
        print('this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' % (i, len(obss), num_trajectories))

    actions = np.array(actions)
    returns = np.array(returns)
    stepwise_returns = np.array(stepwise_returns)
    done_idxs = np.array(done_idxs)

    # -- create reward-to-go dataset
    start_index = 0
    rtg = np.zeros_like(stepwise_returns)
    for i in done_idxs:
        i = int(i)
        curr_traj_returns = stepwise_returns[start_index:i]
        for j in range(i-1, start_index-1, -1): # start from i-1
            rtg_j = curr_traj_returns[j-start_index:i-start_index]
            rtg[j] = sum(rtg_j)
        start_index = i
    print('max rtg is %d' % max(rtg))

    # -- create timestep dataset
    start_index = 0
    timesteps = np.zeros(len(actions)+1, dtype=int)
    for i in done_idxs:
        i = int(i)
        timesteps[start_index:i+1] = np.arange(i+1 - start_index)
        start_index = i+1
    print('max timestep is %d' % max(timesteps))

    return obss, actions, returns, done_idxs, rtg, timesteps


#Environtment

In [7]:


class Env():
    def __init__(self, args):
        self.device = args.device
        self.ale = atari_py.ALEInterface()
        self.ale.setInt('random_seed', args.seed)
        self.ale.setInt('max_num_frames_per_episode', args.max_episode_length)
        self.ale.setFloat('repeat_action_probability', 0)  # Disable sticky actions
        self.ale.setInt('frame_skip', 0)
        self.ale.setBool('color_averaging', False)
        self.ale.loadROM(atari_py.get_game_path(args.game))  # ROM loading must be done after setting options
        actions = self.ale.getMinimalActionSet()
        self.actions = dict([i, e] for i, e in zip(range(len(actions)), actions))
        self.lives = 0  # Life counter (used in DeepMind training)
        self.life_termination = False  # Used to check if resetting only from loss of life
        self.window = args.history_length  # Number of frames to concatenate
        self.state_buffer = deque([], maxlen=args.history_length)
        self.training = True  # Consistent with model training mode

    def _get_state(self):
        state = cv2.resize(self.ale.getScreenGrayscale(), (84, 84), interpolation=cv2.INTER_LINEAR)
        return torch.tensor(state, dtype=torch.float32, device=self.device).div_(255)

    def _reset_buffer(self):
        for _ in range(self.window):
            self.state_buffer.append(torch.zeros(84, 84, device=self.device))

    def reset(self):
        if self.life_termination:
            self.life_termination = False  # Reset flag
            self.ale.act(0)  # Use a no-op after loss of life
        else:
            # Reset internals
            self._reset_buffer()
            self.ale.reset_game()
            # Perform up to 30 random no-ops before starting
            for _ in range(random.randrange(30)):
                self.ale.act(0)  # Assumes raw action 0 is always no-op
                if self.ale.game_over():
                    self.ale.reset_game()
        # Process and return "initial" state
        observation = self._get_state()
        self.state_buffer.append(observation)
        self.lives = self.ale.lives()
        return torch.stack(list(self.state_buffer), 0)

    def step(self, action):
        # Repeat action 4 times, max pool over last 2 frames
        frame_buffer = torch.zeros(2, 84, 84, device=self.device)
        reward, done = 0, False
        for t in range(4):
            reward += self.ale.act(self.actions.get(action))
            if t == 2:
                frame_buffer[0] = self._get_state()
            elif t == 3:
                frame_buffer[1] = self._get_state()
            done = self.ale.game_over()
            if done:
                break
        observation = frame_buffer.max(0)[0]
        self.state_buffer.append(observation)
        # Detect loss of life as terminal in training mode
        if self.training:
            lives = self.ale.lives()
            if lives < self.lives and lives > 0:  # Lives > 0 for Q*bert
                self.life_termination = not done  # Only set flag when not truly done
                done = True
            self.lives = lives
        # Return state, reward, done
        return torch.stack(list(self.state_buffer), 0), reward, done

    # Uses loss of life as terminal signal
    def train(self):
        self.training = True

    # Uses standard terminal signal
    def eval(self):
        self.training = False

    def action_space(self):
        return len(self.actions)

    def render(self):
        cv2.imshow('screen', self.ale.getScreenRGB()[:, :, ::-1])
        cv2.waitKey(1)

    def close(self):
        cv2.destroyAllWindows()

#->GO

In [8]:
import csv
import logging
# make deterministic
#from mingpt.utils import set_seed
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from torch.utils.data import Dataset
#from mingpt.model_atari import GPT, GPTConfig
#from mingpt.trainer_atari import Trainer, TrainerConfig
#from mingpt.utils import sample
from collections import deque
import random
import torch
import pickle
#import blosc
import argparse
#from create_dataset import create_dataset
import matplotlib.pyplot as plt
import matplotlib.image as mpimg



class StateActionReturnDataset(Dataset):

    def __init__(self, data, block_size, actions, done_idxs, rtgs, timesteps):
        self.block_size = block_size
        self.vocab_size = max(actions) + 1
        self.data = data
        self.actions = actions
        self.done_idxs = done_idxs
        self.rtgs = rtgs
        self.timesteps = timesteps

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        block_size = self.block_size // 3
        done_idx = idx + block_size
        for i in self.done_idxs:
            if i > idx: # first done_idx greater than idx
                done_idx = min(int(i), done_idx)
                break
        idx = done_idx - block_size
        states = torch.tensor(np.array(self.data[idx:done_idx]), dtype=torch.float32).reshape(block_size, -1) # (block_size, 4*84*84)
        states = states / 255.
        actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1)
        rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1)
        timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1)

        return states, actions, rtgs, timesteps



In [9]:
obss, actions, returns, done_idxs, rtgs, timesteps = create_dataset(PAR_num_buffers, PAR_num_steps, PAR_game, PAR_data_dir_prefix, PAR_trajectories_per_buffer)

# set up logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

train_dataset = StateActionReturnDataset(obss, hparams['context_length']*3, actions, done_idxs, rtgs, timesteps)


loading from buffer 49 which has 0 already loaded
/content/drive/MyDrive/Deep/UPC/Projecte/datasets/Breakout/1/replay_logs




this buffer has 2006 loaded transitions and there are now 2006 transitions total divided into 1 trajectories
max rtg is 67
max timestep is 1842


In [16]:
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=6, n_head=8, n_embd=128, model_type=PAR_model_type, max_timestep=max(timesteps))
model = GPT(mconf)


In [17]:

# initialize a trainer instance and kick off training
epochs = PAR_epochs
tconf = TrainerConfig(max_epochs=epochs, batch_size=hparams['batch_size'], learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*hparams['context_length']*3,
                      num_workers=4, seed=PAR_seed, model_type=PAR_model_type, game=PAR_game, max_timestep=max(timesteps))
trainer = Trainer(model, train_dataset, None, tconf)

trainer.train()

epoch 1 iter 2: train loss 1.86223. lr 1.054687e-05:   0%|          | 3/958 [00:00<01:31, 10.41it/s]

states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_

epoch 1 iter 7: train loss 1.41208. lr 2.812500e-05:   1%|          | 6/958 [00:00<01:01, 15.37it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 11: train loss 1.50917. lr 4.218750e-05:   1%|▏         | 12/958 [00:00<00:50, 18.81it/s]

 torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30,

epoch 1 iter 16: train loss 1.40896. lr 5.976562e-05:   2%|▏         | 15/958 [00:00<00:46, 20.15it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 20: train loss 1.45893. lr 7.382812e-05:   2%|▏         | 21/958 [00:01<00:44, 21.00it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 25: train loss 1.52335. lr 9.140625e-05:   3%|▎         | 24/958 [00:01<00:44, 20.99it/s]

states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_

epoch 1 iter 29: train loss 1.06452. lr 1.054687e-04:   3%|▎         | 30/958 [00:01<00:43, 21.12it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 34: train loss 1.47586. lr 1.230469e-04:   3%|▎         | 33/958 [00:01<00:43, 21.33it/s]

states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_

epoch 1 iter 39: train loss 1.32099. lr 1.406250e-04:   4%|▍         | 39/958 [00:02<00:42, 21.75it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 43: train loss 1.19452. lr 1.546875e-04:   4%|▍         | 42/958 [00:02<00:44, 20.79it/s]

states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_

epoch 1 iter 48: train loss 1.08526. lr 1.722656e-04:   5%|▌         | 48/958 [00:02<00:42, 21.52it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 52: train loss 1.10582. lr 1.863281e-04:   5%|▌         | 51/958 [00:02<00:41, 21.66it/s]

state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90

epoch 1 iter 57: train loss 1.70658. lr 2.039062e-04:   6%|▌         | 57/958 [00:02<00:42, 21.37it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 61: train loss 1.52241. lr 2.179687e-04:   6%|▋         | 60/958 [00:03<00:41, 21.61it/s]

states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_

epoch 1 iter 65: train loss 1.20162. lr 2.320312e-04:   7%|▋         | 66/958 [00:03<00:42, 21.17it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 70: train loss 1.12776. lr 2.496094e-04:   7%|▋         | 69/958 [00:03<00:41, 21.17it/s]

states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_

epoch 1 iter 74: train loss 1.31431. lr 2.636719e-04:   8%|▊         | 75/958 [00:03<00:41, 21.17it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 79: train loss 1.15102. lr 2.812500e-04:   8%|▊         | 78/958 [00:03<00:41, 21.17it/s]

states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_

epoch 1 iter 83: train loss 1.31355. lr 2.953125e-04:   9%|▉         | 84/958 [00:04<00:42, 20.40it/s]

 torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_ty

epoch 1 iter 87: train loss 1.17361. lr 3.093750e-04:   9%|▉         | 87/958 [00:04<00:42, 20.63it/s]

token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  t

epoch 1 iter 92: train loss 1.03117. lr 3.269531e-04:  10%|▉         | 93/958 [00:04<00:40, 21.28it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 96: train loss 0.95165. lr 3.410156e-04:  10%|█         | 96/958 [00:04<00:40, 21.04it/s]

torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.

epoch 1 iter 101: train loss 1.08825. lr 3.585937e-04:  10%|█         | 99/958 [00:04<00:40, 21.02it/s]

logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

epoch 1 iter 105: train loss 0.77126. lr 3.726562e-04:  11%|█         | 105/958 [00:05<00:40, 21.12it/s]

states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_

epoch 1 iter 107: train loss 0.98302. lr 3.796875e-04:  11%|█▏        | 108/958 [00:05<00:41, 20.34it/s]


logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 128])
all_global_pos_emb  torch.Size([2, 1843, 128])
2_ torch.Size([2, 1, 128])
self.pos_emb  torch.Size([1, 91, 128])
3_ torch.Size([1, 90, 128])
token_embeddings  torch.Size([2, 90, 128])
position_embeddings  torch.Size([2, 90, 128])
logits_type  torch.float32
targets_type  torch.int64
states  torch.Size([2, 30, 28224])
actions  torch.Size([2, 30, 1])
rtgs  torch.Size([2, 30, 1])
timesteps  torch.Size([2, 1, 1])
states_reshape  torch.Size([60, 4, 84, 84])
state_embeddings  torch.Size([60, 128])
1_rtg_ torch.Size([2, 30, 1])
token_embeddings  torch.Size([2, 90, 128])
self.global_pos_emb  torch.Size([1, 1843, 12

Exception in thread Thread-14 (_pin_memory_loop):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/pin_memory.py", line 54, in _pin_memory_loop
    do_one_step()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/pin_memory.py", line 31, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/reductions.py", line 355, in rebuild_storage_fd
    fd = df.detach()
  File "/usr/lib/python3.10/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/usr/lib/python3

KeyboardInterrupt: 

    buf = self._recv_bytes(maxlength)
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 414, in _recv_bytes
    buf = self._recv(4)
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer


In [4]:
import random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None, actions=None, rtgs=None, timesteps=None):
    """
    take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
    the sequence, feeding the predictions back into the model each time. Clearly the sampling
    has quadratic complexity unlike an RNN that is only linear, and has a finite context window
    of block_size, unlike an RNN that has an infinite context window.
    """
    block_size = model.get_block_size()
    model.eval()
    for k in range(steps):
        # x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
        x_cond = x if x.size(1) <= block_size//3 else x[:, -block_size//3:] # crop context if needed
        if actions is not None:
            actions = actions if actions.size(1) <= block_size//3 else actions[:, -block_size//3:] # crop context if needed
        rtgs = rtgs if rtgs.size(1) <= block_size//3 else rtgs[:, -block_size//3:] # crop context if needed
        logits, _ = model(x_cond, actions=actions, targets=None, rtgs=rtgs, timesteps=timesteps)
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        # x = torch.cat((x, ix), dim=1)
        x = ix

    return x


In [None]:
a=np.array([[1, 2, 3], [4, 5, 6]])
b=np.array([1, 2])
np.savez('/tmp/123.npz', a=a, b=b)
data = np.load('/tmp/123.npz')
data['a']
data['b']
data.close()