In [1]:
import gymnasium as gym
import pygame
from gymnasium.utils.play import play
# from gynasium.utils.play import pl
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import random
%matplotlib inline

matplotlib.rcParams['figure.facecolor'] = '#ffffff'


In [2]:
env = gym.make("CartPole-v1",render_mode="human")

In [3]:
class Q_Network(nn.Module):
    def __init__(self):
        super(Q_Network, self).__init__()
        self.network = nn.Sequential(nn.Linear(4,6),
                                     nn.ReLU(),
                                     nn.Linear(6,4),
                                     nn.ReLU(),
                                     nn.Linear(4,2))
    def forward(self,state):
            return self.network(state)
            
        

In [4]:
def mse(t1,t2):
    diff = t1-t2
    return torch.sum(diff*diff/diff.numel())

In [5]:
# @torch.no_grad()
def train(epochs,lr,model,rm_size,batch_size,max_steps,gamma,opt_func=torch.optim.SGD):

    max_epsilon = 1.0             # Exploration probability at start
    min_epsilon = 0.05
    decay_rate = 0.0005 
    epsilon = max_epsilon
    env = gym.make("CartPole-v1",render_mode="human")
    optimizer = opt_func(model.parameters(),lr)
    replay_memory = list()
    # replay_memory = replay_memory
    for epoch in range(epochs):
        state , info = env.reset()
        state = torch.from_numpy(state)
        # print(device)
        # state = state.to(device)
        
        # print(state.device)
        total_reward = 0
        for step in range(max_steps):
            
            #Sampling Phase
            exp_exp_tradeoff = random.uniform(0, 1)
            if exp_exp_tradeoff > epsilon:
                action = torch.argmax(model.forward(state)).item()
            else:
                action = env.action_space.sample()

            new_state, reward, done,truncated, info= env.step(action)
            new_state = torch.from_numpy(new_state)
            memory = [state,action,reward,new_state]
            total_reward += reward
            
            if len(replay_memory) == rm_size:
                replay_memory.pop(0)
            replay_memory.append(memory)
            if done or truncated:
                print(epoch,":",total_reward)
                break
            
            #Training Phase
            batch = random.sample(replay_memory,batch_size) if len(replay_memory)>batch_size else []
            for element in batch:
                
                rand_state,rand_action,rand_reward,target_state = replay_memory[random.randint(0,len(replay_memory)-1)]
                # target_state = target_state.to(device)
                target_Q = rand_reward + gamma*torch.max(model.forward(target_state))
                rand_output = model.forward(rand_state)
                # assert rand_output.requires_grad == True, "Model output must have requires_grad=True"
    
                trial_Q = rand_output[rand_action]
                
                # print("hi")
                # print(target_Q)
                # print(trial_Q)
                loss = F.mse_loss(target_Q,trial_Q)
                # print(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            #State-Change
            state = new_state
            
        epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*epoch) 
        # epsilon = 0.01
        
    env.close()

In [6]:
model = Q_Network()
model

Q_Network(
  (network): Sequential(
    (0): Linear(in_features=4, out_features=6, bias=True)
    (1): ReLU()
    (2): Linear(in_features=6, out_features=4, bias=True)
    (3): ReLU()
    (4): Linear(in_features=4, out_features=2, bias=True)
  )
)

In [7]:
epochs=500
# epsilon=1.0
# max_epsilon = 1.0             # Exploration probability at start
# min_epsilon = 0.2 
max_steps=300
gamma=0.7
lr = 0.003
rm_size=60
batch_size=10

In [8]:

train(epochs,lr,model,rm_size,batch_size,max_steps,gamma)

0 : 27.0
1 : 21.0
2 : 79.0
3 : 51.0
4 : 28.0
5 : 24.0
6 : 23.0
7 : 48.0
8 : 18.0
9 : 39.0
10 : 13.0
11 : 16.0
12 : 48.0
13 : 14.0
14 : 34.0
15 : 19.0
16 : 10.0
17 : 16.0
18 : 10.0
19 : 45.0
20 : 23.0
21 : 13.0
22 : 30.0
23 : 20.0
24 : 18.0
25 : 15.0
26 : 34.0
27 : 22.0
28 : 19.0
29 : 35.0
30 : 17.0
31 : 14.0
32 : 27.0
33 : 28.0
34 : 17.0
35 : 49.0
36 : 28.0
37 : 11.0
38 : 54.0
39 : 28.0
40 : 16.0
41 : 15.0
42 : 16.0
43 : 15.0
44 : 23.0
45 : 13.0
46 : 26.0
47 : 32.0
48 : 29.0
49 : 33.0
50 : 25.0
51 : 43.0
52 : 16.0
53 : 18.0
54 : 18.0
55 : 36.0
56 : 12.0
57 : 27.0
58 : 35.0
59 : 56.0
60 : 28.0
61 : 31.0
62 : 36.0
63 : 14.0
64 : 23.0
65 : 26.0
66 : 12.0
67 : 19.0
68 : 82.0
69 : 15.0
70 : 43.0
71 : 17.0
72 : 33.0
73 : 21.0
74 : 20.0
75 : 30.0
76 : 25.0
77 : 13.0
78 : 15.0
79 : 18.0
80 : 11.0
81 : 29.0
82 : 14.0
83 : 24.0
84 : 18.0
85 : 15.0
86 : 32.0
87 : 26.0
88 : 13.0
89 : 17.0
90 : 21.0
91 : 12.0
92 : 27.0
93 : 11.0
94 : 19.0
95 : 15.0
96 : 15.0
97 : 34.0
98 : 37.0
99 : 18.0
100 : 39.0

In [9]:
def test(model,episodes):
    env = gym.make("CartPole-v1",render_mode="human")
    state,info = env.reset()
    state = torch.from_numpy(state)
    total_reward=0
    for ep in range(episodes):
       
        action = torch.argmax(model.forward(state)).item()
        # print(action)
        new_state, reward, done,truncated, info= env.step(action)
        # print(env.step(action))
        total_reward+=reward
        # print(f"{ep} Total reward:",total_reward)
        if done or truncated:
            print(total_reward)
            break
        state = torch.from_numpy(new_state)    
    env.close()
        

In [10]:
for i in range(20):
    print(i)
    test(model,100)

0
9.0
1
10.0
2
11.0
3
10.0
4
9.0
5
9.0
6
10.0
7
10.0
8
10.0
9
8.0
10
10.0
11
8.0
12
11.0
13
10.0
14
10.0
15
10.0
16
10.0
17
8.0
18
10.0
19
8.0
