In [1]:
import torch
import torch.nn as nn
# import torch_geometric.nn as pyg_nn
# from prompt_dt.graph_prompt_decision_transformer import *

In [2]:
hidden_size = 128
num_heads = 1
dropout = 1e-5
n_layers = 3
device = 'cuda:1'

## PyG Graph Transformer

In [3]:
pyg_graph_transformer = nn.ModuleList([Block(hidden_size, hidden_size, num_heads=num_heads, dropout=dropout, layer_norm=True, batch_norm=False) for _ in range(n_layers)])
pyg_graph_transformer.to(device)
x = torch.load('ant_sample_x.pt')
edge = torch.load('ant_sample_edge.pt')
print(f'x shape: {x.shape}')
print(f'edge shape: {edge.shape}')
for block in pyg_graph_transformer:
    x = block(x, edge_index=edge)

print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(1)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(1)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(1)/1024/1024/1024))

x shape: torch.Size([30375, 128])
edge shape: torch.Size([2, 399159])
torch.cuda.memory_allocated: 2.175523GB
torch.cuda.memory_reserved: 2.398438GB
torch.cuda.max_memory_reserved: 2.398438GB


## Torch Transformer with Mask

In [3]:
encoder_layer = nn.TransformerEncoderLayer(hidden_size, num_heads, dim_feedforward=hidden_size*2, dropout=0, layer_norm_eps=1e-05, batch_first=True)
torch_graph_transformer = nn.TransformerEncoder(encoder_layer, 3)
# torch_graph_transformer.to(device)
# x = torch.load('ant_sample_transformer_inputs.pt')
# print(f'x shape: {x.shape}')
# out = torch_graph_transformer(x)

# print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(1)/1024/1024/1024))
# print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(1)/1024/1024/1024))
# print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(1)/1024/1024/1024))

### Study bool tensor for transformer input mask

In [5]:
structural_mask = torch.ones(5, 5).to(torch.bool)
temporal_mask = torch.zeros(5, 5).to(torch.bool)

In [12]:
a = [structural_mask for i in range(10)]
a = torch.stack(a, dim=0)
a.shape

torch.Size([10, 5, 5])

In [18]:
attn_mask = [1, 1, 1]
num_node = 3
a = torch.triu(torch.ones(3, 3), diagonal=0)
a = a.nonzero()*num_node
b = torch.ones(9, 9).to(torch.bool)
temporal_mask = torch.ones(9, 9).to(torch.bool)
a

tensor([[0, 0],
        [0, 3],
        [0, 6],
        [3, 3],
        [3, 6],
        [6, 6]])

In [20]:
a.shape[0]

6

## Check ant_dir direction distribution, cheetah_vel velocity distribution

In [2]:
import pickle
import numpy as np

In [8]:
env_name = 'ant_dir'
tasks = []
for idx in range(50):
    task_paths = f"config/{env_name}/config_{env_name}_task{idx}.pkl"   
    with open(task_paths.format(idx), 'rb') as f:
        task_info = pickle.load(f)
        tasks.append(task_info[0]['goal'])

tasks = np.array(tasks)
tasks
# sorted_tasks = np.sort(tasks)
# np.diff(sorted_tasks)

array([1.20336059, 0.42663053, 4.94477548, 4.12386514, 4.00566193,
       3.61661964, 0.24543954, 2.24820918, 5.9419027 , 0.37727185,
       5.42893665, 5.51217894, 0.32165929, 4.09926706, 3.46675609,
       3.75428649, 3.03809995, 1.77806705, 1.87066586, 3.5280645 ,
       2.48843943, 4.95555271, 2.62941494, 0.904175  , 0.94817473,
       0.34709164, 4.51156075, 1.83668411, 1.24893307, 5.22361358,
       3.56879346, 0.5173559 , 3.42432703, 0.99876725, 4.25222348,
       0.74438319, 2.79599274, 5.57935818, 5.00938089, 0.42712295,
       6.03680751, 4.14190964, 4.51620333, 4.67203061, 5.577812  ,
       0.84160545, 4.88187527, 5.26525596, 5.21067551, 0.18319881])

In [10]:
rng = np.random.default_rng(seed=42)
tasks = 2 * np.pi * rng.random(50)
tasks

array([4.86290927, 2.75755456, 5.39472984, 4.38169255, 0.59173373,
       6.13001603, 4.78238179, 4.93898769, 0.80496169, 2.82985831,
       2.3297927 , 5.82303616, 4.04552386, 5.16956368, 2.78605358,
       1.427783  , 3.48455899, 0.40097565, 5.20016002, 3.96886447,
       4.76320575, 2.22755235, 6.09907556, 5.61164551, 4.89072775,
       1.22295107, 2.93249455, 0.27522718, 0.96942947, 4.29172315,
       4.67947864, 6.07904294, 2.0472211 , 2.32766698, 2.95030617,
       1.19048366, 0.81632089, 2.9889422 , 1.42571349, 4.20856545,
       2.74670651, 5.23187141, 4.3998954 , 1.96265749, 5.22924256,
       5.05648359, 2.43459846, 1.81161891, 4.28824572, 0.87809075])

