In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
from tqdm import tqdm
from tqdm import trange
import numpy as np
import time

In [3]:
import torch as T
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch


In [4]:
import gym
render=True
env = gym.make('CartPole-v0')
print(env.observation_space.shape, env.action_space.n)


(4,) 2


In [5]:
class BCO(nn.Module):
    def __init__(self, env, policy='mlp'):
        super(BCO, self).__init__()
        
        self.policy = policy
        self.act_n = env.action_space.n
        
        if self.policy=='mlp':
            self.obs_n = env.observation_space.shape[0]
            self.pol = nn.Sequential(*[nn.Linear(self.obs_n, 32),
                                       nn.Linear(32, self.act_n)])
            self.inv = nn.Sequential(*[nn.Linear(self.obs_n*2, 32),
                                       nn.Linear(32, self.act_n)])
        
        elif self.policy=='cnn':
            pass
    
    def pred_act(self, obs):
        out = self.pol(obs)
        
        return out
    
    def pred_inv(self, obs1, obs2):
        obs = T.cat([obs1, obs2], dim=1)
        out = self.inv(obs)
        
        return out

POLICY = 'mlp'
model = BCO(env, policy=POLICY).cuda()

In [6]:
from torch.utils.data import Dataset, DataLoader

class DS_Inv(Dataset):
    def __init__(self, trajs):
        self.dat = []
        
        for traj in trajs:
            for dat in traj:
                obs, act, new_obs = dat
                
                self.dat.append([obs, new_obs, act])
    
    def __len__(self):
        return len(self.dat)
    
    def __getitem__(self, idx):
        obs, new_obs, act = self.dat[idx]
        
        return obs, new_obs, np.asarray(act)

class DS_Policy(Dataset):
    def __init__(self, traj):
        self.dat = []
        
        for dat in traj:
            obs, act = dat
                
            self.dat.append([obs, act])
    
    def __len__(self):
        return len(self.dat)
    
    def __getitem__(self, idx):
        obs, act = self.dat[idx]
        
        return obs, np.asarray(act)

In [7]:
import pickle

trajs_demo = pickle.load(open('Demo/demo_cart-pole.pkl', 'rb'))
ld_demo = DataLoader(DS_Inv(trajs_demo), batch_size=100)

for obs1, obs2, act  in ld_demo:
    print(obs1.shape,act.shape, obs2.shape)
    break
    

torch.Size([100, 4]) torch.Size([100]) torch.Size([100, 4])


In [8]:
loss_func = nn.CrossEntropyLoss().cuda()
optim = T.optim.Adam(model.parameters(), lr=5e-3)

alpha = 0
M = 1000

EPS = 0.9
DECAY = 0.5
random_seed = 42
epochs = 20
patience = 5

In [9]:
def train_valid_loader(dataset, batch_size, validation_split, shuffle_dataset):
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    train_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,
                                               sampler=train_sampler)
    validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                    sampler=valid_sampler)
    return train_loader, validation_loader

def train_NN(train_loader, NN):
    
    with tqdm(train_loader, desc='Training',  disable = True, position=0, leave=True) as TQ:
        ls_ep = 0
        correct = 0
        total = 0
        
        if (NN == 'inv'):
            for obs1, obs2, act in TQ:
                out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
                out_act = torch.argmax(out.cpu().detach(), axis=1)
                ls_bh = loss_func(out, act.cuda())
                
                optim.zero_grad()
                ls_bh.backward()
                optim.step()

                ls_bh = ls_bh.cpu().detach().numpy()
                TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
                ls_ep += ls_bh
                total += obs1.shape[0]
                correct += (out_act == act).sum().item()
                
        elif(NN == 'pred'):
            for obs, act in TQ:
                out = model.pred_act(obs.float().cuda())
                out_act = torch.argmax(out.cpu().detach(), axis=1)
                ls_bh = loss_func(out, act.cuda())

                optim.zero_grad()
                ls_bh.backward()
                optim.step()

                ls_bh = ls_bh.cpu().detach().numpy()
                TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
                ls_ep += ls_bh
                total += obs.shape[0]
                correct += (out_act == act).sum().item()
            
        ls_ep /= len(TQ)
        accuracy = 100*correct/total
        
    return ls_ep, accuracy

def validate_NN(validation_loader, NN):
    
    with tqdm(validation_loader, desc='Validate', disable = True, position=0, leave=True) as TQ:
        ls_val_ep = 0
        correct = 0
        total = 0
        
        if (NN == 'inv'):
            for obs1, obs2, act in TQ:
                out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
                out_act = torch.argmax(out.cpu().detach(), axis=1)
                ls_bh = loss_func(out, act.cuda())
                ls_bh = ls_bh.cpu().detach().numpy()
                TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
                ls_val_ep += ls_bh
                total += obs1.shape[0]
                correct += (out_act == act).sum().item()
        elif (NN == 'pred'):
            for obs, act in TQ:
                out = model.pred_act(obs.float().cuda())
                out_act = torch.argmax(out.cpu().detach(), axis=1)
                ls_bh = loss_func(out, act.cuda())
                ls_bh = ls_bh.cpu().detach().numpy()
                TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
                ls_val_ep += ls_bh
                total += obs.shape[0]
                correct += (out_act == act).sum().item()
            
        ls_val_ep /= len(TQ)
        accuracy = 100*correct/total
        
        return ls_val_ep, accuracy

