In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
from tqdm import tqdm
import numpy as np

In [3]:
import torch as T
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch


In [4]:
import gym
render=True
env = gym.make('MountainCar-v0')
print(env.observation_space.shape, env.action_space.n)


(2,) 3


In [5]:
class BCO(nn.Module):
    def __init__(self, env, policy='mlp'):
        super(BCO, self).__init__()
        
        self.policy = policy
        self.act_n = env.action_space.n
        
        if self.policy=='mlp':
            self.obs_n = env.observation_space.shape[0]
            self.pol = nn.Sequential(*[nn.Linear(self.obs_n, 8), nn.LeakyReLU(),
                                       nn.Linear(8, 8), nn.LeakyReLU(),
                                       nn.Linear(8, self.act_n)])
            self.inv = nn.Sequential(*[nn.Linear(self.obs_n*2, 8), nn.LeakyReLU(),
                                       nn.Linear(8, 8), nn.LeakyReLU(),
                                       nn.Linear(8, self.act_n)])
        
        elif self.policy=='cnn':
            pass
    
    def pred_act(self, obs):
        out = self.pol(obs)
        
        return out
    
    def pred_inv(self, obs1, obs2):
        obs = T.cat([obs1, obs2], dim=1)
        out = self.inv(obs)
        
        return out

POLICY = 'mlp'
model = BCO(env, policy=POLICY).cuda()

In [31]:
from torch.utils.data import Dataset, DataLoader

class DS_Inv(Dataset):
    def __init__(self, trajs):
        self.dat = []
        
        for traj in trajs:
            for dat in traj:
                obs, act, new_obs = dat
                
                self.dat.append([obs, new_obs, act])
    
    def __len__(self):
        return len(self.dat)
    
    def __getitem__(self, idx):
        obs, new_obs, act = self.dat[idx]
        
        return obs, new_obs, np.asarray(act)

class DS_Policy(Dataset):
    def __init__(self, traj):
        self.dat = []
        
        for dat in traj:
            obs, act = dat
                
            self.dat.append([obs, act])
    
    def __len__(self):
        return len(self.dat)
    
    def __getitem__(self, idx):
        obs, act = self.dat[idx]
        
        return obs, np.asarray(act)

In [33]:
import pickle

trajs_demo = pickle.load(open('Demo/demo_mountaincar.pkl', 'rb'))
rew = 0
var = 0
temp = 128.32
print(len(trajs_demo))
for i in range(len(trajs_demo)):
    rew += len(trajs_demo[i])
    var += (len(trajs_demo[i]) - temp)**2
avg = rew/len(trajs_demo)
var = (var/len(trajs_demo))**0.5
print(avg)
print(var)
ld_demo = DataLoader(DS_Inv(trajs_demo), batch_size=100)
# print(len(ld_demo))
# for obs1, obs2, act  in ld_demo:
#     print(obs1.shape,act.shape, obs2.shape)
#     break
    

25
128.32
26.15296541503468


In [22]:
loss_func = nn.CrossEntropyLoss().cuda()
optim = T.optim.Adam(model.parameters(), lr=5e-4)

alpha = 0
M = 2000

EPS = 0.9
DECAY = 0.5
random_seed = 40
epochs = 100
patience = 10

In [24]:
def train_valid_loader(dataset, batch_size, validation_split, shuffle_dataset):
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    train_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,
                                               sampler=train_sampler)
    validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                    sampler=valid_sampler)
    return train_loader, validation_loader

