In [8]:
import argparse
import gym
import numpy as np
from itertools import count

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical



class Policy(nn.Module):
    def __init__(self,prob):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=prob)
        self.affine2 = nn.Linear(128, 2)

        self.saved_log_probs = []
        self.rewards = []

    def adapt(self,prob):
        self.dropout = nn.Dropout(p=prob)

        
    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)



In [9]:
class ReinforcementLearning():      
    def __init__(self, config):

        self.env = gym.make('CartPole-v1')
        self.policy = Policy(config.get("prob"))
        self.optimizer = optim.Adam(self.policy.parameters(), lr=config.get("lr"),
                betas=((config.get("b1", 0.999), config.get("b2", 0.9999))),
                eps=config.get("eps1", 1e-08),
                weight_decay=config.get("weight_decay", 0))
        self.eps = config.get("eps2", 1e-08)
        self.gamma = config.get("gamma")
        self.exploration = config.get("exploration")

        self.running_reward = 10
        self.env.seed(543)

    def select_action(self,state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.policy(state)
        m = Categorical(probs)
        action = m.sample()
        self.policy.saved_log_probs.append(m.log_prob(action))
        return action.item()


    def adapt(self,config):
        self_copy = copy.deepcopy(self)
        self_copy.policy.adapt(config.get("prob"))

        self_copy.optimizer = optim.Adam(self.policy.parameters(), lr=config.get("lr"),
                betas=((config.get("b1", 0.999), config.get("b2", 0.9999))),
                eps=config.get("eps1", 1e-08),
                weight_decay=config.get("weight_decay", 0))
        self_copy.eps = config.get("eps2", 1e-08)
        self_copy.gamma = config.get("gamma")
        self_copy.exploration = config.get("exploration")
        
        return self_copy
    
    
    
    def finish_episode(self):
        R = 0
        policy_loss = []
        returns = []
        for r in self.policy.rewards[::-1]:
            R = r + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + self.eps)
        for log_prob, R in zip(self.policy.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)
        self.optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        self.optimizer.step()
        del self.policy.rewards[:]
        del self.policy.saved_log_probs[:]

    def step(self):
        for _ in range(10):
            state, ep_reward = self.env.reset(), 0
            for t in range(1, 10000):  # Don't infinite loop while learning
                action = self.select_action(state)
                state, reward, done, _ = self.env.step(action)
                self.policy.rewards.append(reward)
                ep_reward += reward
                if done:
                    break

            self.running_reward = self.exploration * ep_reward + (1 - self.exploration) * self.running_reward
            self.finish_episode()
          #  if i_episode % args.get("log_interval") == 0:
          #      print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
          #            i_episode, ep_reward, self.running_reward))

        return self.running_reward

    

torch.manual_seed(543)
config = {"prob":.7,"lr":.01, "b1": 0.999, "b2": 0.9999,
                "eps": 1e-08,
                "weight_decay": 0,"eps2": 1e-08 ,"eps1": 1e-08,"gamma" : 0.99,"exploration":.05} 


    

In [11]:
RL = ReinforcementLearning(config)
RL1 = ReinforcementLearning(config)



for _ in range(20):
    result = RL1.step()
    print(result)



17.121750099806345
18.635854642973584
24.20474515199965
28.147690838309018
39.02793893504587
40.84662878441295
52.18124915455773
66.85996455705856
70.14965854777962
66.67082286496249
62.186129417873985
59.3535076878001
58.014783113753055
63.70149607263607
64.60618971728609
85.11384765768422
105.2349279993666
130.28302168310526
150.77848943327317
165.30483193004332
