In [38]:
import gymnasium as gym
import pygame
from gymnasium.utils.play import play
# from gynasium.utils.play import pl
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import random
%matplotlib inline

matplotlib.rcParams['figure.facecolor'] = '#ffffff'


In [39]:
env = gym.make("CartPole-v1",render_mode="human")

In [44]:
class Q_Network(nn.Module):
    def __init__(self):
        super(Q_Network, self).__init__()
        self.network = nn.Sequential(nn.Linear(4,6),
                                     nn.ReLU(),
                                     nn.Linear(6,4),
                                     nn.ReLU(),
                                     nn.Linear(4,2))
    def forward(self,state):
            return self.network(state)
            
        

In [45]:
def mse(t1,t2):
    diff = t1-t2
    return torch.sum(diff*diff/diff.numel())

In [91]:
# @torch.no_grad()
def train(epochs,lr,model,rm_size,batch_size,max_steps,gamma,opt_func=torch.optim.SGD):

    max_epsilon = 1.0             # Exploration probability at start
    min_epsilon = 0.05
    decay_rate = 0.0005 
    epsilon = max_epsilon
    env = gym.make("CartPole-v1",render_mode="human")
    optimizer = opt_func(model.parameters(),lr)
    replay_memory = list()
    # replay_memory = replay_memory
    for epoch in range(epochs):
        state , info = env.reset()
        state = torch.from_numpy(state)
        # print(device)
        # state = state.to(device)
        
        # print(state.device)
        total_reward = 0
        for step in range(max_steps):
            
            #Sampling Phase
            exp_exp_tradeoff = random.uniform(0, 1)
            if exp_exp_tradeoff > epsilon:
                action = torch.argmax(model.forward(state)).item()
            else:
                action = env.action_space.sample()

            new_state, reward, done,truncated, info= env.step(action)
            new_state = torch.from_numpy(new_state)
            memory = [state,action,reward,new_state]
            total_reward += reward
            
            if len(replay_memory) == rm_size:
                replay_memory.pop(0)
            replay_memory.append(memory)
            if done or truncated:
                print(epoch,":",total_reward)
                break
            
            #Training Phase
            batch = random.sample(replay_memory,batch_size) if len(replay_memory)>batch_size else []
            for element in batch:
                
                rand_state,rand_action,rand_reward,target_state = replay_memory[random.randint(0,len(replay_memory)-1)]
                # target_state = target_state.to(device)
                target_Q = rand_reward + gamma*torch.max(model.forward(target_state))
                rand_output = model.forward(rand_state)
                # assert rand_output.requires_grad == True, "Model output must have requires_grad=True"
    
                trial_Q = rand_output[rand_action]
                
                # print("hi")
                # print(target_Q)
                # print(trial_Q)
                loss = F.mse_loss(target_Q,trial_Q)
                # print(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            #State-Change
            state = new_state
            
        epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*epoch) 
        # epsilon = 0.01
        
    env.close()

In [95]:
model = Q_Network()
model

Q_Network(
  (network): Sequential(
    (0): Linear(in_features=4, out_features=6, bias=True)
    (1): ReLU()
    (2): Linear(in_features=6, out_features=4, bias=True)
    (3): ReLU()
    (4): Linear(in_features=4, out_features=2, bias=True)
  )
)

In [100]:
epochs=500
# epsilon=1.0
# max_epsilon = 1.0             # Exploration probability at start
# min_epsilon = 0.2 
max_steps=300
gamma=0.7
lr = 0.003
rm_size=60
batch_size=10

In [97]:

train(epochs,lr,model,rm_size,batch_size,max_steps,gamma)

0 : 12.0
1 : 19.0
2 : 29.0
3 : 32.0
4 : 11.0
5 : 14.0
6 : 35.0
7 : 16.0
8 : 25.0
9 : 9.0
10 : 20.0
11 : 58.0
12 : 13.0
13 : 18.0
14 : 13.0
15 : 30.0
16 : 16.0
17 : 93.0
18 : 12.0
19 : 18.0
20 : 35.0
21 : 30.0
22 : 15.0
23 : 16.0
24 : 25.0
25 : 16.0
26 : 19.0
27 : 25.0
28 : 14.0
29 : 24.0
30 : 12.0
31 : 30.0
32 : 33.0
33 : 37.0
34 : 19.0
35 : 20.0
36 : 9.0
37 : 11.0
38 : 28.0
39 : 21.0
40 : 21.0
41 : 18.0
42 : 14.0
43 : 11.0
44 : 24.0
45 : 18.0
46 : 12.0
47 : 18.0
48 : 13.0
49 : 51.0
50 : 47.0
51 : 10.0
52 : 12.0
53 : 22.0
54 : 14.0
55 : 26.0
56 : 26.0
57 : 28.0
58 : 13.0
59 : 22.0
60 : 71.0
61 : 13.0
62 : 8.0
63 : 11.0
64 : 24.0
65 : 29.0
66 : 11.0
67 : 9.0
68 : 31.0
69 : 14.0
70 : 15.0
71 : 14.0
72 : 24.0
73 : 23.0
74 : 14.0
75 : 10.0
76 : 32.0
77 : 12.0
78 : 12.0
79 : 20.0
80 : 18.0
81 : 13.0
82 : 9.0
83 : 23.0
84 : 30.0
85 : 21.0
86 : 16.0
87 : 26.0
88 : 29.0
89 : 22.0
90 : 24.0
91 : 18.0
92 : 35.0
93 : 24.0
94 : 17.0
95 : 21.0
96 : 11.0
97 : 21.0
98 : 10.0
99 : 28.0
100 : 17.0
101 

In [65]:
def test(model,episodes):
    env = gym.make("CartPole-v1",render_mode="human")
    state,info = env.reset()
    state = torch.from_numpy(state)
    total_reward=0
    for ep in range(episodes):
       
        action = torch.argmax(model.forward(state)).item()
        # print(action)
        new_state, reward, done,truncated, info= env.step(action)
        # print(env.step(action))
        total_reward+=reward
        # print(f"{ep} Total reward:",total_reward)
        if done or truncated:
            print(total_reward)
            break
        state = torch.from_numpy(new_state)    
    env.close()
        

In [99]:
for i in range(20):
    print(i)
    test(model,100)

0
9.0
1
8.0
2
10.0
3
10.0
4
10.0
5
9.0
6
10.0
7
10.0
8
8.0
9
10.0
10
9.0
11
8.0
12
9.0
13
10.0
14
9.0
15
10.0
16
10.0
17
10.0
18
10.0
19
9.0
