In [None]:
#| default_exp cartpole_cross_entropy2

In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import namedtuple
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import LightningModule, Trainer

# Constants
HIDDEN_SIZE = 128
BATCH_SIZE = 16
PERCENTILE = 70

# Named tuples to store episodes and steps
Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action'])

  warn(f"Failed to load image Python extension: {e}")


In [None]:
# NN for our agent
class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x):
        return self.net(x)


In [None]:
def iterate_batches(env, net, batch_size=16):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs = env.reset()
    sm = nn.Softmax(dim=1)
    while True:
        obs_v = torch.FloatTensor([obs])
        act_probs_v = sm(net(obs_v))
        act_probs = act_probs_v.data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, is_done, _ = env.step(action)
        episode_reward += reward
        step = EpisodeStep(observation=obs, action=action)
        episode_steps.append(step)
        if is_done:
            e = Episode(reward=episode_reward, steps=episode_steps)
            batch.append(e)
            episode_reward = 0.0
            episode_steps = []
            next_obs = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs = next_obs


def filter_batch(batch, percentile = 70):
    rewards = list(map(lambda s: s.reward, batch))
    reward_bound = np.percentile(rewards, percentile)
    reward_mean = float(np.mean(rewards))

    train_obs = []
    train_act = []
    for reward, steps in batch:
        if reward < reward_bound:
            continue
        train_obs.extend(map(lambda step: step.observation, steps))
        train_act.extend(map(lambda step: step.action, steps))

    train_obs_v = torch.FloatTensor(train_obs)
    train_act_v = torch.LongTensor(train_act)
    return train_obs_v, train_act_v, reward_bound, reward_mean


In [None]:
class CartPoleModule(LightningModule):
    def __init__(self, obs_size, n_actions):
        super(CartPoleModule, self).__init__()
        self.net = Net(obs_size, HIDDEN_SIZE, n_actions)
        self.objective = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(params=self.net.parameters(), lr=0.01)
        self.batch_size = 16
        self.percentile = 70
        
    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        obs_v, acts_v, reward_b, reward_m = filter_batch(batch, self.percentile)
        action_scores_v = self.forward(obs_v)
        loss_v = self.objective(action_scores_v, acts_v)
        logs = {'loss': loss_v, 'reward_mean': reward_m, 'reward_bound': reward_b}
        return {'loss': loss_v, 'log': logs}
    
    def on_train_epoch_end(self, outputs):
        reward_mean = np.mean([x['log']['reward_mean'] for x in outputs])
        if reward_mean > 199:
            print("Solved!")
            trainer.stop_training = True

    def on_train_end(self):
        self.trainer.save_checkpoint('model.ckpt')

    def train_dataloader(self):
        env = gym.make("CartPole-v0")
        return iterate_batches(env, self.net, self.batch_size)

    
    def configure_optimizers(self):
        return self.optimizer


In [None]:
# Create the PyTorch Lightning module and trainer
env = gym.make("CartPole-v0")
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
module =CartPoleModule(obs_size, n_actions)


# Create a Trainer instance with a logger callback
trainer = Trainer()

# Start training
trainer.fit(module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /home/meherkh/Meher/deep_reinforcement_learning_hands_on/nbs/lightning_logs

  | Name      | Type             | Params
-----------------------------------------------
0 | net       | Net              | 898   
1 | objective | CrossEntropyLoss | 0     
-----------------------------------------------
898       Trainable params
0         Non-trainable params
898       Total params
0.004     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
