In [1]:
%load_ext autoreload  
%autoreload 2  
!hostname  
!pwd  
import os
import sys

print(sys.executable)
os.environ['CUDA_VISIBLE_DEVICES'] = "7"

oliva-titanrtx-2.csail.mit.edu
/data/vision/phillipi/akumar01/synthetic-mdps/src
/data/vision/phillipi/akumar01/.virtualenvs/smdps-mujoco/bin/python


In [2]:
import os, sys, glob, pickle
from functools import partial  

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np

# from einops import rearrange, reduce, repeat

In [3]:
import torch
import torch.nn as nn

In [7]:
from torch.distributions.normal import Normal
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Agent(nn.Module):
    def __init__(self, envs, rpo_alpha):
        super().__init__()
        self.rpo_alpha = rpo_alpha
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        else:  # new to RPO
            # sample again to add stochasticity to the policy
            z = torch.FloatTensor(action_mean.shape).uniform_(-self.rpo_alpha, self.rpo_alpha).to(device)
            action_mean = action_mean + z
            probs = Normal(action_mean, action_std)

        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x), action_mean


In [48]:
with open("/data/vision/phillipi/akumar01/synthetic-mdps-data/datasets/mujoco/HalfCheetah/dataset.pkl", "rb") as f:
    dataset = pickle.load(f)
d_obs, d_act = dataset['obs'].shape[-1], dataset['act_mean'].shape[-1]

In [49]:
def sample_batch(dataset, batch_size):
    i_b = torch.randint(0, dataset['obs'].shape[0], (batch_size,))
    i_t = torch.randint(0, dataset['obs'].shape[1], (batch_size,))
    return {k: v[i_b, i_t] for k, v in dataset.items()}

In [112]:
rng = jax.random.PRNGKey(0)
agent = Agent(d_obs, d_act)
batch = sample_batch(rng, dataset, 1)

rng, _rng = split(rng)
agent_params = agent.init(_rng, jax.tree_map(lambda x: x[0], batch['obs']))

agent_forward = jax.jit(jax.vmap(agent.apply, in_axes=(None, 0)))

def iter_step(state, batch):
    def loss_fn(params):
        act_pred = agent_forward(params, batch['obs'])
        return jnp.mean(jnp.square(act_pred - batch['act_mean']))
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(state.params)
    return state.apply_gradients(grads=grad), loss

tx = optax.chain(optax.clip_by_global_norm(1.),
                 optax.adamw(3e-4, weight_decay=0., eps=1e-8))
train_state = TrainState.create(apply_fn=agent.apply, params=agent_params, tx=tx)

pbar = tqdm(range(2000))
for i in pbar:
    batch = sample_batch(rng, dataset, 32)
    train_state, loss = iter_step(train_state, batch)
    pbar.set_postfix({'loss': loss})

  0%|          | 0/2000 [00:00<?, ?it/s]

In [39]:
import gymnasium as gym
def make_env(env_id):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation space
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.ClipAction(env)
        return env
    return thunk


In [40]:
envs = gym.vector.SyncVectorEnv( [make_env("HalfCheetah-v4") for i in range(64)] )

In [41]:
device = 'cuda:0'
agent = Agent(envs, 0.).to(device)
load_dir = '/data/vision/phillipi/akumar01/synthetic-mdps-data/datasets/mujoco/HalfCheetah'
agent.load_state_dict(torch.load(f"{load_dir}/model.pth"))
with open(f"{load_dir}/env_obs_rms.pkl", "rb") as f:
    env_obs_rms = pickle.load(f)
obs_mean, obs_var = torch.Tensor(env_obs_rms["mean"]).to(device), torch.Tensor(env_obs_rms["var"]).to(device)


In [42]:
stats = []
obs, info = envs.reset()
for i in tqdm(range(1005)):
    # obs_agent = torch.tensor(obs, dtype=torch.float32).to(device)
    obs_agent = (torch.tensor(obs, dtype=torch.float32).to(device) - obs_mean) / torch.sqrt(obs_var + 1e-8)
    act, _, _, _, act_mean = agent.get_action_and_value(obs_agent)
    act = act_mean.detach().cpu().numpy()
    obs, rew, term, trunc, infos = envs.step(act)
    if "final_info" in infos:
        for info in infos["final_info"]:
            if info and "episode" in info:
                stats.append((info["episode"]["r"], info["episode"]["l"]))

  0%|          | 0/1005 [00:00<?, ?it/s]

In [43]:
np.array(stats)[:, 0].mean()

5397.595832824707

In [59]:
agent2 = Agent(envs, 0.).to(device)
opt = torch.optim.Adam(agent2.parameters(), lr=3e-4)

pbar = tqdm(range(50000))
for i in pbar:
    batch = sample_batch(dataset, 32)
    x, y = torch.tensor(batch['obs'], dtype=torch.float32).to(device), torch.tensor(batch['act_mean'], dtype=torch.float32).to(device)
    _, _, _, _, act_mean = agent2.get_action_and_value(x)
    loss = ((act_mean - y) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    pbar.set_postfix(loss=loss.item())
    


  0%|          | 0/50000 [00:00<?, ?it/s]

In [60]:
stats = []
obs, info = envs.reset()
for i in tqdm(range(1005)):
    obs_agent = torch.tensor(obs, dtype=torch.float32).to(device)
    # obs_agent = (torch.tensor(obs, dtype=torch.float32).to(device) - obs_mean) / torch.sqrt(obs_var + 1e-8)
    act, _, _, _, act_mean = agent2.get_action_and_value(obs_agent)
    act = act_mean.detach().cpu().numpy()
    obs, rew, term, trunc, infos = envs.step(act)
    if "final_info" in infos:
        for info in infos["final_info"]:
            if info and "episode" in info:
                stats.append((info["episode"]["r"], info["episode"]["l"]))

  0%|          | 0/1005 [00:00<?, ?it/s]

In [61]:
np.array(stats)[:, 0].mean()

5271.7492961883545