In [87]:
import gym
import numpy as np
import torch
import matplotlib.pyplot as plt
import time

In [88]:
from gym.wrappers import Monitor

In [89]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [90]:
import math
import copy
from torch.distributions import Categorical
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [91]:
#Hyper-parameters
solved_reward = 30           # stop training if avg_reward > solved_reward
log_interval = 100           # print avg reward in the interval
max_episodes = 1000          # max training episodes
max_timesteps = 500          # max timesteps in one episode
n_latent_var = 64            # number of variables in hidden layer
update_timestep = 2000       # update policy every n timesteps
lr = 0.002
betas = (0.9, 0.999)
gamma = 0.99                 # discount factor
K_epochs = 1                 # update policy for K epochs
eps_clip = 0.2               # clip parameter for PPO
random_seed = None
render = False
epsilon = 0.2                #need to change it to max(advantage)
d_kl=1                       #need to change it to KL divergence between old and new policies

In [92]:
class CartPoleAI(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Sequential(
                        nn.Linear(4,128, bias=True),
                        nn.ReLU(),
                        nn.Linear(128,2, bias=True),
                        nn.Softmax(dim=1)
                        )

                
        def forward(self, inputs):
            x = self.fc(inputs)
            return x

In [93]:
def init_weights(m):
    
        # nn.Conv2d weights are of shape [16, 1, 3, 3] i.e. # number of filters, 1, stride, stride
        # nn.Conv2d bias is of shape [16] i.e. # number of filters
        
        # nn.Linear weights are of shape [32, 24336] i.e. # number of input features, number of output features
        # nn.Linear bias is of shape [32] i.e. # number of output features
        
        if ((type(m) == nn.Linear) | (type(m) == nn.Conv2d)):
            torch.nn.init.xavier_uniform(m.weight)
            m.bias.data.fill_(0.00)

In [94]:
def behavioural_policy(agents):
    return agent

In [95]:
def KL_divergence(agent1,agent2):
    return KL

In [96]:
def return_random_agents(num_agents):
    
    agents = []
    for _ in range(num_agents):
        
        agent = CartPoleAI()
        
        for param in agent.parameters():
            param.requires_grad = False
            
        init_weights(agent)
        agents.append(agent)
        
        
    return agents

In [97]:
def run_agents(agents):
    
    reward_agents = []
    env = gym.make("CartPole-v0")
    print("Enter")
    for agent in agents:#There is only one agent in the list. But it is necessary to pass the agents in the form of lists to make it iterable and work on it
        agent.eval()
        #print("HELLLO!!!!!!")
        observation = env.reset()
        
        r=0
        s=0
        
        for _ in range(max_timesteps):
            
            inp = torch.tensor(observation).type('torch.FloatTensor').view(1,-1)
            output_probabilities = agent(inp).detach().numpy()[0]
            print(output_probabilities)
            action = np.random.choice(range(game_actions), 1, p=output_probabilities).item()
            new_observation, reward, done, info = env.step(action)
            r=r+reward
            
            s=s+1
            observation = new_observation

            if(done):
                break

        reward_agents.append(r)        
        #reward_agents.append(s)
    print("Exit")    
    #print(reward_agents)
    return reward_agents

In [98]:
def return_average_score(agent, runs):
    score = 0.
    for i in range(runs):
        score += run_agents([agent])[0]
    return score/runs

In [99]:
def run_agents_n_times(agents, runs):
    avg_score = []
    for agent in agents:
        avg_score.append(return_average_score(agent,runs))
    return avg_score

In [100]:
def mutate(agent):

    child_agent = copy.deepcopy(agent)
    
    mutation_power = 0.02 #hyper-parameter, set from https://arxiv.org/pdf/1712.06567.pdf
            
    for param in child_agent.parameters():
    
        if(len(param.shape)==4): #weights of Conv2D

            for i0 in range(param.shape[0]):
                for i1 in range(param.shape[1]):
                    for i2 in range(param.shape[2]):
                        for i3 in range(param.shape[3]):
                            
                            param[i0][i1][i2][i3]+= mutation_power * np.random.randn()
                                
                                    

        elif(len(param.shape)==2): #weights of linear layer
            for i0 in range(param.shape[0]):
                for i1 in range(param.shape[1]):
                    
                    param[i0][i1]+= mutation_power * np.random.randn()
                        

        elif(len(param.shape)==1): #biases of linear layer or conv layer
            for i0 in range(param.shape[0]):
                
                param[i0]+=mutation_power * np.random.randn()

    return child_agent

In [101]:
def return_children(agents, sorted_parent_indexes, elite_index):
    
    children_agents = []
    
    #first take selected parents from sorted_parent_indexes and generate N-1 children
    for i in range(len(agents)-1):
        
        selected_agent_index = sorted_parent_indexes[np.random.randint(len(sorted_parent_indexes))]
        children_agents.append(mutate(agents[selected_agent_index]))

    #now add one elite
    elite_child = add_elite(agents, sorted_parent_indexes, elite_index)
    children_agents.append(elite_child)
    elite_index=len(children_agents)-1 #it is the last one
    
    return children_agents, elite_index

In [102]:
def add_elite(agents, sorted_parent_indexes, elite_index=None, only_consider_top_n=10):
    
    candidate_elite_index = sorted_parent_indexes[:only_consider_top_n]
    
    if(elite_index is not None):
        candidate_elite_index = np.append(candidate_elite_index,[elite_index])
        
    top_score = None
    top_elite_index = None
    
    for i in candidate_elite_index:
        score = return_average_score(agents[i],runs=5)
        print("Score for elite i ", i, " is ", score)
        
        if(top_score is None):
            top_score = score
            top_elite_index = i
        elif(score > top_score):
            top_score = score
            top_elite_index = i
            
    print("Elite selected with index ",top_elite_index, " and score", top_score)
    
    child_agent = copy.deepcopy(agents[top_elite_index])
    return child_agent

In [103]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)

