# import libs


In [49]:
import os
import sys
import random
import csv
from datetime import datetime
import pickle
import collections
import math
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from symmetrizer.nn.modules import BasisLinear
from symmetrizer.ops import GroupRepresentations
from symmetrizer.groups import MatrixRepresentation
from utils.symmetrizer_utils import create_inverted_pendulum_actor_representations, create_inverted_pendulum_qfunction_representations, actor_equivariance_mae, q_equivariance_mae

In [50]:
use_emlp = True

In [51]:
# # Define the BasisLinear layer for testing
# state_dim = 4  # Dimensionality of the input state
# h_dim = 128     # Dimensionality of the output representation
# # Test input: batch of states
# batch_size = 64  # Example batch size
# seq_len = 20    # Sequence length (context length)
# in_action_embedding = [
# torch.DoubleTensor(np.eye(4)), 
# torch.DoubleTensor(-1 * np.eye(4))
# ]
# in_group = GroupRepresentations(in_action_embedding, "ActionRepr")

# out_action_embedding =  [
#     torch.DoubleTensor(np.eye(seq_len*2)), 
#     torch.DoubleTensor(-1 * np.eye(seq_len*2))
# ]
# out_group = GroupRepresentations(out_action_embedding, "ActionRepr")

# repr_inter = MatrixRepresentation(in_group, out_group)

# basis_layer = BasisLinear(
#     channels_in=1,  # Input dimensionality
#     channels_out=h_dim,     # Output dimensionality
#     group=repr_inter,          # Group representation for equivariance
# )


# # Create the environment
# env = gym.make("InvertedPendulum-v4")

# # Reset the environment and get the initial state
# x, _ = env.reset()  # Unpack observation and metadata

# # Convert the observation to a PyTorch tensor and add a batch dimension
# t_x = torch.Tensor(x).unsqueeze(0).unsqueeze(0)  # Shape [1, state_dim]
# dummy_input = torch.randn(batch_size, seq_len ,state_dim)
# print("Initial observation:", dummy_input.shape)

# layer = nn.Linear(state_dim, h_dim)
# #output = basis_layer(dummy_input).permute(0, 2, 1) 
# output = layer(dummy_input)

# print("Output shape:", output.shape)      
# # proint the first 10 elements of the output
# # Print the first 10 elements of the output
# print("Output:", output[1, 1, : ].shape)  


# decision transformer model

In [None]:
class MaskedCausalAttention(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p, context_len):
        super().__init__()

        self.n_heads = n_heads
        self.max_T = max_T
         ## group representations
        in_action_embedding = [
        torch.DoubleTensor(np.eye(context_len*3)), 
        torch.DoubleTensor(-1 * np.eye(context_len*3))
        ]
        in_group = GroupRepresentations(in_action_embedding, "ActionRepr")
        
        

        out_action_embedding =  [
            torch.DoubleTensor(np.eye(context_len*3)), 
            torch.DoubleTensor(-1 * np.eye(context_len*3))
        ]
        out_group = GroupRepresentations(out_action_embedding, "ActionRepr")


        repr_inter = MatrixRepresentation(in_group, out_group)

        if use_emlp:
            self.q_net = nn.Linear(h_dim, h_dim, dtype=torch.float64)
            self.k_net = nn.Linear(h_dim, h_dim, dtype=torch.float64)
            self.v_net = nn.Linear(h_dim, h_dim, dtype=torch.float64)
        else:
            self.q_net = nn.Linear(h_dim, h_dim, dtype=torch.float64)
            self.k_net = nn.Linear(h_dim, h_dim, dtype=torch.float64)
            self.v_net = nn.Linear(h_dim, h_dim, dtype=torch.float64)

        self.proj_net = nn.Linear(h_dim, h_dim, dtype=torch.float64)

        self.att_drop = nn.Dropout(drop_p)
        self.proj_drop = nn.Dropout(drop_p)

        ones = torch.ones((max_T, max_T))
        mask = torch.tril(ones).view(1, 1, max_T, max_T)

        # register buffer makes sure mask does not get updated
        # during backpropagation
        self.register_buffer('mask',mask)

    def forward(self, x):
        B, T, C = x.shape # batch size, seq length, h_dim * n_heads

        N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim

        # rearrange q, k, v as (B, N, T, D)
        if use_emlp:
            q = self.q_net(x).view(B, T, N, D).transpose(1,2)
            k = self.k_net(x).view(B, T, N, D).transpose(1,2)
            v = self.v_net(x).view(B, T, N, D).transpose(1,2)
        else:

            q = self.q_net(x.permute(0,2,1)).view(B, T, N, D).transpose(1,2)
            k = self.k_net(x.permute(0,2,1)).view(B, T, N, D).transpose(1,2)
            v = self.v_net(x.permute(0,2,1)).view(B, T, N, D).transpose(1,2)

        # weights (B, N, T, T)
        weights = q @ k.transpose(2,3) / math.sqrt(D)
        # causal mask applied to weights
        weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf'))
        # normalize weights, all -inf -> 0 after softmax
        normalized_weights = F.softmax(weights, dim=-1)

        # attention (B, N, T, D)
        attention = self.att_drop(normalized_weights @ v)

        # gather heads and project (B, N, T, D) -> (B, T, N*D)
        attention = attention.transpose(1, 2).contiguous().view(B,T,N*D)

        out = self.proj_drop(self.proj_net(attention))
        return out


