In [3]:
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gym
env=gym.make('CartPole-v0')

In [6]:
class PolicyNetwork(nn.Module):
    def __init__(self,lr,input_dims,fc1_dims,fc2_dims,n_actions):
        super(PolicyNetwork,self).__init__()
        self.input_dims=input_dims
        self.lr=lr
        self.fc1_dims=fc1_dims
        self.fc2_dims=fc2_dims
        self.n_actions=n_actions
        self.fc1=nn.Linear(input_dims,fc1_dims)
        self.fc2=nn.Linear(fc1_dims,fc2_dims)
        self.fc3=nn.Linear(fc2_dims,n_actions)
        self.optimizer=optim.Adam(self.parameters(),lr=lr)
        
        self.device=T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)
        
    def forward(self,observation):
        observation=T.Tensor(observation).to(self.device)
        x=F.relu(self.fc1(observation))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        
        return x

In [7]:
class Agent():
    def __init__(self,lr,input_dims,gamma=0.99,n_actions=4,l1_size=256,l2_size=256):
        
        self.gamma=gamma
        self.reward_memory=[]
        self.action_memory=[]
        self.policy=PolicyNetwork(lr,input_dims,l1_size,l2_size,n_actions)
        
    def act(self,observation):
        probs=F.softmax(self.policy.forward(observation))
        action_probs=T.distributions.Categorical(probs)
        action=action_probs.sample()
        log_probs=action_probs.log_prob(action)
        self.action_memory.append(log_probs)
        
        return action.item()
    
    def store_rewards(self,reward):
        self.reward_memory.append(reward)
        
    def learn(self):
        self.policy.optimizer.zero_grad()
        G=[]
        for t in range(len(self.reward_memory)):
            G_sum=0
            discount=1
            
            for k in range(t,len(self.reward_memory)):
                G_sum+=self.reward_memory[k]*discount
                discount*=self.gamma   
            G.append(G_sum)
            
        mean=np.mean(G)
        std=np.std(G) if np.std(G)>0 else 1
        G=np.array(G)
        G=(G-mean)/std
        
        G=T.tensor(G).to(self.policy.device)
        
        loss=0
        
        for g,log_probs in zip(G,self.action_memory):
            loss+= -g*log_probs
        
        loss.backward()
        
        self.policy.optimizer.step()
        
        self.action_memory=[]
        self.reward_memory=[]
        
        
    

In [39]:

agent=Agent(0.001,4,0.99,2,128,128)

score_history=[]
score=0
num_ep=2500

for i in range(num_ep):
    done=False
    score=0
    observation=env.reset()
    while not done:
        action=agent.act(observation)
        observation,reward,done,_ =env.step(action)
        agent.store_rewards(reward)
        score+=reward
    score_history.append(score)
    print('episode',i,'score %.3f' % score)
    agent.learn()

  # Remove the CWD from sys.path while we load stuff.


episode 0 score 80.000
episode 1 score 40.000
episode 2 score 13.000
episode 3 score 21.000
episode 4 score 18.000
episode 5 score 30.000
episode 6 score 25.000
episode 7 score 14.000
episode 8 score 16.000
episode 9 score 17.000
episode 10 score 22.000
episode 11 score 28.000
episode 12 score 16.000
episode 13 score 19.000
episode 14 score 21.000
episode 15 score 35.000
episode 16 score 133.000
episode 17 score 24.000
episode 18 score 33.000
episode 19 score 43.000
episode 20 score 20.000
episode 21 score 16.000
episode 22 score 58.000
episode 23 score 29.000
episode 24 score 17.000
episode 25 score 26.000
episode 26 score 25.000
episode 27 score 32.000
episode 28 score 22.000
episode 29 score 29.000
episode 30 score 61.000
episode 31 score 22.000
episode 32 score 50.000
episode 33 score 34.000
episode 34 score 13.000
episode 35 score 18.000
episode 36 score 14.000
episode 37 score 29.000
episode 38 score 15.000
episode 39 score 42.000
episode 40 score 16.000
episode 41 score 71.000
e

