In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import numpy as np 
import time 
import numpy as np
import torch
import torch.nn as nn
import time
import gym 
import pickle  
import argparse
import torch.nn.functional as F
from torch.distributions import Normal
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import TanhTransform
from torch.optim import Adam
import tqdm

from torch.utils.data import DataLoader 

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
env_name='Ant-v3' 

In [5]:
data_path = "expert_data/Ant-v3_10_3765.pkl"  

with open(data_path, "rb") as f:
    data_good = pickle.load(f)
print('expert data loaded')

# data_good = data_good[:10]

expert data loaded


In [6]:
good_obs=[]
good_acts=[]

for traj in data_good: 
    s,a,r=traj  

    good_obs.append(s)
    good_acts.append(a)

good_obs=np.vstack(good_obs)
good_acts=np.vstack(good_acts)

In [7]:
good_obs.shape, good_acts.shape

((10000, 111), (10000, 8))

In [8]:
data_loader = DataLoader( list(zip(good_obs, good_acts)), batch_size=64, shuffle=True)

batch=next(iter(data_loader))
states,actions = batch
states.shape,actions.shape

(torch.Size([64, 111]), torch.Size([64, 8]))

In [9]:
action_dim=actions.shape[1]
state_dim=states.shape[1]
print(state_dim, action_dim)

111 8


In [10]:
class MLP(nn.Module):
    def __init__(self, input_dim, size=32):
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim,size),
            nn.ReLU(), 
            nn.Linear(size,size),
            nn.ReLU() 
        )
    def forward(self,x):
        x = self.net(x)
        return x
    
class RegNet(MLP):
    def __init__(self, input_dim , size, action_dim):
        super(RegNet, self).__init__(input_dim, size)
        self.decoder = nn.Linear(size, action_dim)
    def forward(self,x):
        x = self.net(x)
        x = self.decoder(x)
        return x

In [11]:
learning_rate = 1e-4

bc = RegNet(state_dim, 64, action_dim)
criterion = nn.MSELoss() 
optimizer = Adam(bc.parameters(), lr = learning_rate)

In [12]:
loss_list = []
test_loss = [] 
n_epoch = 1_000
 
for itr in range(0, n_epoch+1):
    total_loss = 0
    b=0
    for batch_states, batch_actions in data_loader: 
        y_pred = bc(batch_states.float())
        loss   = criterion(y_pred, batch_actions) 
        total_loss += loss.item() 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        b += 1 
        
        loss_list.append(loss.item())
        
    if itr % (n_epoch//20)==0:
        print(f'Epoch {itr} Loss: {total_loss/b:.4f}')

Epoch 0 Loss: 0.0869
Epoch 50 Loss: 0.0070
Epoch 100 Loss: 0.0055
Epoch 150 Loss: 0.0049
Epoch 200 Loss: 0.0045
Epoch 250 Loss: 0.0042
Epoch 300 Loss: 0.0040
Epoch 350 Loss: 0.0039
Epoch 400 Loss: 0.0037
Epoch 450 Loss: 0.0036
Epoch 500 Loss: 0.0036
Epoch 550 Loss: 0.0035
Epoch 600 Loss: 0.0034
Epoch 650 Loss: 0.0034
Epoch 700 Loss: 0.0033
Epoch 750 Loss: 0.0033
Epoch 800 Loss: 0.0032
Epoch 850 Loss: 0.0032
Epoch 900 Loss: 0.0032
Epoch 950 Loss: 0.0031
Epoch 1000 Loss: 0.0031


In [13]:
def play(env, policy, is_close=True, is_render=True, max_step=1000): 
    obs,info = env.reset()
    dones=False
    total_r=0
    step=0
    while not dones: 
        step+=1
        obs=torch.Tensor(obs[None]).to(device)
        ac = policy(obs)
        action=ac.cpu().detach().numpy() 

        obs, rewards, done, s, info = env.step(action.ravel())
        total_r +=rewards  
        if done:
            break
        if step>max_step:
            # print('max step reached')
            break
        # elif s:
        #     print('solved!')
        #     break
    if is_close:
        env.close()
    return {'reward':total_r, 'step':step-1}

In [14]:
env = gym.make(env_name)
play(env, bc, is_close=True, is_render=False)

{'reward': 4048.60897787418, 'step': 1000}

In [15]:
scores=[]
n_trajectory=20
for i in range(n_trajectory):
    stats=play(env, bc, is_close=True, is_render=False)
    rewards=stats['reward']
    print(f'episode #{i} reward: {rewards:0.2f}')
    scores.append(rewards)

print(f'\n score: {np.mean(scores):0.2f} +- {np.std(scores):0.2f}')

episode #0 reward: 4137.56
episode #1 reward: 4169.83
episode #2 reward: 4121.55
episode #3 reward: 4164.70
episode #4 reward: 3919.30
episode #5 reward: 3715.23
episode #6 reward: 4086.71
episode #7 reward: 4140.05
episode #8 reward: 4035.92
episode #9 reward: 4138.06
episode #10 reward: 4012.36
episode #11 reward: 2208.48
episode #12 reward: 4056.03
episode #13 reward: 4132.13
episode #14 reward: 4008.43
episode #15 reward: 4096.09
episode #16 reward: 4094.08
episode #17 reward: 4177.33
episode #18 reward: 4035.93
episode #19 reward: 4138.87

 score: 3979.43 +- 419.33
