In [1]:
####

In [2]:
import torch
from torch import nn
import gym
import numpy as np
import torch.nn.functional as nnf

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class Config:
    def __init__(self):
        
        self.env_name = 'CartPole-v1'
        self.batch_size = 200
        self.hidden_dim = 128
        self.target_reward_level = 199
        self.lr = 0.01

In [5]:
class Network(nn.Module):
    def __init__(self , 
                 in_channels , 
                 hidden_dim , 
                 out_channels):
        super(Network , self).__init__()

        self.linear1 = nn.Linear(in_channels , hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim , out_channels)

    def forward(self , x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

In [6]:
class Episode:
    def __init__(self):
        
        self.total_reward = 0
        
        self.states = []
        self.actions = []
        self.rewards = []

    def add_step(self , state , action , reward):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.total_reward += reward

In [7]:
class Session:
    def __init__(self , 
                 env , 
                 network, 
                 config = Config):
        
        self.env = env
        self.config = Config()
        self.number_of_training_games = self.config.batch_size
        self.target_reward_level = self.config.target_reward_level
        self.net = network

        self.opt = torch.optim.Adam(self.net.parameters() , lr = self.config.lr)
        self.loss_fn = nn.CrossEntropyLoss()

    def get_batches(self):
        
        batch = []
        episode = Episode()
        state = self.env.reset()
        number_of_actions = self.env.action_space.n

        while True:
            actions_score = self.net(torch.FloatTensor(state))
            actions_prob = nnf.softmax(actions_score, dim=0)

            action = np.random.choice(number_of_actions, size=1, p=actions_prob.data.numpy())[0]

            new_state, reward, is_finished, _ = self.env.step(action)
            episode.add_step(state, action, reward)

            if is_finished:
                if len(batch) == self.number_of_training_games:
                    return batch
                batch.append(episode)
                new_state = self.env.reset()
                episode = Episode()

            state = new_state

    def train(self):
        step = 0

        while True:
            self.opt.zero_grad()
            batch = self.get_batches()
            mean_reward = np.array([episode.total_reward for episode in batch]).mean()
            episode_state = []
            episode_actions = []

            for episode in batch:
                
                if episode.total_reward > mean_reward:
                    episode_state.extend(episode.states)
                    episode_actions.extend(episode.actions)

            
            predicted_actions_scores = self.net(torch.FloatTensor(episode_state))
            episode_actions = torch.LongTensor(episode_actions)
            loss = self.loss_fn(predicted_actions_scores , episode_actions)
            loss.backward()
            self.opt.step()

            if mean_reward > self.target_reward_level:
                print(f'Step {step} , loss {loss.item()} , reward {mean_reward} ')
                print('Solved')
                break
            step += 1
            print(f'Step {step} , loss {loss.item()} , reward {mean_reward} ')

In [None]:
config = Config()
env = gym.make(config.env_name)
net = Network(env.observation_space.shape[0] , 
              config.hidden_dim , 
              env.action_space.n).to(device)
session = Session(env , net)
session.train()