 #### Install necessary dependacies

In [None]:
import gymnasium as gym
import gymnasium_robotics
from just_d4rl import d4rl_offline_dataset
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchesn.nn import ESNimport gymnasium as gym
import gymnasium_robotics
from just_d4rl import d4rl_offline_dataset
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchesn.nn import ESN
from tqdm import tqdm
import numpy as np
os.environ["MUJOCO_GL"] = "egl"

#### Create environment and train model

In [None]:
data = d4rl_offline_dataset("hammer-expert-v1")

In [None]:
observations = data['observations']  
actions = data['actions']            

# Convert to PyTorch tensors
observations = torch.tensor(observations, dtype=torch.float32)[:-1, :]
actions = torch.tensor(actions, dtype=torch.float32)[:-1, :]

print(observations.shape)
print(actions.shape)

In [None]:
seq_len =20# number of past steps fed into ESN

class SequenceRLDataset(Dataset):
    
    def __init__(self, obs, acts, seq_len):
        self.obs = obs[:25000]
        self.acts = acts[:25000]
        self.seq_len = seq_len

    def __len__(self):
        return len(self.obs) - self.seq_len

    def __getitem__(self, idx):
        obs_seq = self.obs[idx:idx+self.seq_len]       # [seq_len, obs_dim]
        target_action = self.acts[idx:idx+self.seq_len]   # next action
        return obs_seq, target_action

offline_dataset = SequenceRLDataset(observations, actions, seq_len)
dataloader = DataLoader(offline_dataset, batch_size=256, shuffle=True)


#### Graident Descent Solver

In [None]:
input_size = observations.shape[1]
hidden_size = 64
output_size = actions.shape[1]
device = 'cpu'
esn = ESN(
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    readout_training='gd', 
    nonlinearity = "tanh",
    batch_first = True,
    output_steps = 'all',
    w_io = False
    
).to(device)

In [None]:
optimizer = torch.optim.Adam(esn.parameters(), lr = 0.001)
loss_fn = torch.nn.HuberLoss()
epochs = 10
esn.train()
history = {}

for epoch in range(epochs):
    history[epoch] = []
    sum_loss = 0
    print(f"Epoch:{epoch}")
    for x_batch, y_batch in tqdm(dataloader):
        optimizer.zero_grad()
        washout_batch = [0]*x_batch.shape[0]
        output, _ = esn(x_batch, washout_batch)
    
         
        output = output.reshape(-1, output.shape[-1]) 
        y_batch = y_batch.reshape(-1, y_batch.shape[-1])

        loss = loss_fn(output, y_batch)
        loss.backward()
        
        optimizer.step()

        history[epoch].append(loss.item())
    print(f"avg loss:{np.mean(history[epoch])}\n")
       

#### Inv Solver(Closed form solver)

In [None]:
input_size = observations.shape[1]
hidden_size = 64
output_size = actions.shape[1]
device = 'cpu'
esn = ESN(
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    readout_training='inv', 
    nonlinearity = "tanh",
    batch_first = True,
    output_steps = 'all',
    w_io = False
    
).to(device)

In [None]:
for x_batch, y_batch in tqdm(dataloader):
    x_batch = x_batch.to(device)  # [batch, seq_len, obs_dim]
    y_batch = y_batch.to(device).reshape(-1, y_batch.shape[-1])  # [batch, action_dim]

    washout_batch = [0]*x_batch.shape[0]  
    esn(x_batch, washout_batch, target=y_batch)  # accumulate stats for ridge regression
esn.fit()  # computes the linear readout weights

#### Valdiate model by having it interact with live MuJoCo environment

In [None]:
from gymnasium.wrappers import RecordVideo

env_name = "AdroitHandHammer-v1"
env = gym.make(env_name, render_mode = "rgb_array")
env = RecordVideo(env, "./")

env_data = env.reset()
obs = env_data[0]

episode_reward = 0
max_ep_timesteps = 2000
hidden = None

for t in range(max_ep_timesteps):
    print(f"timestep: {t}")
    obs = torch.tensor(obs).type(torch.float32).reshape(1, 1, -1)
    action, hidden  = esn(obs, washout = [0], h_0 = hidden)
    action = action.detach().numpy().flatten()
    env_data = env.step(action)
    obs = env_data[0]
    reward = env_data[1]
    done = env_data[2]

    episode_reward += reward

    print(f"action: {action}")
    print(f"episode reward: {episode_reward}")
    print(env_data[1:])

    if done: break
env.close()