class Block(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p, context_len):
        super().__init__()
        self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p, context_len)
        self.mlp = nn.Sequential(
                nn.Linear(h_dim, 4*h_dim, 
            dtype=torch.float64),
                nn.GELU(),
                nn.Linear(4*h_dim, h_dim,
            dtype=torch.float64),
                nn.Dropout(drop_p))
        self.ln1 = nn.LayerNorm(h_dim, dtype=torch.float64)
        self.ln2 = nn.LayerNorm(h_dim, dtype=torch.float64)

    def forward(self, x):
        # Attention -> LayerNorm -> MLP -> LayerNorm
        x = x + self.attention(x) # residual
        x = self.ln1(x)
        x = x + self.mlp(x) # residual
        x = self.ln2(x)
        return x


class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len,
                 n_heads, drop_p, max_timestep=4096):
        super().__init__()

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.h_dim = h_dim
        
        ## group representations
        in_action_embedding = [
            torch.DoubleTensor(np.eye(act_dim)), 
            torch.DoubleTensor(-1 * np.eye(act_dim))
        ]
        in_group = GroupRepresentations(in_action_embedding, "ActionRepr")
        
        

        out_action_embedding =  [
            torch.DoubleTensor(np.eye(context_len)), 
            torch.DoubleTensor(-1 * np.eye(context_len))
        ]
        out_group = GroupRepresentations(out_action_embedding, "ActionRepr")


        repr_in = MatrixRepresentation(in_group, out_group)
        repr_out = MatrixRepresentation(out_group, in_group)
        ########################################################
        in_state_embedding = [
        torch.DoubleTensor(np.eye(state_dim)), 
        torch.DoubleTensor(-1 * np.eye(state_dim))
        ]
        in_group = GroupRepresentations(in_state_embedding, "ActionRepr")
        

        out_state_embedding =  [
            torch.DoubleTensor(np.eye(context_len)), 
            torch.DoubleTensor(-1 * np.eye(context_len))
        ]
        out_group = GroupRepresentations(out_state_embedding, "ActionRepr")
        
        repr_in_s = MatrixRepresentation(in_group, out_group)
        repr_out_s = MatrixRepresentation(out_group, in_group)
        ### transformer blocks
        input_seq_len = 3 * context_len
        blocks = [Block(h_dim, input_seq_len, n_heads, drop_p, context_len) for _ in range(n_blocks)]
        self.transformer = nn.Sequential(*blocks)

        ### projection heads (project to embedding)
        self.embed_ln = nn.LayerNorm(h_dim, dtype=torch.float64)
        self.embed_timestep = nn.Embedding(max_timestep, h_dim)
        self.embed_rtg = torch.nn.Linear(1, h_dim)
        if use_emlp:
            self.embed_state = BasisLinear(1, h_dim, group=repr_in_s)
        else:
            self.embed_state = torch.nn.Linear(state_dim, h_dim)

        # # discrete actions
        # self.embed_action = torch.nn.Embedding(act_dim, h_dim)
        # use_action_tanh = False # False for discrete actions

        # continuous actions
        if use_emlp:
            self.embed_action = BasisLinear(1, h_dim, group=repr_in)
            use_action_tanh = True # True for continuous actions
        else:
            self.embed_action = torch.nn.Linear(act_dim, h_dim)
            use_action_tanh = False

        ### prediction heads
        self.predict_rtg = torch.nn.Linear(h_dim, 1, dtype=torch.float64)
        
        if use_emlp:
            self.predict_state = BasisLinear(h_dim, 1, group=repr_out_s)
            self.predict_action = nn.Sequential(
                *([BasisLinear(h_dim, 1, group=repr_out)] + ([nn.Tanh()] if use_action_tanh else []))
            )
        else:
            self.predict_state = torch.nn.Linear(h_dim, state_dim, dtype=torch.float64)
            self.predict_action = nn.Sequential(
                *([nn.Linear(h_dim, act_dim, dtype=torch.float64)] + ([nn.Tanh()] if use_action_tanh else []))
            )


    def forward(self, timesteps, states, actions, returns_to_go):

        B, T, _ = states.shape

        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings
        state_embeddings = self.embed_state(states) #+ time_embeddings
        action_embeddings = self.embed_action(actions)# + time_embeddings
        returns_embeddings = self.embed_rtg(returns_to_go) #+ time_embeddings
        
        if use_emlp:
            state_embeddings = state_embeddings.permute(0, 2, 1)
            action_embeddings = action_embeddings.permute(0, 2, 1)

        # stack rtg, states and actions and reshape sequence as
        # (r1, s1, a1, r2, s2, a2 ...)
        h = torch.stack(
            (returns_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)

        h = self.embed_ln(h)

        # transformer and prediction
        h = self.transformer(h)

        # get h reshaped such that its size = (B x 3 x T x h_dim) and
        # h[:, 0, t] is conditioned on r_0, s_0, a_0 ... r_t
        # h[:, 1, t] is conditioned on r_0, s_0, a_0 ... r_t, s_t
        # h[:, 2, t] is conditioned on r_0, s_0, a_0 ... r_t, s_t, a_t
        h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)

        # get predictions
        return_preds = self.predict_rtg(h[:,2])     # predict next rtg given r, s, a
        state_preds = self.predict_state(h[:,2].permute(0, 2, 1))    # predict next state given r, s, a
        action_preds = self.predict_action(h[:,1].permute(0, 2, 1))  # predict action given r, s

        return state_preds, action_preds, return_preds