episode 324 score 200.000
episode 325 score 200.000
episode 326 score 200.000
episode 327 score 200.000
episode 328 score 200.000
episode 329 score 200.000
episode 330 score 187.000
episode 331 score 200.000
episode 332 score 176.000
episode 333 score 200.000
episode 334 score 186.000
episode 335 score 200.000
episode 336 score 200.000
episode 337 score 200.000
episode 338 score 200.000
episode 339 score 107.000
episode 340 score 200.000
episode 341 score 200.000
episode 342 score 200.000
episode 343 score 200.000
episode 344 score 200.000
episode 345 score 200.000
episode 346 score 200.000
episode 347 score 200.000
episode 348 score 200.000
episode 349 score 200.000
episode 350 score 200.000
episode 351 score 200.000
episode 352 score 200.000
episode 353 score 200.000
episode 354 score 200.000
episode 355 score 200.000
episode 356 score 53.000
episode 357 score 144.000
episode 358 score 100.000
episode 359 score 200.000
episode 360 score 153.000
episode 361 score 103.000
episode 362 s

episode 642 score 114.000
episode 643 score 103.000
episode 644 score 107.000
episode 645 score 113.000
episode 646 score 57.000
episode 647 score 46.000
episode 648 score 35.000
episode 649 score 60.000
episode 650 score 33.000
episode 651 score 83.000
episode 652 score 36.000
episode 653 score 73.000
episode 654 score 99.000
episode 655 score 94.000
episode 656 score 88.000
episode 657 score 115.000
episode 658 score 118.000
episode 659 score 117.000
episode 660 score 114.000
episode 661 score 135.000
episode 662 score 160.000
episode 663 score 142.000
episode 664 score 170.000
episode 665 score 183.000
episode 666 score 200.000
episode 667 score 170.000
episode 668 score 165.000
episode 669 score 185.000
episode 670 score 187.000
episode 671 score 175.000
episode 672 score 143.000
episode 673 score 180.000
episode 674 score 172.000
episode 675 score 200.000
episode 676 score 200.000
episode 677 score 200.000
episode 678 score 200.000
episode 679 score 200.000
episode 680 score 200.0

episode 958 score 200.000
episode 959 score 200.000
episode 960 score 200.000
episode 961 score 200.000
episode 962 score 200.000
episode 963 score 200.000
episode 964 score 200.000
episode 965 score 200.000
episode 966 score 200.000
episode 967 score 200.000
episode 968 score 200.000
episode 969 score 200.000
episode 970 score 200.000
episode 971 score 200.000
episode 972 score 200.000
episode 973 score 200.000
episode 974 score 200.000
episode 975 score 200.000
episode 976 score 200.000
episode 977 score 200.000
episode 978 score 200.000
episode 979 score 200.000
episode 980 score 200.000
episode 981 score 200.000
episode 982 score 200.000
episode 983 score 200.000
episode 984 score 200.000
episode 985 score 200.000
episode 986 score 200.000
episode 987 score 200.000
episode 988 score 200.000
episode 989 score 200.000
episode 990 score 200.000
episode 991 score 200.000
episode 992 score 200.000
episode 993 score 200.000
episode 994 score 200.000
episode 995 score 200.000
episode 996 

episode 1264 score 200.000
episode 1265 score 200.000
episode 1266 score 200.000
episode 1267 score 200.000
episode 1268 score 200.000
episode 1269 score 200.000
episode 1270 score 200.000
episode 1271 score 200.000
episode 1272 score 200.000
episode 1273 score 60.000
episode 1274 score 200.000
episode 1275 score 200.000
episode 1276 score 200.000
episode 1277 score 200.000
episode 1278 score 200.000
episode 1279 score 200.000
episode 1280 score 200.000
episode 1281 score 200.000
episode 1282 score 200.000
episode 1283 score 200.000
episode 1284 score 200.000
episode 1285 score 200.000
episode 1286 score 200.000
episode 1287 score 200.000
episode 1288 score 22.000
episode 1289 score 66.000
episode 1290 score 200.000
episode 1291 score 200.000
episode 1292 score 20.000
episode 1293 score 200.000
episode 1294 score 200.000
episode 1295 score 200.000
episode 1296 score 200.000
episode 1297 score 21.000
episode 1298 score 19.000
episode 1299 score 22.000
episode 1300 score 21.000
episode 1