In [104]:
game_actions = 2 #2 actions possible: left or right

#disable gradients as we will not use them
torch.set_grad_enabled(False)

# initialize N number of agents
num_agents = 500
agents = return_random_agents(num_agents)

# How many top agents to consider as parents
top_limit = 20

# run evolution until X generations
generations = 10

elite_index = None
n=[]

for generation in range(generations):

    # return rewards of agents
    rewards = run_agents_n_times(agents, 1) #return average of 3 runs later
    #print(rewards)
    #sort by rewards
    sorted_parent_indexes = np.argsort(rewards)[::-1][:top_limit] #reverses and gives top values (argsort sorts by ascending by default) https://stackoverflow.com/questions/16486252/is-it-possible-to-use-argsort-in-descending-order
    print("")
    print("")
    top_rewards = []
    
    for best_parent in sorted_parent_indexes:
        top_rewards.append(rewards[best_parent])
    
    print("Generation ", generation, " | Mean rewards: ", np.mean(rewards), " | Mean of top 5: ",np.mean(top_rewards[:5]))
    #print(rewards)
    print("Top ",top_limit," scores", sorted_parent_indexes)
    print("Rewards for top: ",top_rewards)
    
    
    n.append(np.mean(rewards))
    # setup an empty list for containing children agents
    children_agents, elite_index = return_children(agents, sorted_parent_indexes, elite_index)

    # kill all agents, and replace them with their children
    agents = children_agents
x=np.arange(generations)
plt.plot(x,n)
plt.title('Improvement of Mean Rewards in increasing Generations(Training)')
plt.ylabel('Mean Rewards')
plt.xlabel('Generations')
plt.show()

