In [1]:
import gymnasium as gym
import pygame
from gymnasium.utils.play import play
# from gynasium.utils.play import pl
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import random
%matplotlib inline

matplotlib.rcParams['figure.facecolor'] = '#ffffff'


In [2]:
env = gym.make("CartPole-v1",render_mode="human")

In [27]:
class Q_Network(nn.Module):
    def __init__(self):
        super(Q_Network, self).__init__()
        self.network = nn.Sequential(nn.Linear(4,16),
                                     nn.Tanh(),
                                     nn.Linear(16,8),
                                     nn.Tanh(),
                                     nn.Linear(8,2))
    def forward(self,state):
            return self.network(state)
            
        

In [28]:
def mse(t1,t2):
    diff = t1-t2
    return torch.sum(diff*diff/diff.numel())

In [32]:
# @torch.no_grad()
def train(epochs,lr,model,rm_size,batch_size,max_steps,gamma,opt_func=torch.optim.Adam):

    max_epsilon = 1.0             # Exploration probability at start
    min_epsilon = 0.05
    decay_rate = 0.0005 
    epsilon = max_epsilon
    env = gym.make("CartPole-v1",render_mode="human")
    optimizer = opt_func(model.parameters(),lr)
    replay_memory = list()
    model_old = model
    # replay_memory = replay_memory
    history=[]
    for epoch in range(epochs):
        state , info = env.reset()
        state = torch.from_numpy(state)
        # print(device)
        # state = state.to(device)
        
        # print(state.device)
        total_reward = 0
        for step in range(max_steps):
            
            #Sampling Phase
            exp_exp_tradeoff = random.uniform(0, 1)
            if exp_exp_tradeoff > epsilon:
                action = torch.argmax(model.forward(state)).item()
            else:
                action = env.action_space.sample()

            new_state, reward, done,truncated, info= env.step(action)
            new_state = torch.from_numpy(new_state)
            memory = [state,action,reward,new_state]
            total_reward += reward
            
            if len(replay_memory) == rm_size:
                replay_memory.pop(0)
            replay_memory.append(memory)
            if done or truncated:
                print(epoch,":",total_reward)
                break
            
            #Training Phase
            batch = random.sample(replay_memory,batch_size) if len(replay_memory)>batch_size else []
            for element in batch:
                
                rand_state,rand_action,rand_reward,target_state = replay_memory[random.randint(0,len(replay_memory)-1)]
                # target_state = target_state.to(device)
                target_Q = rand_reward + gamma*torch.max(model_old.forward(target_state))
                
                # assert rand_output.requires_grad == True, "Model output must have requires_grad=True"
    
                trial_Q = model.forward(rand_state)[rand_action]
                
                # print("hi")
                # print(target_Q)
                # print(trial_Q)
                model_old = model   #saving the model's old parameters in an target model variable
                loss = F.mse_loss(target_Q,trial_Q)
                # print(loss)

                #updating the current model
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                
            #State-Change
            state = new_state

        history.append(total_reward)
        epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*epoch) 
        # epsilon = 0.01
        
    env.close()
    return history

In [33]:
model = Q_Network()
model

Q_Network(
  (network): Sequential(
    (0): Linear(in_features=4, out_features=16, bias=True)
    (1): Tanh()
    (2): Linear(in_features=16, out_features=8, bias=True)
    (3): Tanh()
    (4): Linear(in_features=8, out_features=2, bias=True)
  )
)

In [37]:
epochs=3000
# epsilon=1.0
# max_epsilon = 1.0             # Exploration probability at start
# min_epsilon = 0.2 
max_steps=300
gamma=0.99
lr = 1e-4
rm_size=60
batch_size=10

In [38]:
history=[]

In [None]:
history+=train(epochs,lr,model,rm_size,batch_size,max_steps,gamma)

0 : 16.0
1 : 15.0
2 : 13.0
3 : 14.0
4 : 23.0
5 : 37.0
6 : 15.0
7 : 14.0
8 : 21.0
9 : 20.0
10 : 13.0
11 : 31.0
12 : 10.0
13 : 42.0
14 : 15.0
15 : 23.0
16 : 19.0
17 : 13.0
18 : 28.0
19 : 35.0
20 : 16.0
21 : 13.0
22 : 12.0
23 : 14.0
24 : 14.0
25 : 15.0
26 : 20.0
27 : 26.0
28 : 12.0
29 : 16.0
30 : 16.0
31 : 35.0
32 : 12.0
33 : 21.0
34 : 48.0
35 : 19.0
36 : 24.0
37 : 9.0
38 : 17.0
39 : 16.0
40 : 37.0
41 : 33.0
42 : 57.0
43 : 27.0
44 : 13.0
45 : 13.0
46 : 18.0
47 : 12.0
48 : 19.0
49 : 36.0
50 : 12.0
51 : 15.0
52 : 15.0
53 : 35.0
54 : 12.0
55 : 21.0
56 : 10.0
57 : 14.0
58 : 23.0
59 : 16.0
60 : 15.0
61 : 10.0
62 : 39.0
63 : 16.0
64 : 16.0
65 : 36.0
66 : 13.0
67 : 26.0
68 : 12.0
69 : 11.0
70 : 17.0
71 : 17.0
72 : 13.0
73 : 21.0
74 : 11.0
75 : 34.0
76 : 22.0
77 : 15.0
78 : 22.0
79 : 17.0
80 : 23.0
81 : 25.0
82 : 10.0
83 : 14.0
84 : 34.0
85 : 13.0
86 : 14.0
87 : 13.0
88 : 23.0
89 : 31.0
90 : 18.0
91 : 31.0
92 : 35.0
93 : 17.0
94 : 27.0
95 : 14.0
96 : 10.0
97 : 20.0
98 : 18.0
99 : 10.0
100 : 28.0


In [None]:
plt.figure(figsize=(14, 6))
plt.plot(history, marker='o', linestyle='-', color='b')

window_size=4
smoothed_data = np.convolve(history, np.ones(window_size)/window_size, mode='valid')
smoothed_indices = range(window_size - 1, len(history))
plt.plot(smoothed_indices, smoothed_data, marker='', linestyle='-', color='r', label='Smoothed Data')

plt.title("Model Training")
plt.xlabel("Epochs")
plt.ylabel("Avg Score")
# plt.legend()
plt.grid(True)
plt.show()

In [25]:
def test(model,episodes):
    env = gym.make("CartPole-v1",render_mode="human")
    state,info = env.reset()
    state = torch.from_numpy(state)
    total_reward=0
    for ep in range(episodes):
       
        action = torch.argmax(model.forward(state)).item()
        new_state, reward, done,truncated, info= env.step(action)
        total_reward+=reward
        
        if done or truncated:
            print(total_reward)
            break
        state = torch.from_numpy(new_state)    
    env.close()
    

In [26]:
for i in range(20):
    print(i)
    test(model,100)


0
54.0
1
68.0
2
58.0
3
68.0
4
82.0
5
67.0
6
60.0
7
58.0
8
54.0
9
45.0
10
65.0
11
74.0
12
58.0
13
59.0
14
91.0
15
46.0
16
72.0
17
18
68.0
19
66.0