In [10]:
trajs_inv = []

tqdm_alpha = trange(alpha+1, position=0, desc='alpha:', leave=True)
for e in tqdm_alpha:
    
    # step1, generate inverse samples
    tqdm_alpha.set_description("alpha: %i, Step1: Exploration" % e,refresh=True)
    time.sleep(1)
    cnt = 0 #count
    epn = 0 #Episode number
    
    rews = 0 #Rewards
    
    while True:
        traj = []
        rew = 0
        N=0 
        obs = env.reset()
        while True:
            inp = T.from_numpy(obs).view(((1, )+obs.shape)).float().cuda()
            out = model.pred_act(inp).cpu().detach().numpy()
            act = np.argmax(out, axis=1)[0]
            if np.random.rand()>=EPS:
                act = np.argmax(out, axis=1)[0]
            else:
                act = env.action_space.sample()
        
            new_obs, r, done, _ = env.step(act)
                
            traj.append([obs, act, new_obs])
            obs = new_obs
            rew += r
            
            cnt += 1
            tqdm_alpha.set_description("alpha: %i, Step1: Exploration - %i" % (e,cnt),refresh=True)
            N+=1   
            if done==True:
                rews += rew
                trajs_inv.append(traj)
                
                epn += 1
                
                break
        
        if cnt >= M:
            break
        
    rews /= epn
    tqdm_alpha.set_description("alpha: %i, step1: Exploration, Reward: %.2f" % (e,rews),refresh=True)
    time.sleep(1)
        
    # step2, update inverse model
    
    acc_val_best = 0
    patience_cnt = 0
    tqdm_alpha.set_description("alpha: %i, Step2: Update Inverse Model" % e,refresh=True)
    time.sleep(1)
    tqdm_epoch = trange(epochs, position=0, desc='Epoch:', leave=True)
    for i in  tqdm_epoch:
        dataset=DS_Inv(trajs_inv)
        train_loader, validation_loader = train_valid_loader(dataset, batch_size=32, 
                                                             validation_split=0.3,
                                                             shuffle_dataset=True)
        
        ls_ep, acc_ep = train_NN(train_loader, NN = 'inv')
        ls_val_ep, acc_val_ep = validate_NN(validation_loader, NN = 'inv')
        
        tqdm_epoch.set_description("ID Model Update - Epoch: %i, val loss: %.3f" % (i,ls_val_ep),refresh=True)
        
        if acc_val_ep > acc_val_best:
            acc_val_best = acc_val_ep
            patience_cnt = 0
    
        else:
            patience_cnt += 1
            if patience_cnt == patience:
#                 tqdm.write("break")
                break
        
    
    #step3, predict actions
    tqdm_alpha.set_description("alpha: %i, Step3: Predict most probable actions for expert demos" % e,refresh=True)
    traj_policy = []
    
    for obs1, obs2, _ in ld_demo:
        out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
        obs = obs1.cpu().detach().numpy()
        out = out.cpu().detach().numpy()
        out = np.argmax(out, axis=1)
        for i in range(100):
            traj_policy.append([obs[i], out[i]])

    # step4, update policy via demo samples
    tqdm_alpha.set_description("alpha: %i, Step4: Update Policy" % e,refresh=True)
    acc_val_best = 0
    patience_cnt = 0
    tqdm_epoch = trange(epochs, position=0, desc='Epochs', leave=True)
    for i in  tqdm_epoch:
        dataset=DS_Policy(traj_policy)
        train_loader, validation_loader = train_valid_loader(dataset, batch_size=32, 
                                                             validation_split=0.3,
                                                             shuffle_dataset=True)
        
        ls_ep, acc_ep = train_NN(train_loader, NN = 'pred')
        ls_val_ep, acc_val_ep = validate_NN(validation_loader, NN = 'pred')
        
        tqdm_epoch.set_description("Policy Update - Epoch: %i, val loss: %.3f" % (i,ls_val_ep),refresh=True)
        
        if acc_val_ep > acc_val_best:
            acc_val_best = acc_val_ep
            patience_cnt = 0
    
        else:
            patience_cnt += 1
            if patience_cnt == patience:
#                 tqdm.write("break")
                break
        
    time.sleep(1)
    # step5, save model
    T.save(model.state_dict(), 'Model/model_cart-pole_%d.pt' % (e+1))
    
    EPS *= DECAY

ID Model Update - Epoch: 7, val loss: 0.011:  35%|███▌      | 7/20 [00:00<00:00, 17.68it/s]
Policy Update - Epoch: 17, val loss: 0.477:  85%|████████▌ | 17/20 [00:00<00:00, 18.29it/s]           
alpha: 0, Step4: Update Policy: 100%|██████████| 1/1 [00:06<00:00,  6.07s/it]


In [12]:
import time

rews = 0

for i_episode in range(20):
    observation = env.reset()
    rews = 0
    for t in range(200):
        env.render()
        inp = T.from_numpy(observation).view(((1, )+observation.shape)).float().cuda()
        out = model.pred_act(inp).cpu().detach().numpy()
        act = np.argmax(out, axis=1)[0]  ## Take actions predicted by the inverse dynamics model
        observation, reward, done, _ = env.step(act)
        rews += reward
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            print(rews)
            break
env.close()


Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
Episode finished after 200 timesteps
200.0