episode 1569 score 200.000
episode 1570 score 200.000
episode 1571 score 200.000
episode 1572 score 200.000
episode 1573 score 200.000
episode 1574 score 200.000
episode 1575 score 200.000
episode 1576 score 200.000
episode 1577 score 200.000
episode 1578 score 200.000
episode 1579 score 200.000
episode 1580 score 200.000
episode 1581 score 200.000
episode 1582 score 200.000
episode 1583 score 200.000
episode 1584 score 200.000
episode 1585 score 200.000
episode 1586 score 200.000
episode 1587 score 200.000
episode 1588 score 200.000
episode 1589 score 200.000
episode 1590 score 200.000
episode 1591 score 200.000
episode 1592 score 200.000
episode 1593 score 200.000
episode 1594 score 200.000
episode 1595 score 200.000
episode 1596 score 200.000
episode 1597 score 200.000
episode 1598 score 200.000
episode 1599 score 200.000
episode 1600 score 200.000
episode 1601 score 200.000
episode 1602 score 200.000
episode 1603 score 200.000
episode 1604 score 200.000
episode 1605 score 200.000
e

episode 1873 score 200.000
episode 1874 score 200.000
episode 1875 score 200.000
episode 1876 score 200.000
episode 1877 score 200.000
episode 1878 score 200.000
episode 1879 score 200.000
episode 1880 score 200.000
episode 1881 score 200.000
episode 1882 score 200.000
episode 1883 score 200.000
episode 1884 score 200.000
episode 1885 score 200.000
episode 1886 score 200.000
episode 1887 score 200.000
episode 1888 score 200.000
episode 1889 score 200.000
episode 1890 score 200.000
episode 1891 score 200.000
episode 1892 score 200.000
episode 1893 score 200.000
episode 1894 score 200.000
episode 1895 score 200.000
episode 1896 score 200.000
episode 1897 score 200.000
episode 1898 score 200.000
episode 1899 score 200.000
episode 1900 score 200.000
episode 1901 score 200.000
episode 1902 score 200.000
episode 1903 score 200.000
episode 1904 score 200.000
episode 1905 score 200.000
episode 1906 score 200.000
episode 1907 score 200.000
episode 1908 score 200.000
episode 1909 score 200.000
e

episode 2177 score 200.000
episode 2178 score 200.000
episode 2179 score 200.000
episode 2180 score 200.000
episode 2181 score 200.000
episode 2182 score 200.000
episode 2183 score 200.000
episode 2184 score 200.000
episode 2185 score 200.000
episode 2186 score 200.000
episode 2187 score 200.000
episode 2188 score 200.000
episode 2189 score 200.000
episode 2190 score 200.000
episode 2191 score 200.000
episode 2192 score 200.000
episode 2193 score 200.000
episode 2194 score 200.000
episode 2195 score 200.000
episode 2196 score 200.000
episode 2197 score 200.000
episode 2198 score 200.000
episode 2199 score 200.000
episode 2200 score 200.000
episode 2201 score 200.000
episode 2202 score 200.000
episode 2203 score 200.000
episode 2204 score 200.000
episode 2205 score 200.000
episode 2206 score 200.000
episode 2207 score 200.000
episode 2208 score 200.000
episode 2209 score 200.000
episode 2210 score 200.000
episode 2211 score 200.000
episode 2212 score 200.000
episode 2213 score 200.000
e

episode 2481 score 200.000
episode 2482 score 200.000
episode 2483 score 200.000
episode 2484 score 200.000
episode 2485 score 200.000
episode 2486 score 200.000
episode 2487 score 200.000
episode 2488 score 200.000
episode 2489 score 200.000
episode 2490 score 200.000
episode 2491 score 200.000
episode 2492 score 200.000
episode 2493 score 200.000
episode 2494 score 200.000
episode 2495 score 200.000
episode 2496 score 200.000
episode 2497 score 200.000
episode 2498 score 200.000
episode 2499 score 200.000


In [9]:
PATH="CartPole_agent_weights/agent1.pt"
# T.save(agent.policy.state_dict(), PATH)

In [10]:
new_agent = Agent(0.001,4,0.99,2,128,128)
new_agent.policy.load_state_dict(T.load(PATH))
new_agent.policy.eval()

PolicyNetwork(
  (fc1): Linear(in_features=4, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=2, bias=True)
)

In [11]:

done=False
score=0
observation=env.reset()
while not done:
    action=new_agent.act(observation)
    observation,reward,done,_ =env.step(action)
    score+=reward
# score_history.append(score)
print(score)

200.0


  # Remove the CWD from sys.path while we load stuff.
