In [86]:
import gymnasium as gym
import pygame
from gymnasium.utils.play import play
# from gynasium.utils.play import pl
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import random

%matplotlib inline

matplotlib.rcParams['figure.facecolor'] = '#ffffff'


In [87]:
env = gym.make("CartPole-v1",render_mode="human")

In [88]:
class Q_Network(nn.Module):
    def __init__(self):
        super(Q_Network, self).__init__()
        self.network = nn.Sequential(nn.Linear(4,6),
                                     nn.ReLU(),
                                     nn.Linear(6,4),
                                     nn.ReLU(),
                                     nn.Linear(4,2))
    def forward(self,state):
            return self.network(state)

class Policy_Network(nn.Module):
    def __init__(self):
        super(Policy_Network, self).__init__()
        self.network = nn.Sequential(nn.Linear(4,6),
                                     nn.ReLU(),
                                     nn.Linear(6,4),
                                     nn.ReLU(),
                                     nn.Linear(4,2),
                                     nn.Softmax(dim=0))
    def forward(self,state):
            return self.network(state)

In [89]:
# @torch.no_grad()
def train(epochs,lrQ,lrP,model_Q,model_P,max_steps,gamma,opt_func=torch.optim.SGD):
    epsilon =1e-5
    env = gym.make("CartPole-v1",render_mode="human")
    action_labels = torch.tensor([0,1])
    all_reward = []

    
    for epoch in range(epochs):
        #epoch initialisation
        state , info = env.reset()
        state = torch.from_numpy(state)

        total_reward = 0
        for step in range(max_steps):

            P_actions = model_P.forward(state)
            prob = random.uniform(0, 1)
            if prob > P_actions[1].item():
                action = 0
            else:
                action = 1
            

            #New state and reward for state-action pair
            new_state, reward, done,truncated, info= env.step(action)
            new_state = torch.from_numpy(new_state)

            # State_value = torch.tensordot(P_actions,Q_values,dim=1)

            #Rewards for epoch and loop break
            total_reward += reward
            if done or truncated:
                print(epoch,":",total_reward)
                all_reward.append(total_reward)
                break
            
           

            P_actions_new = model_P.forward(new_state)
            prob = random.uniform(0, 1)
            P_right = P_actions[1].item()
            if prob > P_right:
                new_action = 0
            else:
                new_action = 1

            #advantage
            Advantage = (reward + gamma*model_Q.forward(new_state)[new_action] - model_Q.forward(state)[action])
            
            opt_Q = opt_func(model_Q.parameters(),lrQ)
            Q_loss = (Advantage)**2
            Q_loss.backward()
            opt_Q.step()
            opt_Q.zero_grad()

            opt_P = opt_func(model_P.parameters(),lrP)            
            P_actions = torch.clamp(P_actions,min=epsilon,max=1-epsilon)  #to remove inf edge case                     
            loss_P = -torch.log(P_actions[action])*Advantage.item()
            print(loss_P)
            loss_P.backward()
            opt_P.step()
            opt_P.zero_grad()
            
            
            #State-Change
            state = new_state
            
    env.close()

In [90]:
model_Q = Q_Network()
model_P = Policy_Network()

In [91]:
epochs=500
max_steps=300
gamma=0.99
lrQ=1e-3
lrP=3e-4

Possible reasons:
1. Exploding value due to large value after log
2. Division by 0 somewhere
3. Exploding gradients
4. higher the learning rate faster the error arrives.
5. Maybe instead of using the optimiser iterating over the params and modifying them would be better.

In [92]:

train(epochs,lrQ,lrP,model_Q,model_P,max_steps,gamma)


tensor(0.8589, grad_fn=<MulBackward0>)
tensor(0.3463, grad_fn=<MulBackward0>)
tensor(0.3464, grad_fn=<MulBackward0>)
tensor(0.3395, grad_fn=<MulBackward0>)
tensor(0.5573, grad_fn=<MulBackward0>)
tensor(0.8608, grad_fn=<MulBackward0>)
tensor(0.5590, grad_fn=<MulBackward0>)
tensor(0.5542, grad_fn=<MulBackward0>)
tensor(0.8561, grad_fn=<MulBackward0>)
tensor(0.8611, grad_fn=<MulBackward0>)
tensor(0.8970, grad_fn=<MulBackward0>)
0 : 12.0
tensor(0.8637, grad_fn=<MulBackward0>)
tensor(0.5466, grad_fn=<MulBackward0>)
tensor(0.8898, grad_fn=<MulBackward0>)
tensor(0.8382, grad_fn=<MulBackward0>)
tensor(0.4112, grad_fn=<MulBackward0>)
tensor(0.8802, grad_fn=<MulBackward0>)
tensor(0.5514, grad_fn=<MulBackward0>)
tensor(0.8273, grad_fn=<MulBackward0>)
tensor(0.5488, grad_fn=<MulBackward0>)
tensor(0.9025, grad_fn=<MulBackward0>)
tensor(0.5473, grad_fn=<MulBackward0>)
tensor(0.4521, grad_fn=<MulBackward0>)
tensor(0.7949, grad_fn=<MulBackward0>)
1 : 14.0
tensor(0.5521, grad_fn=<MulBackward0>)
tensor(

In [None]:
def test(model,episodes):
    env = gym.make("CartPole-v1",render_mode="human")
    state,info = env.reset()
    state = torch.from_numpy(state)
    total_reward=0
    for ep in range(episodes):
       
        action = torch.argmax(model.forward(state)).item()
        # print(action)
        new_state, reward, done,truncated, info= env.step(action)
        # print(env.step(action))
        total_reward+=reward
        # print(f"{ep} Total reward:",total_reward)
        if done or truncated:
            print(total_reward)
            break
        state = torch.from_numpy(new_state)    
    env.close()
        

In [66]:
for i in range(20):
    print(i)
    test(model_P,100)

0
9.0
1
10.0
2
10.0
3
8.0
4
9.0
5
8.0
6
8.0
7
9.0
8
10.0
9
9.0
10
8.0
11
8.0
12
8.0
13
10.0
14
10.0
15
10.0
16
10.0
17
9.0
18
8.0
19
8.0