def train_NN(train_loader, NN):
    
    with tqdm(train_loader) as TQ:
        ls_ep = 0
        correct = 0
        total = 0
        
        if (NN == 'inv'):
            for obs1, obs2, act in TQ:
                out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
                out_act = torch.argmax(out.cpu().detach(), axis=1)
                ls_bh = loss_func(out, act.cuda())
                
                optim.zero_grad()
                ls_bh.backward()
                optim.step()

                ls_bh = ls_bh.cpu().detach().numpy()
                TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
                ls_ep += ls_bh
                total += obs1.shape[0]
                correct += (out_act == act).sum().item()
                
        elif(NN == 'pred'):
            for obs, act in TQ:
                out = model.pred_act(torch.tensor(obs).float().cuda())
                out_act = torch.argmax(out.cpu().detach(), axis=1)
                ls_bh = loss_func(out, act.cuda())

                optim.zero_grad()
                ls_bh.backward()
                optim.step()

                ls_bh = ls_bh.cpu().detach().numpy()
                TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
                ls_ep += ls_bh
                total += obs.shape[0]
                correct += (out_act == act).sum().item()
            
        ls_ep /= len(TQ)
        accuracy = 100*correct/total
        
    return ls_ep, accuracy

def validate_NN(validation_loader, NN):
    
    with tqdm(validation_loader) as TQ:
        ls_val_ep = 0
        correct = 0
        total = 0
        
        if (NN == 'inv'):
            for obs1, obs2, act in TQ:
                out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
                out_act = torch.argmax(out.cpu().detach(), axis=1)
                ls_bh = loss_func(out, act.cuda())
                ls_bh = ls_bh.cpu().detach().numpy()
                TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
                ls_val_ep += ls_bh
                total += obs1.shape[0]
                correct += (out_act == act).sum().item()
        elif (NN == 'pred'):
            for obs, act in TQ:
                out = model.pred_act(torch.tensor(obs).float().cuda())
                out_act = torch.argmax(out.cpu().detach(), axis=1)
                ls_bh = loss_func(out, act.cuda())
                ls_bh = ls_bh.cpu().detach().numpy()
                TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
                ls_val_ep += ls_bh
                total += obs.shape[0]
                correct += (out_act == act).sum().item()
            
        ls_val_ep /= len(TQ)
        accuracy = 100*correct/total
        
        return ls_val_ep, accuracy

In [25]:
trajs_inv = []

for e in tqdm(range(alpha+1)):
    
    # step1, generate inverse samples
    cnt = 0 #count
    epn = 0 #Episode number
    
    rews = 0 #Rewards
    
    while True:
        traj = []
        rew = 0
        N=0 
        obs = env.reset()
        while True:
            inp = T.from_numpy(obs).view(((1, )+obs.shape)).float().cuda()
            out = model.pred_act(inp).cpu().detach().numpy()
            act = np.argmax(out, axis=1)[0]
            
            if np.random.rand()>=EPS:
                act = np.argmax(out, axis=1)[0]
            else:
                act = env.action_space.sample()
            
            new_obs, r, done, _ = env.step(act)
                
            traj.append([obs, act, new_obs])
            obs = new_obs
            rew += r
            
            cnt += 1
            N+=1   
            if done==True:
                rews += rew
                trajs_inv.append(traj)
                
                epn += 1
                
                break
        
        if cnt >= M:
            break

    rews /= epn
    tqdm.write('BCO_%d: reward=%.2f' % (e+1, rews))
        
    # step2, update inverse model
    
    ls_val_best = 5
    patience_cnt = 0
    for i in  range(epochs):
        dataset=DS_Inv(trajs_inv)
        train_loader, validation_loader = train_valid_loader(dataset, batch_size=64, 
                                                             validation_split=0.3,
                                                             shuffle_dataset=True)
        
        ls_ep, acc_ep = train_NN(train_loader, NN = 'inv')
        ls_val_ep, acc_val_ep = validate_NN(validation_loader, NN = 'inv')
        
        if ls_val_ep < ls_val_best:
            ls_val_best = ls_val_ep
            patience_cnt = 0
    
        else:
            patience_cnt += 1
            if patience_cnt == patience:
