In [1]:
import gym
import numpy as np
import torch
from torch.nn import Module
from torch.nn import Linear
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.optim import Adam
import matplotlib.pyplot as plt

In [2]:
RANDOM_SEED = 99
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x1d450354c50>

In [3]:
env = gym.make('CartPole-v0')
env.seed(RANDOM_SEED)

[99]

In [4]:
# Свойста агента
STATES_SIZE =env.observation_space.shape[0]
ACTION_SIZE =env.action_space.n

In [5]:
class PolicyNet(Module):
    
    def __init__(self,input_size,output_size):
        super().__init__()
        
        self.fc1=Linear(input_size,32)
        self.fc2=Linear(32,output_size)
        
        self.train()
        self.onpolicy_reset()
    
    def onpolicy_reset(self):
        self.rewards=[]
        self.log_probs=[]
    
    def forward(self,state):
        
        x=self.fc1(state)
        x=F.relu(x)
        x=self.fc2(x)
        
        return x
    
    # Функция политики, возвращает действие на састояние и 
    # логорифм действия
    def act(self,state):
        x=torch.from_numpy(state.astype(np.float32))
        
        props=self.forward(x) 
        prop_cat=Categorical(logits=props)
        # Получаем действие
        action=prop_cat.sample()
        self.log_probs.append(prop_cat.log_prob(action))
        return action.item()
        

In [6]:
gamma=0.99

In [7]:
def train(net, optimizer):
    optimizator.zero_grad()
    
    T=len(net.rewards)
    rets = np.empty(T,dtype=np.float32)
    futere_ret=0.0
    
    for t in reversed(range(T)):
        #print("t:{}".format(t))
        # Берем награду и добавляем 
        futere_ret =net.rewards[t]+gamma*futere_ret
        #print("futere_ret:{}".format(futere_ret))
        rets[t]=futere_ret
        #print("rets:{}".format(rets[t]))
        #print("----------end-----------")
        
    rets=torch.tensor(rets)
    log_probs=torch.stack(net.log_probs)
    loss= - log_probs*rets
    loss=torch.sum(loss)
    print("try backward")  
    loss.backward()
    optimizator.step()
    
    return loss
    

In [8]:
policy = PolicyNet(STATES_SIZE,ACTION_SIZE)

In [9]:
optimizator = Adam(policy.parameters(),lr=0.01)

In [10]:
for epoch in range(300):
    state = env.reset()
    
    for t in range(200):
        action = policy.act(state)
        state, reward, done, _ = env.step(action)
        policy.rewards.append(reward)
        
        env.render()
        # Если игра завершинна
        if done:
            break
    
    loss =train(policy,optimizator)
    total_reward=sum(policy.rewards)
    solved = total_reward>195.0
    policy.onpolicy_reset()
    
    print("Epoch:{} loss:{} total_reward:{} solved:{}".format(epoch,loss,total_reward,solved))

try backward
Epoch:0 loss:122.52767181396484 total_reward:19.0 solved:False
try backward
Epoch:1 loss:603.6463623046875 total_reward:44.0 solved:False
try backward
Epoch:2 loss:47.78059005737305 total_reward:11.0 solved:False
try backward
Epoch:3 loss:165.68043518066406 total_reward:22.0 solved:False
try backward
Epoch:4 loss:99.9844970703125 total_reward:17.0 solved:False
try backward
Epoch:5 loss:88.56130981445312 total_reward:16.0 solved:False
try backward
Epoch:6 loss:274.0126953125 total_reward:29.0 solved:False
try backward
Epoch:7 loss:66.74525451660156 total_reward:14.0 solved:False
try backward
Epoch:8 loss:97.55488586425781 total_reward:17.0 solved:False
try backward
Epoch:9 loss:173.27239990234375 total_reward:23.0 solved:False
try backward
Epoch:10 loss:285.6756286621094 total_reward:30.0 solved:False
try backward
Epoch:11 loss:39.125038146972656 total_reward:11.0 solved:False
try backward
Epoch:12 loss:219.47703552246094 total_reward:26.0 solved:False
try backward
Epoch:13

try backward
Epoch:109 loss:6124.7666015625 total_reward:196.0 solved:True
try backward
Epoch:110 loss:1324.68896484375 total_reward:77.0 solved:False
try backward
Epoch:111 loss:1599.9058837890625 total_reward:88.0 solved:False
try backward
Epoch:112 loss:2873.476806640625 total_reward:113.0 solved:False
try backward
Epoch:113 loss:6564.56982421875 total_reward:200.0 solved:True
try backward
Epoch:114 loss:1417.6817626953125 total_reward:80.0 solved:False
try backward
Epoch:115 loss:2647.389892578125 total_reward:120.0 solved:False
try backward
Epoch:116 loss:2749.26123046875 total_reward:122.0 solved:False
try backward
Epoch:117 loss:2238.53369140625 total_reward:102.0 solved:False
try backward
Epoch:118 loss:6121.6015625 total_reward:200.0 solved:True
try backward
Epoch:119 loss:2246.934326171875 total_reward:102.0 solved:False
try backward
Epoch:120 loss:6354.03515625 total_reward:200.0 solved:True
try backward
Epoch:121 loss:6269.6328125 total_reward:200.0 solved:True
try backward

try backward
Epoch:217 loss:6451.96728515625 total_reward:200.0 solved:True
try backward
Epoch:218 loss:6402.37060546875 total_reward:200.0 solved:True
try backward
Epoch:219 loss:6237.98388671875 total_reward:200.0 solved:True
try backward
Epoch:220 loss:6555.62353515625 total_reward:200.0 solved:True
try backward
Epoch:221 loss:6508.12890625 total_reward:200.0 solved:True
try backward
Epoch:222 loss:6548.57373046875 total_reward:200.0 solved:True
try backward
Epoch:223 loss:2397.49365234375 total_reward:106.0 solved:False
try backward
Epoch:224 loss:6435.0458984375 total_reward:200.0 solved:True
try backward
Epoch:225 loss:6173.96630859375 total_reward:200.0 solved:True
try backward
Epoch:226 loss:6415.63232421875 total_reward:200.0 solved:True
try backward
Epoch:227 loss:5997.91162109375 total_reward:200.0 solved:True
try backward
Epoch:228 loss:6012.62158203125 total_reward:200.0 solved:True
try backward
Epoch:229 loss:6287.56005859375 total_reward:200.0 solved:True
try backward
Ep