### Attention mask for torch graph transformer

In [18]:
a = torch.ones(5)
print(int(a.sum()))
print(len(a))
len(a) - int(a.sum())


5
5


0

In [7]:
def get_temporal_mask_list(num_node, max_length, prompt_length):
    temporal_masks = []
    for num_zeros in range(max_length):
        temporal_edge = torch.tril(torch.ones(max_length + prompt_length, max_length + prompt_length), diagonal=0)
        temporal_edge[:, prompt_length:prompt_length+num_zeros] = 0
        temporal_edge[prompt_length:prompt_length+num_zeros, :] = 0
        temporal_edge = temporal_edge.nonzero()*num_node
        full_temporal_edge = torch.cat([temporal_edge+i for i in range(num_node)], dim=0)
        temporal_mask = torch.zeros(num_node * (max_length + prompt_length), num_node * (max_length + prompt_length))
        for edge in full_temporal_edge:
            temporal_mask[edge[0], edge[1]] = 1
        temporal_mask[:num_node * prompt_length, :num_node * prompt_length] = 0
        
        temporal_masks.append(temporal_mask)

    return temporal_masks

In [8]:
num_node = 3
max_length = 3
prompt_length = 2
a = get_temporal_mask_list(num_node, max_length, prompt_length)
for a_ in a:
    print(a_)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 1.,

In [5]:
def get_agent_config(name):
    if name == 'halfcheetah':
        agent = {
            'state_dim': 17,
            'act_dim': 6,
            'num_node': 7,
            'node_type_state_action_len': {
                'root': 5,
                'thigh': 3,
                'shin': 3,
            },
            'node_type': {
                'root': 'root',
                'back_thigh': 'thigh',
                'back_shin': 'shin',
                'bach_foot': 'foot',
                'front_thigh': 'thigh',
                'front_shin': 'shin',
                'front_foot': 'foot',
            },
            'state_position': {
                'root': [0, 1, 8, 9, 10],
                'back_thigh': [2, 11],
                'back_shin': [3, 12],
                'bach_foot': [4, 13],
                'front_thigh': [5, 14],
                'front_shin': [6, 15],
                'front_foot': [7, 16],
            },
            'action_position': {
                'root': [],
                'back_thigh': [0],
                'back_shin': [1],
                'bach_foot': [2],
                'front_thigh': [3],
                'front_shin': [4],
                'front_foot': [5],
            },
            'edge': [[3, 2], [2, 1], [1, 0], [0, 4], [4, 5], [5, 6]],
        }
    elif name == 'ant': 
        agent = {
            'state_dim': 27,
            'act_dim': 8,
            'num_node': 9,
            'node_type_state_action_len': {
                'root': 11,
                'hip': 3,
                'ankle': 3,
            },
            'node_type': {
                'root': 'root',
                'hip_1': 'hip',
                'ankle_1': 'ankle',
                'hip_2': 'hip',
                'ankle_2': 'ankle',
                'hip_3': 'hip',
                'ankle_3': 'ankle',
                'hip_4': 'hip',
                'ankle_4': 'ankle',
            },
            'state_position': {
                'root': [0, 1, 2, 3, 4, 13, 14, 15, 16, 17, 18],
                'hip_1': [5, 19],
                'ankle_1': [6, 20],
                'hip_2': [7, 21],
                'ankle_2': [8, 22],
                'hip_3': [9, 23],
                'ankle_3': [10, 24],
                'hip_4': [11, 25],
                'ankle_4': [12, 26],
            },
            'action_position': {
                'root': [],
                'hip_1': [0],
                'ankle_1': [1],
                'hip_2': [2],
                'ankle_2': [3],
                'hip_3': [4],
                'ankle_3': [5],
                'hip_4': [6],
                'ankle_4': [7],
            },
            'edge': [[0, 1], [0, 3], [0, 5], [0, 7], [1, 2], [3, 4], [5, 6], [7, 8]],
        }
    else:
        raise NameError('agent name not valid, only "hopper", "halfcheetah" and "walker2d" are implemented')
    
    return agent

In [3]:
agent = get_agent_config('ant')
num_node = agent['num_node']
max_length = 20
prompt_length = 5
structural_mask = get_structural_mask(agent, max_length, prompt_length)
temporal_mask_list = get_temporal_mask_list(num_node, max_length, prompt_length)
mask_store = []
for temporal_mask in temporal_mask_list:
    mask = structural_mask + temporal_mask
    mask_store.append(mask)
mask_store = torch.stack(mask_store, dim=0)
print(mask_store.shape)

torch.Size([20, 225, 225])