#                 print("break")
                break
        tqdm.write('Epoch %d: id loss_policy=%.3f' % (i+1, ls_ep))
        
    # step3, predict actions for demo trajectories
    traj_policy = []
    
    for obs1, obs2, _ in ld_demo:
        out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
        obs = obs1.cpu().detach().numpy()
        out = out.cpu().detach().numpy()
        out = np.argmax(out, axis=1)
        for i in range(100):
            traj_policy.append([obs[i], out[i]])
            
    # step4, update policy via demo samples
    ls_val_best = 5
    patience_cnt = 0
    for i in  range(epochs):
        dataset=DS_Policy(traj_policy)
        train_loader, validation_loader = train_valid_loader(dataset, batch_size=64, 
                                                             validation_split=0.3,
                                                             shuffle_dataset=True)
        
        ls_ep, acc_ep = train_NN(train_loader, NN = 'pred')
        ls_val_ep, acc_val_ep = validate_NN(validation_loader, NN = 'pred')
        
        if ls_val_ep < ls_val_best:
            ls_val_best = ls_val_ep
            patience_cnt = 0
    
        else:
            patience_cnt += 1
            if patience_cnt == patience:
#                 print("break")
                break
        tqdm.write('Epoch %d: validation id loss policy=%.3f' % (i+1, ls_val_ep))
    # step5, save model
    T.save(model.state_dict(), 'Model/model_mountaincar_%d.pt' % (e+1))
    
    EPS *= DECAY

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.159][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.138][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.178][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.154][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.202][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.214][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.190][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.219][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.183][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.187][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.258][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.146][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.195][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.094][A
 64%|██████▎   | 14/22 [00:00<00:00, 139.25it/s, loss_policy=1.094

BCO_1: reward=-200.00


[A
100%|██████████| 10/10 [00:00<00:00, 214.45it/s, loss_policy=1.222]
  0%|          | 0/1 [00:01<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.140][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.118][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.171][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.187][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.139][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.159][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.181][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.162][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.128][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.193][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.188][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.156][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.141][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.18

Epoch 1: id loss_policy=1.171


[A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.175][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.137][A
100%|██████████| 10/10 [00:00<00:00, 210.57it/s, loss_policy=1.125]
  0%|          | 0/1 [00:01<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.106][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.163][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.170][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.145][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.119][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.173][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.118][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.121][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.162][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.105][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.177][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.14

Epoch 2: id loss_policy=1.160
Epoch 3: id loss_policy=1.149


  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.084][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.120][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.153][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.156][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.145][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.101][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.158][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.105][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.174][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.131][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.161][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.172][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.130][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.117][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.139][A
  0%|          | 0/22 [00:00<?, ?it/s, loss

Epoch 4: id loss_policy=1.138


100%|██████████| 10/10 [00:00<00:00, 217.37it/s, loss_policy=1.181]
  0%|          | 0/1 [00:01<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.151][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.120][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.147][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.108][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.148][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.098][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.129][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.124][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.103][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.125][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.117][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.108][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.118][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.103][

Epoch 5: id loss_policy=1.127
Epoch 6: id loss_policy=1.118



  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.101][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.121][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.099][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.102][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.130][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.117][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.093][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.109][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.099][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.120][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.120][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.115][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.114][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.124][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.106][A
 68%|██████▊   | 15/22 [00:00<00:00, 148.77it/s, loss_policy=1.106][A
 68%|██████▊ 

Epoch 7: id loss_policy=1.110



  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.099][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.092][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.100][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.102][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.098][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.104][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.095][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.100][A
100%|██████████| 10/10 [00:00<00:00, 218.33it/s, loss_policy=1.106]
  0%|          | 0/1 [00:02<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.102][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.105][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.097][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.102][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.100][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.100]

Epoch 8: id loss_policy=1.103



100%|██████████| 10/10 [00:00<00:00, 193.44it/s, loss_policy=1.092]
  0%|          | 0/1 [00:02<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.095][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.092][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.100][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.089][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.087][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.100][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.077][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.089][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.104][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.090][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.089][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.093][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.091][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.092]

Epoch 9: id loss_policy=1.097


  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.104][A
100%|██████████| 10/10 [00:00<00:00, 209.02it/s, loss_policy=1.095]
  0%|          | 0/1 [00:03<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.108][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.082][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.087][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.080][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.087][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.091][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.095][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.086][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.098][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.101][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.085][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.091][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.090][

Epoch 10: id loss_policy=1.093



  0%|          | 0/1 [00:03<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.100][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.103][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.097][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.079][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.090][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.074][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.091][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.081][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.093][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.099][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.087][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.084][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.108][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.076][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.080][A
  0%

Epoch 11: id loss_policy=1.090


[A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.087][A
100%|██████████| 10/10 [00:00<00:00, 179.60it/s, loss_policy=1.095]
  0%|          | 0/1 [00:03<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.073][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.089][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.081][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.074][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.069][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.107][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.094][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.076][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.105][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.103][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.081][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.102][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.08

Epoch 12: id loss_policy=1.088



  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.071][A
100%|██████████| 10/10 [00:00<00:00, 211.00it/s, loss_policy=1.059]
  0%|          | 0/1 [00:03<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.104][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.073][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.087][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.079][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.087][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.090][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.107][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.105][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.061][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.089][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.074][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.067][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.061]

Epoch 13: id loss_policy=1.088



  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.102][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.069][A
100%|██████████| 10/10 [00:00<00:00, 202.99it/s, loss_policy=1.077]
  0%|          | 0/1 [00:03<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.096][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.104][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.088][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.099][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.104][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.055][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.092][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.061][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.062][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.085][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.065][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.095]

Epoch 14: id loss_policy=1.087


[A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.073][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.063][A
100%|██████████| 10/10 [00:00<00:00, 226.05it/s, loss_policy=1.107]
  0%|          | 0/1 [00:04<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.078][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.087][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.089][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.088][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.047][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.090][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.074][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.069][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.052][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.099][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.107][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.10

Epoch 15: id loss_policy=1.087
Epoch 16: id loss_policy=1.087



  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.072][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.103][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.101][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.072][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.063][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.092][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.084][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.090][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.094][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.068][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.062][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.074][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.098][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.119][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.070][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.083][A
 73%|███████▎  | 16/22 

Epoch 17: id loss_policy=1.086


  0%|          | 0/1 [00:04<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.120][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.095][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.101][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.070][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.082][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.070][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.096][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.078][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.071][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.082][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.095][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.076][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.070][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.101][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.070][A
 68%|

Epoch 18: id loss_policy=1.087



  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.087][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.098][A
100%|██████████| 10/10 [00:00<00:00, 202.01it/s, loss_policy=1.094]
  0%|          | 0/1 [00:04<?, ?it/s]
  0%|          | 0/22 [00:00<?, ?it/s][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.087][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.067][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.090][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.093][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.083][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.069][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.085][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.091][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.121][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.092][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.086][A
  0%|          | 0/22 [00:00<?, ?it/s, loss_policy=1.084]

Epoch 19: id loss_policy=1.087


  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.075][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.084][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.085][A
  0%|          | 0/10 [00:00<?, ?it/s, loss_policy=1.062][A
100%|██████████| 10/10 [00:00<00:00, 184.18it/s, loss_policy=1.045]
  0%|          | 0/1 [00:05<?, ?it/s]


Epoch 20: id loss_policy=1.087


IndexError: index 8 is out of bounds for axis 0 with size 8

In [8]:
import time

reward = 0

model = BCO(env, policy=POLICY).cuda()
model.load_state_dict(torch.load('Model/model_mountain_car_best.pt'))

for i_episode in range(20):
    observation = env.reset()
    for t in range(200):
        env.render()
        inp = T.from_numpy(observation).view(((1, )+observation.shape)).float().cuda()
        out = model.pred_act(inp).cpu().detach().numpy()
        act = np.argmax(out, axis=1)[0] ## Take actions predicted by the inverse dynamics model
        print(act)
        observation, rew, done, _ = env.step(act)
        time.sleep(0.02)
        reward += rew
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            print("Reward:", rew)
            break
env.close()


2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
1
1
1
1
1
1
1
1
1
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
1
1
1
1
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
Episode finished after 200 timesteps
Reward: -1.0
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
1
1
1
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
Episode finished after 200 timesteps
Reward: -1.0
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2


KeyboardInterrupt: 