In [53]:
def discount_cumsum(x, gamma):
    disc_cumsum = np.zeros_like(x)
    disc_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        disc_cumsum[t] = x[t] + gamma * disc_cumsum[t+1]
    return disc_cumsum
class D4RLTrajectoryDataset(Dataset):
    def __init__(self, dataset_path, context_len, rtg_scale):
        self.context_len = context_len

        # load dataset
        with open(dataset_path, 'rb') as f:
            self.trajectories = pickle.load(f)

        # Handle Gymnasium `timeouts`
        for traj in self.trajectories:
            if 'timeouts' in traj:
                traj['terminals'] = np.logical_or(traj['terminals'], traj['timeouts'])

        # calculate min len of traj, state mean and variance
        # and returns_to-go for all traj
        min_len = 10**6
        states = []
        for traj in self.trajectories:
            traj_len = traj['observations'].shape[0]
            min_len = min(min_len, traj_len)
            states.append(traj['observations'])
            # calculate returns-to-go and rescale them
            traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale

        # used for input normalization
        states = np.concatenate(states, axis=0)
        self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

        # normalize states
        for traj in self.trajectories:
            traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std

    def get_state_stats(self):
        return self.state_mean, self.state_std

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        traj = self.trajectories[idx]
        traj_len = traj['observations'].shape[0]

        if traj_len >= self.context_len:
            # sample random index to slice trajectory
            si = random.randint(0, traj_len - self.context_len)

            states = torch.from_numpy(traj['observations'][si : si + self.context_len])
            actions = torch.from_numpy(traj['actions'][si : si + self.context_len])
            returns_to_go = torch.from_numpy(traj['returns_to_go'][si : si + self.context_len])
            timesteps = torch.arange(start=si, end=si+self.context_len, step=1)

            # all ones since no padding
            traj_mask = torch.ones(self.context_len, dtype=torch.long)

        else:
            padding_len = self.context_len - traj_len

            # padding with zeros
            states = torch.from_numpy(traj['observations'])
            states = torch.cat([states,
                                torch.zeros(([padding_len] + list(states.shape[1:])),
                                dtype=states.dtype)],
                               dim=0)

            actions = torch.from_numpy(traj['actions'])
            actions = torch.cat([actions,
                                torch.zeros(([padding_len] + list(actions.shape[1:])),
                                dtype=actions.dtype)],
                               dim=0)

            returns_to_go = torch.from_numpy(traj['returns_to_go'])
            returns_to_go = torch.cat([returns_to_go,
                                torch.zeros(([padding_len] + list(returns_to_go.shape[1:])),
                                dtype=returns_to_go.dtype)],
                               dim=0)

            timesteps = torch.arange(start=0, end=self.context_len, step=1)

            traj_mask = torch.cat([torch.ones(traj_len, dtype=torch.long),
                                   torch.zeros(padding_len, dtype=torch.long)],
                                  dim=0)

        return  timesteps, states, actions, returns_to_go, traj_mask

In [54]:
def reflect_states(states):
    """Reflect the states (e.g., negate specific dimensions)."""
    reflected_states = torch.clone(states)
    reflected_states*= -1  
    return reflected_states

