# Test Env + Model

In [6]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')

import math
from typing import List, Tuple, Optional, NamedTuple, Dict, Union, Any
from einops import rearrange, repeat
from hydra.utils import instantiate

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn import DataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import lightning as L

from torchrl.envs import EnvBase
from torchrl.envs.utils import step_mdp
from tensordict import TensorDict

from rl4co.data.dataset import TensorDictDataset, TensorDictCollate
from rl4co.models.rl.reinforce import *
from rl4co.models.zoo.am.context import env_context
from rl4co.models.zoo.am.embeddings import env_init_embedding, env_dynamic_embedding
from rl4co.models.zoo.am.encoder import GraphAttentionEncoder
from rl4co.models.zoo.am.decoder import Decoder, decode_probs, PrecomputedCache, LogitAttention
from rl4co.models.zoo.am.policy import get_log_likelihood
from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy
from rl4co.models.nn.attention import NativeFlashMHA, flash_attn_wrapper
from rl4co.utils.lightning import get_lightning_device

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## New implementation

In [14]:
from rl4co.envs import TSPEnv

In [19]:
def get_env(env_name, size):
    if env_name == "tsp":
        env = TSPEnv(num_loc=size)
    else:
        raise NotImplementedError

    return env.transform()


def generate_env_data(env, size):
    env = get_env(env, size)
    dataset = env.dataset([2])

    dataloader = DataLoader(
        dataset,
        batch_size=2,
        shuffle=False,
        num_workers=0,
        collate_fn=TensorDictCollate(),
    )

    return env, next(iter(dataloader))


env, batch = generate_env_data("tsp", 50)

AttributeError: 'NoneType' object has no attribute 'dataset'

In [None]:
env

In [2]:
a = torch.rand(10000, 50, 2)

In [3]:
data = TensorDict({'a': a}, batch_size=a.shape[0])

dataset = TensorDictDataset(data)
dl = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0, collate_fn=TensorDictCollate())#, collate_fn=torch.stack)


batch = next(iter(dl))
print(batch['a'].shape)

torch.Size([512, 50, 2])


In [4]:
%timeit for batch in dl: pass

2.21 ms ± 2.85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Old implementation

In [5]:
class TensorDictDataset2(Dataset):
    """Simple dataset compatible with TensorDicts"""

    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]  # note: use torch.stack to get batch

In [6]:
data = TensorDict({'a': a}, batch_size=a.shape[0])

dataset = TensorDictDataset2(data)
dl = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0, collate_fn=torch.stack)# collate_fn=TensorDictCollate())#, collate_fn=torch.stack)


batch = next(iter(dl))
print(batch['a'].shape)

torch.Size([512, 50, 2])


In [7]:
%timeit for batch in dl: pass

65.1 ms ± 170 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
data = TensorDict({'a': a}, batch_size=a.shape[0])#.memmap_()

# split the batch inside the dictionary as list of dictionaries
data = [{key: value[i] for key, value in data.items()} for i in range(data.shape[0])]


dataset = TensorDictDataset2(data)

dl = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0)#, collate_fn=#custom_collate)


batch = next(iter(dl))
print(batch['a'].shape)

torch.Size([512, 50, 2])


In [10]:
%timeit for batch in dl: pass

2.11 ms ± 3.49 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
