In [87]:
import gym
from PPO import PPO
from PIL import Image
import torch
import torch.nn as nn
from torch.distributions import Categorical
import time
import numpy as np

In [2]:
env_name = "LunarLander-v2"
# creating environment
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = 4
device = "cpu"

In [3]:
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, n_latent_var):
        super(ActorCritic, self).__init__()

        # actor
        self.action_layer = nn.Sequential(
                nn.Linear(state_dim, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, action_dim),
                nn.Softmax(dim=-1)
                )
        
        # critic
        self.value_layer = nn.Sequential(
                nn.Linear(state_dim, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, 1)
                )
        
    def forward(self):
        raise NotImplementedError
        
    def act(self, state):
        state = torch.from_numpy(state).float().to(device) 
        action_probs = self.action_layer(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item()

    
ppo = ActorCritic(state_dim, action_dim, 64)

In [4]:
ppo.load_state_dict(torch.load("PPO_LunarLander-v2.pth"))

<All keys matched successfully>

In [5]:
## This is only for testing if it works
state = env.reset()
while True:
    action = ppo.act(state)
    env.render()
    next_state, reward, done, _ = env.step(action)
    state = next_state
    time.sleep(0.01)
    if done: break
env.close()

#### Start of behaviour Cloning

In [85]:
class Storage:
    def __init__(self):
        self.states = []
        self.actions = []

    def reset(self):
        del self.states[:]
        del self.actions[:]

    def return_onehot(self, data, size = action_dim):
        final = [0]*size
        final[data] += 1
        return torch.tensor(final).float().reshape(1,-1).to(device)

    def stack(self, data):
        return torch.cat(data, dim=0).detach()

    def sample(self):
        return self.stack(self.states), self.stack(self.actions)

def random_sample(indices, batch_size):
    indices = np.asarray(np.random.permutation(indices))
    batches = indices[:len(indices) // batch_size * batch_size].reshape(-1, batch_size)
    for batch in batches:
        yield batch
    r = len(indices) % batch_size
    if r:
        yield indices[-r:]


In [7]:
max_len = 200
initial_episode = 20
lr = 0.0003             
betas = (0.9, 0.999)
initial_collection_visizliztion = False

In [8]:
## Data Collection 
storage = Storage()
for k in range(initial_episode):
    state = env.reset()
    for i in range(max_len):
        action = ppo.act(state)
        if initial_collection_visizliztion: env.render()
        storage.states.append(torch.tensor(state).reshape(1,-1))
        storage.actions.append(storage.return_onehot(action))
        next_state, reward, done, _ = env.step(action)
        state = next_state
        time.sleep(0.01)
        if done: break
    print("\rCurrent Iteration {} ".format(k), end = " ")
    if initial_collection_visizliztion: env.close()

Current Iteration 19

In [60]:
states, actions = storage.sample()
behavior_cloniing = ActorCritic(state_dim, action_dim, 64).to(device)
optimizer = torch.optim.Adam(behavior_cloniing.parameters(), lr=lr, betas=betas)
# optimizer = torch.optim.SGD(behavior_cloniing.parameters(), lr = 0.01)
MseLoss = nn.MSELoss()

In [83]:
x_data = behavior_cloniing.action_layer(states)
loss = MseLoss(x_data, actions)
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()

tensor(0.1762, grad_fn=<MseLossBackward>)


In [90]:
for doc in random_sample(np.arange(states.size()[0]), 64):
    print(doc)

[1846  898 1347  810 2847 3132  576 2931  344  340 1327 3309 1516 2436
 1527  773 2162  435 3585 1746 1004 2252 1222  695 3364   70 1869  926
 2979 2192  176 3406 3073  841 3157 3480 3317 1994 1919 2321 3397 3684
 1073 2114 3469 1258 2104 1894 1326  520 1816 3083 1408  737 3053 3232
 1660 3189  338 2205 3351 3051 2050 1978]
[2663 2178 3539  900 1480 2670 1707  826  676  414 3398 2057 3565 3363
 3587 3408 3786  967 1604 1926 3164  146  345 3037 2987 2547 3431 1611
  953  620  134 1944 1390 1889 3430 3240 2893 1682 3548 2019 3453 3429
 1849  430 1773 1814 1883 2606 3758 2013   50 1239  210 1478  612 1381
 2268  586 3218  624  208 2179 3117 1809]
[1207 3417 1063 3675 1643 3203  952  594  809 2233 3372 1284 1219 1562
 1017 1429 2518  566 3436 2201 2461 3180 3022 1501 1606 1134 1398  436
 2052 1195 2157  139 1345 2673  426   82 3691  260 3043 1288 2564  154
 3034  709 1010 2623 1953 2628  485  818 3359 2210 3168  906  214 1340
 2545 3027  815 2982 1521 1864 2714 2046]
[2791 1412 1437 3493 2