def reflect_actions(actions):
    """Reflect the actions similarly to states."""
    reflected_actions = torch.clone(actions)
    reflected_actions*= -1  
    return reflected_actions

def test_equivariance(model, timesteps, states, actions, returns_to_go, traj_mask):
    """Test equivariance of the model."""
    model.eval()  # Set the model to evaluation mode

    # Reflect states and actions
    reflected_states = reflect_states(states)
    reflected_actions = reflect_actions(actions)

    # Get predictions for original inputs
    with torch.no_grad():
        state_preds, action_preds, _ = model.forward(
            timesteps=timesteps,
            states=states,
            actions=actions,
            returns_to_go=returns_to_go
        )

        # Get predictions for reflected inputs
        reflected_state_preds, reflected_action_preds, _ = model.forward(
            timesteps=timesteps,
            states=reflected_states,
            actions=reflected_actions,
            returns_to_go=returns_to_go
        )

    # Reflect the predictions back
    reflected_state_preds = reflect_states(reflected_state_preds)
    reflected_action_preds = reflect_actions(reflected_action_preds)

    # Compute equivariance loss
    state_equivariance_loss = torch.mean((state_preds - reflected_state_preds) ** 2).item()
    action_equivariance_loss = torch.mean((action_preds - reflected_action_preds) ** 2).item()

    return state_equivariance_loss, action_equivariance_loss

In [55]:
dataset = "stitched"       # medium / medium-replay / medium-expert
rtg_scale = 1000                # scale to normalize returns to go


env_name = 'InvertedPendulum-v4'
rtg_target = 10000
env_d4rl_name = f'InvertedPendulum-v4-{dataset}'


max_eval_ep_len = 1000      # max len of one evaluation episode
num_eval_ep = 10            # num of evaluation episodes per iteration

batch_size = 64             # training batch size
lr = 4e-4                   # learning rate
wt_decay = 1e-4             # weight decay
warmup_steps = 10000        # warmup steps for lr scheduler

# total updates = max_train_iters x num_updates_per_iter
max_train_iters = 500
num_updates_per_iter = 500

context_len = 5        # K in decision transformer
n_blocks = 2            # num of transformer blocks
embed_dim = 128         # embedding (hidden) dim of transformer
n_heads = 1             # num of transformer heads
dropout_p = 0.1         # dropout probability



# load data from this file
dataset_path = f'processed_data/{env_d4rl_name}.pkl'


In [56]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
model = DecisionTransformer(
    state_dim=state_dim,  # Replace with the appropriate dimension
    act_dim=act_dim,      # Replace with the appropriate dimension
    n_blocks=n_blocks,    # Number of transformer blocks
    h_dim=embed_dim,      # Hidden dimension of embeddings
    context_len=context_len,  # Context length
    n_heads=n_heads,      # Number of attention heads
    drop_p=dropout_p      # Dropout probability
).to(device)

print("Model loaded successfully.")

traj_dataset = D4RLTrajectoryDataset(dataset_path, context_len, rtg_scale)

traj_data_loader = DataLoader(
    traj_dataset,
    batch_size=1,  # Adjust batch size as needed
    shuffle=True,
    pin_memory=True,
    drop_last=True
)

# Get a batch of test data
data_iter = iter(traj_data_loader)
timesteps, states, actions, returns_to_go, traj_mask = next(data_iter)

# Move data to device
timesteps = timesteps.to(device)
states = states.to(device)
actions = actions.to(device)
returns_to_go = returns_to_go.to(device).unsqueeze(dim=-1)
traj_mask = traj_mask.to(device)

# Test equivariance
state_eq_loss, action_eq_loss = test_equivariance(
    model, timesteps, states, actions, returns_to_go, traj_mask
)

print(f"State Equivariance Loss: {state_eq_loss}")
print(f"Action Equivariance Loss: {action_eq_loss}")


Model loaded successfully.


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x15 and 128x128)

In [None]:
# Get a batch of test data
data_iter = iter(traj_data_loader)
timesteps, states, actions, returns_to_go, traj_mask = next(data_iter)

# Move data to device
timesteps = timesteps.to(device)
states = states.to(device)
actions = actions.to(device)
returns_to_go = returns_to_go.to(device).unsqueeze(dim=-1)
traj_mask = traj_mask.to(device)

# Test equivariance
state_eq_loss, action_eq_loss = test_equivariance(
    model, timesteps, states, actions, returns_to_go, traj_mask
)

print(f"State Equivariance Loss: {state_eq_loss}")
print(f"Action Equivariance Loss: {action_eq_loss}")

State Equivariance Loss: 2.5476254518182375
Action Equivariance Loss: 1.061262476728831e-24
