 #### Install necessary dependacies

In [1]:
import gymnasium as gym
import gymnasium_robotics
from just_d4rl import d4rl_offline_dataset
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchesn.nn import ESNimport gymnasium as gym
import gymnasium_robotics
from just_d4rl import d4rl_offline_dataset
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchesn.nn import ESN
from tqdm import tqdm
import numpy as np
os.environ["MUJOCO_GL"] = "egl"

#### Create environment and train model

In [2]:
data = d4rl_offline_dataset("hammer-expert-v1")

load datafile: 100%|████████████████████████████| 22/22 [00:04<00:00,  5.31it/s]


Dataset loaded and saved at: /home/credit-research2/.d4rl/datasets/hammer-expert-v1.hdf5


In [3]:
observations = data['observations']  
actions = data['actions']            

# Convert to PyTorch tensors
observations = torch.tensor(observations, dtype=torch.float32)[:-1, :]
actions = torch.tensor(actions, dtype=torch.float32)[:-1, :]

print(observations.shape)
print(actions.shape)

torch.Size([994999, 46])
torch.Size([994999, 26])


In [4]:
seq_len =20# number of past steps fed into ESN

class SequenceRLDataset(Dataset):
    
    def __init__(self, obs, acts, seq_len):
        self.obs = obs[:25000]
        self.acts = acts[:25000]
        self.seq_len = seq_len

    def __len__(self):
        return len(self.obs) - self.seq_len

    def __getitem__(self, idx):
        obs_seq = self.obs[idx:idx+self.seq_len]       # [seq_len, obs_dim]
        target_action = self.acts[idx:idx+self.seq_len]   # next action
        return obs_seq, target_action

offline_dataset = SequenceRLDataset(observations, actions, seq_len)
dataloader = DataLoader(offline_dataset, batch_size=256, shuffle=True)


#### Graident Descent Solver

In [13]:
input_size = observations.shape[1]
hidden_size = 64
output_size = actions.shape[1]
device = 'cpu'
esn = ESN(
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    readout_training='gd', 
    nonlinearity = "tanh",
    batch_first = True,
    output_steps = 'all',
    w_io = False
    
).to(device)

In [14]:
optimizer = torch.optim.Adam(esn.parameters(), lr = 0.001)
loss_fn = torch.nn.HuberLoss()
epochs = 10
esn.train()
history = {}

for epoch in range(epochs):
    history[epoch] = []
    sum_loss = 0
    print(f"Epoch:{epoch}")
    for x_batch, y_batch in tqdm(dataloader):
        optimizer.zero_grad()
        washout_batch = [0]*x_batch.shape[0]
        output, _ = esn(x_batch, washout_batch)
    
         
        output = output.reshape(-1, output.shape[-1]) 
        y_batch = y_batch.reshape(-1, y_batch.shape[-1])

        loss = loss_fn(output, y_batch)
        loss.backward()
        
        optimizer.step()

        history[epoch].append(loss.item())
    print(f"avg loss:{np.mean(history[epoch])}\n")
       

Epoch:0


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 285.42it/s]


avg loss:0.09798277321518684

Epoch:1


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 327.39it/s]


avg loss:0.045422865762090196

Epoch:2


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 272.89it/s]


avg loss:0.04132978059351444

Epoch:3


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 172.33it/s]


avg loss:0.03958547187550944

Epoch:4


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 239.28it/s]


avg loss:0.03860433661968124

Epoch:5


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 166.22it/s]


avg loss:0.03794956606413637

Epoch:6


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 175.76it/s]


avg loss:0.03748039163801135

Epoch:7


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 170.49it/s]


avg loss:0.03712150393700113

Epoch:8


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 243.90it/s]


avg loss:0.0368328793924682

Epoch:9


100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 195.72it/s]

avg loss:0.03660863387037297






#### Inv Solver(Closed form solver)

In [16]:
input_size = observations.shape[1]
hidden_size = 64
output_size = actions.shape[1]
device = 'cpu'
esn = ESN(
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    readout_training='inv', 
    nonlinearity = "tanh",
    batch_first = True,
    output_steps = 'all',
    w_io = False
    
).to(device)

In [17]:
for x_batch, y_batch in tqdm(dataloader):
    x_batch = x_batch.to(device)  # [batch, seq_len, obs_dim]
    y_batch = y_batch.to(device).reshape(-1, y_batch.shape[-1])  # [batch, action_dim]

    washout_batch = [0]*x_batch.shape[0]  
    esn(x_batch, washout_batch, target=y_batch)  # accumulate stats for ridge regression
esn.fit()  # computes the linear readout weights

100%|██████████████████████████████████████████| 98/98 [00:00<00:00, 246.70it/s]


#### Valdiate model by having it interact with live MuJoCo environment

In [18]:
from gymnasium.wrappers import RecordVideo

env_name = "AdroitHandHammer-v1"
env = gym.make(env_name, render_mode = "rgb_array")
env = RecordVideo(env, "./")

env_data = env.reset()
obs = env_data[0]

episode_reward = 0
max_ep_timesteps = 2000
hidden = None

for t in range(max_ep_timesteps):
    print(f"timestep: {t}")
    obs = torch.tensor(obs).type(torch.float32).reshape(1, 1, -1)
    action, hidden  = esn(obs, washout = [0], h_0 = hidden)
    action = action.detach().numpy().flatten()
    env_data = env.step(action)
    obs = env_data[0]
    reward = env_data[1]
    done = env_data[2]

    episode_reward += reward

    print(f"action: {action}")
    print(f"episode reward: {episode_reward}")
    print(env_data[1:])

    if done: break
env.close()

timestep: 0
action: [ 0.30026206 -0.12186994  0.21562415  0.45231408  0.0800792  -0.11734063
 -0.46647924 -0.53990966  0.14862014 -0.81737274 -0.14479755 -0.09197253
 -0.18690357 -0.7024277  -0.07829814 -0.11201181 -0.42128995 -0.5108693
 -0.5314501  -0.33754835 -0.19214204 -0.4399683   0.28944486  0.7147401
 -0.63612854  0.549347  ]
episode reward: -1.300895768297768
(np.float64(-1.300895768297768), False, False, {'success': np.False_})
timestep: 1
action: [ 0.37497848  0.07417182  0.4912244   0.30155137  0.02892671  0.02585705
 -0.36581737 -0.2817188   0.15492336 -0.4861309  -0.15696114  0.15370737
 -0.20187818 -0.4604655   0.2996749   0.0403638  -0.5442527  -0.5178546
 -0.25780267 -0.26361528 -0.04743642 -0.02576072  0.16844304  0.6522112
 -0.68735904  0.44279844]
episode reward: -2.7169455242592777
(np.float64(-1.4160497559615095), False, False, {'success': np.False_})
timestep: 2
action: [ 0.33536646 -0.02893     0.33640307  0.22186397  0.06418876 -0.07726982
 -0.3533176  -0.37538