In [12]:
a = np.array([1, 0, 1, 1, 0])
a = torch.from_numpy(a).bool()
b = np.array([0, 1, 1, 0, 0])
b = torch.from_numpy(b).bool()
~(~a + ~b)

tensor([False, False,  True, False, False])

### investigate NaN in action
mask is not right...

In [6]:
x = torch.tensor([1, 2, np.nan])
assert not torch.isnan(x).any()

AssertionError: 

### Check model size

In [1]:
import torch
from torchinfo import summary
from prompt_dt.prompt_decision_transformer import PromptDecisionTransformer
from prompt_dt.graph_prompt_decision_transformer import GPDT_V2_Torch
# from prompt_dt.prompt_utils import get_agent_config

In [6]:
b_size = 560
state_dim = 27
act_dim = 8
K = 20
p_len = 5
embed_dim = 128
n_layer = 3
n_head = 1
dropout = 0.1
agent_name = 'ant'
agent = get_agent_config(agent_name)

In [3]:
model = PromptDecisionTransformer(
    state_dim=state_dim,
    act_dim=act_dim,
    max_length=K,
    max_ep_len=1000,
    hidden_size=embed_dim,
    n_layer=n_layer,
    n_head=n_head,
    n_inner=4 * embed_dim,
    activation_function='relu',
    n_positions=1024,
    resid_pdrop=dropout,
    attn_pdrop=dropout,
)
                            # states, actions, rewards, returns_to_go, timesteps, attention_mask, prompt_states, prompt_actions, prompt_rewards, prompt_dones, prompt_returns_to_go, prompt_timesteps, prompt_attention_mask
# summary(model, 
#         input_size = [(b_size, K, state_dim), (b_size, K, act_dim), (b_size, K, 1), (b_size, K, 1), (b_size, K), (b_size, K), (b_size, p_len, state_dim), (b_size, p_len, act_dim), (b_size, p_len, 1), (b_size, p_len, 1), (b_size, p_len, 1), (b_size, p_len), (b_size, p_len)],
#         dtypes = [torch.float32, torch.float32, torch.float32, torch.float32, torch.int64, torch.float64, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.int64, torch.float64],
#         depth = 0,
#         device = 'cpu')
n_params = sum(p.numel() for p in model.parameters())
print(n_params)

866084


In [7]:
model = GPDT_V2_Torch(
    agent=agent,
    max_length=K,
    max_ep_length=1000,
    prompt_length=p_len,
    hidden_size=embed_dim,
    n_layers=n_layer,
    dropout=dropout,
)
# summary(model, 
#         input_size = [(64, 20, 17), (64, 21, 6), (64, 20, 1), (64, 20, 1), (64, 20), (64, 20)],
#         dtypes = [torch.float32, torch.float32, torch.float32, torch.float32, torch.int64, torch.float64],
#         depth = 0,
#         device = 'cpu')
n_params = sum(p.numel() for p in model.parameters())
print(n_params)

856962


### Torch 2.0 compile test

In [1]:
import torch

In [2]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(x)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

tensor([[ 0.0167,  1.0197,  1.1520,  0.5055,  1.1966, -1.3498,  0.0839,  1.3925,
          1.4011,  1.3631],
        [ 1.3073,  0.7412,  1.4141, -0.8988, -0.3289,  1.3362,  1.0690,  0.8997,
          1.3311,  0.8283],
        [ 1.2202,  0.6734,  0.6878,  0.2398, -1.0635,  1.3986,  0.9566,  0.9573,
         -0.4506,  1.3072],
        [ 1.4010, -0.1403,  0.6883,  0.7119,  0.0699,  0.8176,  1.3714,  1.2906,
          1.4136, -0.3084],
        [ 1.3427,  0.9271, -1.1760, -0.8434,  1.1494,  0.1401,  1.3093, -0.0803,
         -0.5412,  1.0048],
        [ 1.3966, -0.1594,  0.1448,  0.5719,  1.3931,  1.4101,  1.3929,  0.4363,
         -0.3037,  0.4207],
        [-0.5789, -0.0328,  1.2149,  0.7540,  1.4142,  0.8863,  0.9898,  0.5243,
          1.2279, -1.1461],
        [ 0.1016,  1.3345,  0.8952,  0.3761,  0.4474,  1.3854,  0.3412, -1.0334,
          0.7773,  0.8762],
        [ 1.3877, -0.2859,  1.3510,  0.2604,  1.3313,  1.4089, -1.4055,  0.3332,
          1.3932,  0.4427],
        [ 1.3954,  

### test wandb (weight and biases)

In [1]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.02,
    "architecture": "CNN",
    "dataset": "CIFAR-100",
    "epochs": 10,
    }
)

# simulate training
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
    acc = 1 - 2 ** -epoch - random.random() / epoch - offset
    loss = 2 ** -epoch + random.random() / epoch + offset
    
    # log metrics to wandb
    wandb.log({"acc": acc, "loss": loss})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myihung[0m. Use [1m`wandb login --relogin`[0m to force relogin