Enter
[0.442088   0.55791193]
[0.4451655 0.5548345]
[0.44221035 0.5577896 ]
[0.43333504 0.566665  ]
[0.4420756 0.5579244]
[0.44569322 0.55430675]
[0.43969664 0.56030333]
[0.4456839  0.55431604]
[0.43946537 0.5605346 ]
[0.43080503 0.569195  ]
[0.41976425 0.5802357 ]
[0.4086588 0.5913412]
[0.397064 0.602936]
[0.40543348 0.5945665 ]
[0.39484993 0.6051501 ]
[0.3822929 0.6177071]
Exit
Enter
[0.46212953 0.53787047]
[0.46558043 0.53441954]
[0.47034538 0.5296547 ]
[0.46649668 0.53350335]
[0.47094655 0.52905345]
[0.47561494 0.5243851 ]
[0.4716821 0.5283179]
[0.47740173 0.52259827]
[0.47281578 0.52718425]
[0.4795855 0.5204145]
[0.47460878 0.5253913 ]
[0.4815497  0.51845026]
[0.48621243 0.5137875 ]
Exit
Enter
[0.48535243 0.51464754]
[0.47705656 0.52294344]
[0.46872547 0.5312745 ]
[0.477849 0.522151]
[0.4693218 0.5306782]
[0.47828275 0.52171725]
[0.48536894 0.5146311 ]
[0.4967679 0.5032321]
[0.48539063 0.51460934]
[0.4780185  0.52198154]
[0.48533168 0.51466835]
[0.47801366 0.5219863 ]
[0.4852301 0

[0.3415368  0.65846324]
[0.35585976 0.64414024]
Exit
Enter
[0.5061951  0.49380484]
[0.5014091 0.4985909]
[0.5073219  0.49267814]
[0.5195605  0.48043954]
[0.5285468  0.47145316]
[0.5269742  0.47302577]
[0.52432364 0.4756763 ]
[0.52256316 0.47743687]
[0.5247455  0.47525454]
[0.5191751 0.4808249]
[0.5157421  0.48425788]
[0.5149425  0.48505756]
Exit
Enter
[0.56396604 0.43603396]
[0.5607756  0.43922445]
[0.5641818  0.43581823]
[0.561487   0.43851298]
[0.56396765 0.43603238]
[0.5650616  0.43493834]
[0.5638221  0.43617797]
[0.56210685 0.43789315]
[0.56856453 0.43143547]
[0.56318676 0.4368132 ]
[0.5634052  0.43659484]
[0.56408364 0.43591636]
[0.57402104 0.42597893]
[0.5655817  0.43441832]
[0.5786683  0.42133173]
[0.58773196 0.412268  ]
[0.60039616 0.3996039 ]
[0.59508586 0.40491414]
Exit
Enter
[0.4673183  0.53268164]
[0.45191562 0.5480844 ]
[0.46738443 0.5326155 ]
[0.48577285 0.51422715]
[0.51413625 0.48586372]
[0.4851459  0.51485413]
[0.5141074 0.4858926]
[0.4851138 0.5148862]
[0.46690583 0.5

[0.6162064  0.38379353]
[0.6066288 0.3933712]
[0.6003037  0.39969626]
[0.6058046  0.39419538]
[0.5996711  0.40032887]
[0.5987861  0.40121382]
[0.6047811  0.39521894]
[0.59750694 0.40249303]
[0.60587496 0.3941251 ]
[0.6133445  0.38665554]
[0.62024266 0.37975734]
Exit
Enter
[0.53448415 0.46551582]
[0.537084   0.46291605]
[0.53387696 0.466123  ]
[0.523896 0.476104]
[0.53367954 0.46632043]
[0.53681004 0.46319002]
[0.5346574  0.46534258]
[0.53563064 0.4643694 ]
[0.543948   0.45605198]
[0.5527254 0.4472746]
[0.5645284  0.43547162]
[0.57546175 0.42453822]
[0.5891646  0.41083542]
[0.5803811  0.41961887]
Exit
Enter
[0.42273974 0.57726026]
[0.44995674 0.5500432 ]
[0.4563447  0.54365534]
[0.4598932 0.5401068]
[0.46249968 0.5375003 ]
[0.45937538 0.5406246 ]
[0.45586064 0.54413927]
[0.45009148 0.5499085 ]
[0.4306168 0.5693832]
[0.3984502 0.6015498]
[0.435875 0.564125]
[0.4538143 0.5461857]
[0.4404801 0.5595199]
[0.4101437  0.58985627]
[0.3830249  0.61697507]
[0.3704864  0.62951356]
[0.38697878 0.61

[0.5214057 0.4785943]
[0.53380847 0.46619153]
[0.5159239  0.48407608]
[0.52864313 0.4713568 ]
[0.5107494  0.48925066]
[0.52352685 0.4764731 ]
[0.5044145  0.49558547]
[0.51802635 0.4819737 ]
[0.4967257 0.5032742]
[0.47299138 0.5270086 ]
[0.48908827 0.5109117 ]
Exit
Enter
[0.49404207 0.5059579 ]
[0.50409514 0.49590483]
[0.49821487 0.50178516]
[0.48319733 0.5168026 ]
[0.49831426 0.50168574]
[0.5045151  0.49548486]
[0.49513158 0.50486845]
[0.48081076 0.5191893 ]
[0.4965086  0.50349134]
[0.48316926 0.51683074]
[0.49780637 0.5021937 ]
[0.48514256 0.5148574 ]
[0.4990632  0.50093675]
[0.50659317 0.4934069 ]
[0.4931321 0.5068678]
[0.50615996 0.4938401 ]
[0.50139564 0.4986043 ]
[0.4902117 0.5097883]
[0.503148   0.49685192]
[0.5042598  0.49574023]
[0.5046971  0.49530292]
[0.4961144  0.50388557]
[0.48183793 0.5181621 ]
[0.4989969 0.5010031]
[0.5066598  0.49334025]
[0.4968374 0.5031625]
[0.4815846 0.5184154]
[0.4956622 0.5043378]
[0.5079541 0.4920459]
Exit
Enter
[0.5061089  0.49389112]
[0.52320796 

[0.5258292  0.47417086]
[0.5323917 0.4676083]
[0.5243365  0.47566345]
[0.53113073 0.4688692 ]
[0.5238362 0.4761638]
[0.5306628 0.4693372]
[0.52313143 0.4768685 ]
[0.5311964  0.46880364]
[0.5241174 0.4758826]
[0.5208975  0.47910258]
[0.5227118  0.47728822]
Exit
Enter
[0.5387687  0.46123126]
[0.54400176 0.45599827]
[0.54824674 0.4517533 ]
[0.54324865 0.45675132]
[0.5471478  0.45285213]
[0.55356854 0.4464315 ]
[0.5613771  0.43862292]
[0.5704258  0.42957413]
[0.5601564  0.43984362]
[0.5517547  0.44824535]
[0.54395896 0.45604098]
[0.54167765 0.45832235]
[0.54572433 0.45427573]
[0.54313606 0.4568639 ]
[0.5390458  0.46095416]
[0.5451175  0.45488247]
[0.5519292  0.44807076]
[0.5671935  0.43280646]
Exit
Enter
[0.46111    0.53888994]
[0.4742507 0.5257493]
[0.4934186 0.5065814]
[0.5039803  0.49601966]
[0.5136934  0.48630658]
[0.50598377 0.49401623]
[0.5162616  0.48373842]
[0.50838304 0.4916169 ]
[0.51927423 0.48072574]
[0.5112449 0.4887551]
[0.5228962  0.47710377]
[0.51506597 0.48493403]
[0.50656

[0.4528244 0.5471756]
[0.44998488 0.5500151 ]
[0.45274562 0.5472543 ]
[0.4530628 0.5469372]
[0.45319998 0.5468    ]
[0.4532038 0.5467962]
[0.45359358 0.54640645]
[0.45121863 0.54878134]
[0.4515343 0.5484657]
[0.45150635 0.5484936 ]
[0.4524325 0.5475674]
[0.45083955 0.5491604 ]
[0.45014292 0.54985714]
[0.45236465 0.5476353 ]
[0.45489517 0.54510486]
[0.45407397 0.54592603]
[0.456288 0.543712]
[0.45633316 0.5436669 ]
[0.45601192 0.54398805]
Exit
Enter
[0.5930356  0.40696442]
[0.596385   0.40361497]
[0.60403377 0.39596623]
[0.59746253 0.4025375 ]
[0.595082   0.40491804]
[0.5943806 0.4056194]
[0.592065   0.40793502]
[0.5952726  0.40472737]
[0.59320533 0.40679467]
[0.59624755 0.40375248]
[0.59781003 0.4021899 ]
[0.60084563 0.39915434]
[0.59844035 0.40155962]
[0.601926   0.39807394]
[0.59928244 0.4007176 ]
[0.5989615  0.40103856]
[0.59723425 0.40276578]
[0.5949168 0.4050832]
[0.59093505 0.40906498]
[0.59077    0.40923002]
[0.59301865 0.40698138]
[0.59096205 0.40903795]
[0.5938446  0.40615538]

[0.69167113 0.30832887]
Exit
Enter
[0.46984655 0.53015345]
[0.4637453 0.5362547]
[0.46898282 0.53101724]
[0.47151685 0.52848315]
[0.4671965  0.53280354]
[0.46213526 0.53786474]
[0.4502774  0.54972255]
[0.46417582 0.53582424]
[0.4516354 0.5483646]
[0.46550158 0.5344984 ]
[0.45190135 0.5480986 ]
[0.46586373 0.5341362 ]
[0.45138714 0.54861283]
[0.43656415 0.56343585]
[0.45032018 0.54967976]
[0.46296316 0.53703684]
Exit
Enter
[0.49774644 0.50225353]
[0.48799977 0.5120002 ]
[0.47901663 0.52098334]
[0.48804468 0.51195526]
[0.4966687 0.5033313]
[0.50150657 0.49849343]
[0.4967058 0.5032942]
[0.4874126 0.5125874]
[0.47796366 0.5220363 ]
[0.48708776 0.5129123 ]
[0.49512827 0.5048718 ]
[0.50021684 0.49978316]
[0.4948685  0.50513154]
[0.50012165 0.4998783 ]
[0.5049613 0.4950387]
[0.5005268 0.4994732]
[0.50511086 0.49488914]
[0.5016635  0.49833658]
[0.5049833  0.49501675]
[0.50149727 0.49850273]
[0.5048782  0.49512184]
[0.5017883 0.4982117]
[0.4995985 0.5004015]
[0.5022482  0.49775183]
[0.5002643 0

[0.53918725 0.4608128 ]
[0.54244727 0.45755273]
[0.53992844 0.46007153]
[0.54340565 0.45659435]
[0.56108    0.43892002]
[0.58034056 0.41965947]
[0.5981008  0.40189922]
[0.6154204  0.38457966]
[0.63360786 0.36639214]
Exit
Enter
[0.48415947 0.5158406 ]
[0.46865937 0.5313406 ]
[0.48438334 0.5156166 ]
[0.46867302 0.531327  ]
[0.45421878 0.54578125]
[0.43964365 0.5603563 ]
[0.42430976 0.5756902 ]
[0.4403885 0.5596115]
[0.45527685 0.54472315]
[0.4396557 0.5603443]
[0.4231186 0.5768814]
[0.40694445 0.59305555]
Exit
Enter
[0.447207   0.55279297]
[0.47284496 0.52715504]
[0.49603808 0.5039619 ]
[0.47121385 0.5287861 ]
[0.44462115 0.55537885]
[0.42279336 0.5772066 ]
[0.4450598  0.55494016]
[0.47120342 0.5287966 ]
[0.44458288 0.5554172 ]
[0.47085854 0.5291414 ]
[0.49431378 0.5056863 ]
[0.470126 0.529874]
[0.44393083 0.55606914]
[0.4707462 0.5292538]
[0.44477907 0.5552209 ]
[0.47151393 0.5284861 ]
[0.44588953 0.55411047]
[0.4240055  0.57599455]
[0.4478264 0.5521737]
[0.47419393 0.52580607]
[0.44896

[0.525724   0.47427595]
[0.50515777 0.4948423 ]
[0.5205571  0.47944292]
[0.5395532  0.46044675]
[0.51518786 0.48481217]
[0.49326175 0.5067382 ]
[0.47639334 0.52360666]
Exit
Enter
[0.5815785 0.4184215]
[0.5677999 0.4322001]
[0.58150333 0.41849664]
[0.5683431  0.43165693]
[0.58145106 0.41854897]
[0.5919604  0.40803963]
[0.59237164 0.4076284 ]
[0.592745   0.40725496]
[0.58358073 0.41641924]
[0.57155776 0.42844224]
[0.5498962  0.45010382]
[0.57191455 0.42808548]
[0.58337456 0.41662535]
[0.59085524 0.40914476]
[0.5928136  0.40718636]
[0.5910804  0.40891957]
[0.59287536 0.4071246 ]
[0.5976903  0.40230978]
[0.59261405 0.40738603]
[0.59926784 0.4007322 ]
[0.59281427 0.40718576]
[0.60048765 0.39951235]
Exit
Enter
[0.51216525 0.48783475]
[0.49346614 0.50653386]
[0.47902593 0.5209741 ]
[0.47272924 0.5272707 ]
[0.46624428 0.5337558 ]
[0.46095878 0.5390413 ]
[0.4653145 0.5346855]
[0.47360817 0.5263918 ]
[0.48238724 0.51761276]
[0.47178814 0.5282119 ]
[0.4796469 0.5203531]
[0.49707222 0.5029277 ]
[0

[0.46241528 0.5375847 ]
[0.46386573 0.5361343 ]
[0.46654043 0.53345966]
[0.46265423 0.53734577]
[0.4611933 0.5388067]
[0.4623689 0.5376311]
[0.46560597 0.534394  ]
[0.46120346 0.5387965 ]
[0.4650204 0.5349796]
[0.460119 0.539881]
[0.465063   0.53493696]
[0.45889562 0.5411044 ]
[0.46547887 0.53452104]
[0.45772982 0.5422702 ]
[0.45615014 0.5438499 ]
[0.45406732 0.54593265]
[0.45604503 0.543955  ]
[0.45463538 0.5453646 ]
[0.4502776 0.5497224]
[0.45589507 0.54410493]
[0.4574039 0.5425961]
[0.45621642 0.5437836 ]
[0.45790648 0.54209346]
[0.4562813 0.5437187]
[0.4601207  0.53987926]
[0.47146657 0.5285334 ]
[0.5004851  0.49951482]
[0.4767251  0.52327496]
Exit
Enter
[0.55247855 0.44752142]
[0.5644791  0.43552086]
[0.5767923  0.42320767]
[0.56491256 0.43508747]
[0.5770433  0.42295662]
[0.5654732 0.4345268]
[0.55315906 0.44684094]
[0.56579894 0.43420106]
[0.5531058 0.4468943]
[0.543015 0.456985]
[0.53035694 0.4696431 ]
[0.5289334  0.47106662]
[0.5299221 0.4700779]
[0.52870065 0.47129938]
[0.5283

[0.5044938  0.49550623]
[0.5167841  0.48321587]
[0.5047973  0.49520272]
[0.48751953 0.5124805 ]
[0.46998543 0.5300145 ]
[0.48647115 0.5135288 ]
[0.50296795 0.49703205]
[0.4852163  0.51478374]
[0.5018766  0.49812344]
[0.4837939  0.51620615]
[0.50049305 0.49950692]
[0.5062489  0.49375117]
[0.5047577  0.49524233]
[0.50729007 0.49270996]
[0.50037104 0.499629  ]
[0.48155808 0.518442  ]
[0.46459326 0.53540677]
[0.4468451 0.5531549]
[0.42915764 0.5708424 ]
[0.44259432 0.55740565]
[0.4575557 0.5424443]
[0.43843943 0.56156063]
[0.41903684 0.58096313]
[0.43270513 0.5672949 ]
Exit
Enter
[0.5703537  0.42964634]
[0.5604959 0.4395041]
[0.570604   0.42939603]
[0.5607977  0.43920228]
[0.5456143  0.45438573]
[0.56201524 0.43798476]
[0.5711654 0.4288346]
[0.57795197 0.42204803]
[0.5695933 0.4304067]
[0.56018555 0.4398145 ]
[0.5435436  0.45645642]
[0.52971846 0.47028148]
[0.52078426 0.47921574]
[0.51082796 0.48917207]
[0.5243631  0.47563684]
[0.51483476 0.48516533]
[0.50836253 0.49163747]
Exit
Enter
[0.5

[0.5705527  0.42944726]
[0.5681402 0.4318598]
[0.5663214  0.43367857]
[0.5610977  0.43890226]
[0.54971606 0.45028386]
[0.54537386 0.45462614]
[0.5462751 0.4537249]
[0.55041575 0.44958422]
[0.55299205 0.447008  ]
[0.54741836 0.4525817 ]
[0.55100256 0.44899744]
[0.54538953 0.45461047]
[0.541743 0.458257]
[0.5433765  0.45662352]
[0.5398063  0.46019372]
[0.54174054 0.4582595 ]
[0.53769994 0.46230003]
[0.5402504  0.45974955]
[0.543836   0.45616403]
Exit
Enter
[0.5105084  0.48949158]
[0.4920351 0.5079649]
[0.48151618 0.51848376]
[0.49204227 0.5079578 ]
[0.5105732  0.48942673]
[0.49207884 0.50792116]
[0.5101791  0.48982084]
[0.53662455 0.46337542]
[0.5657337  0.43426636]
[0.5919969  0.40800306]
[0.56332254 0.43667755]
[0.533721   0.46627894]
[0.50890774 0.49109226]
[0.4915732 0.5084267]
[0.4807224  0.51927763]
[0.47023362 0.5297663 ]
[0.46598816 0.5340119 ]
[0.47021955 0.52978045]
[0.48178977 0.5182103 ]
[0.49187288 0.50812715]
[0.48215094 0.5178491 ]
[0.49221992 0.50778   ]
[0.5071251  0.492

[0.49784714 0.5021529 ]
[0.5100179  0.48998207]
[0.49888888 0.50111115]
[0.4821136 0.5178864]
[0.49994266 0.5000574 ]
[0.5067613  0.49323863]
Exit
Enter
[0.48510796 0.51489204]
[0.47387847 0.52612156]
[0.48413098 0.51586896]
[0.4924377 0.5075623]
[0.48870942 0.51129055]
[0.47692758 0.52307236]
[0.48723102 0.51276904]
[0.47601265 0.52398735]
[0.48572308 0.514277  ]
[0.47510853 0.52489144]
[0.48435694 0.51564306]
[0.48489282 0.51510715]
[0.47735557 0.52264446]
[0.46867764 0.5313223 ]
[0.46912673 0.5308733 ]
[0.47005206 0.529948  ]
[0.4806158  0.51938415]
[0.47066808 0.52933186]
[0.4702232 0.5297768]
[0.4716731  0.52832687]
[0.47067082 0.5293291 ]
[0.471482  0.5285179]
[0.47216186 0.5278382 ]
[0.47346684 0.5265331 ]
[0.4766118 0.5233882]
[0.47437721 0.5256227 ]
[0.4756874  0.52431256]
[0.47759786 0.5224021 ]
[0.4742909 0.5257091]
Exit
Enter
[0.46353573 0.5364643 ]
[0.46531203 0.534688  ]
[0.462894 0.537106]
[0.46715903 0.532841  ]
[0.4629464  0.53705364]
[0.46388435 0.5361156 ]
[0.4627238

[0.5441565  0.45584354]
[0.53431463 0.46568537]
[0.5272067  0.47279328]
[0.529025   0.47097498]
[0.52251536 0.47748458]
[0.51704055 0.4829594 ]
[0.50986767 0.49013233]
[0.5008745  0.49912557]
[0.4891987  0.51080126]
Exit
Enter
[0.55392754 0.44607243]
[0.56115323 0.4388467 ]
[0.5652794 0.4347206]
[0.56139845 0.43860152]
[0.5656822 0.4343179]
[0.561607   0.43839297]
[0.5661529  0.43384707]
[0.57528365 0.4247163 ]
[0.58275384 0.41724616]
[0.591479   0.40852106]
[0.6015343  0.39846572]
[0.5941785  0.40582147]
[0.58888185 0.41111812]
Exit
Enter
[0.5167699  0.48323005]
[0.52525675 0.47474325]
[0.51528144 0.4847185 ]
[0.5241775 0.4758226]
[0.5299654  0.47003463]
[0.52325547 0.47674447]
[0.5293903  0.47060975]
[0.5228608  0.47713917]
[0.51365566 0.48634434]
[0.52306247 0.4769375 ]
[0.5140691  0.48593095]
[0.5233811  0.47661892]
[0.530022 0.469978]
[0.5420857 0.4579143]
[0.5464771  0.45352298]
[0.5427327  0.45726728]
[0.54540384 0.4545961 ]
[0.5420427  0.45795736]
[0.5377018  0.46229818]
[0.533

Enter
[0.40478602 0.595214  ]
[0.39231938 0.6076806 ]
[0.3810191  0.61898094]
[0.39320895 0.606791  ]
[0.38137415 0.6186259 ]
[0.39378732 0.6062126 ]
[0.40729782 0.59270215]
[0.39396098 0.60603905]
[0.40750462 0.5924954 ]
[0.42058054 0.5794195 ]
[0.40737924 0.59262073]
[0.39399493 0.6060051 ]
[0.4076102 0.5923898]
[0.394085   0.60591507]
[0.40774962 0.59225035]
[0.42065644 0.57934356]
[0.42383742 0.5761626 ]
[0.4150651 0.5849349]
[0.40506876 0.59493124]
[0.41276032 0.5872396 ]
[0.40389732 0.5961027 ]
[0.410215  0.5897849]
[0.40227792 0.59772205]
[0.4070042 0.5929958]
[0.41121858 0.5887814 ]
[0.4041552 0.5958448]
Exit
Enter
[0.47414333 0.5258567 ]
[0.48177755 0.51822245]
[0.4742659 0.5257341]
[0.4713372 0.5286628]
[0.46249422 0.5375058 ]
[0.47202772 0.52797234]
[0.47497228 0.52502775]
[0.47201678 0.5279831 ]
[0.46420893 0.5357911 ]
[0.4724409 0.5275591]
[0.47549567 0.52450436]
[0.47233263 0.5276674 ]
[0.4651647  0.53483534]
[0.45518062 0.54481936]
[0.443842 0.556158]
[0.45625487 0.54374

KeyboardInterrupt: 

In [None]:
def play_agent(agent):
        env = gym.make("CartPole-v0")
        
        env_record = Monitor(env, './video', force=True)
        observation = env_record.reset()
        last_observation = observation
        r=0
        j=[]
        episode_durations=[]
        timestep=0
        for timestep in range(250):
            env_record.render()
            inp = torch.tensor(observation).type('torch.FloatTensor').view(1,-1)
            output_probabilities = agent(inp).detach().numpy()[0]
            action = np.random.choice(range(game_actions), 1, p=output_probabilities).item()
            new_observation, reward, done, info = env_record.step(action)
            r=r+reward
            j.append(r)
            observation = new_observation

            if(done):
                break

        env_record.close()

        print("Rewards: ",r)

In [None]:
play_agent(agents[0])