In [1]:
# imports
import time
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam,SGD
from collections import deque
import random
from matplotlib import pyplot as plt
import copy
import numpy as np
import gym
from torchsummary import summary
import warnings
warnings.filterwarnings('ignore')

# set seed for torch library
torch.manual_seed(33)

<torch._C.Generator at 0x7fb4217c5cf0>

In [2]:
# instantiate DQN
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4,16)
        self.fc2 = nn.Linear(16,16)
        self.fc3 = nn.Linear(16,2)
    
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
# test the DQN model
env = gym.make('CartPole-v1')
model_500 = torch.load('cartpoleDQN.pth')
for e in range(10):
    sta = env.reset()[0]
    sta = torch.from_numpy(sta)
    # print(sta)
    done = False
    i = 0
    reward = 0
    while not done:
        env.render()
        action = torch.argmax(model_500(sta))
        new_sta, rew, done, _, _ = env.step(action.item())
        sta = torch.from_numpy(new_sta)
        cart_velocity = new_sta[1]
        # print(cart_velocity)
        i += 1
        reward += rew
        if i == 10000:
            done = True
        if done:
            print(e, i, reward)
    env.close()

0 10000 10000.0
1 10000 10000.0
2 10000 10000.0
3 10000 10000.0
4 10000 10000.0
5 10000 10000.0
6 10000 10000.0
7 10000 10000.0
8 10000 10000.0
9 10000 10000.0


In [5]:
# test PID

env = gym.make('CartPole-v1')

for e in range(10):
    sta = env.reset()
    Kp = 135
    Ki = 96.5
    Kd = 47.5
    action = 0
    integral = 0

    done = False
    i = 0
    rew = 0
    while not done:
        env.render()
        observation, reward, done, info, _ = env.step(action)

        cart_velocity = observation[1]
        pole_angle = observation[2]
        pole_angular_velocity = observation[3]
        # print(cart_velocity)

        integral += pole_angle

        out = Kp*(pole_angle) + Kd*(pole_angular_velocity) + Ki*(integral)

        if out > 0:
            action = 1
        else:
            action = 0
            
        i += 1
        rew += reward
        if i == 10000:
            done = True
        if done:
            observation = env.reset()
            integral = 0
            print(e, i, rew)
    env.close()

0 1812 1812.0
1 3718 3718.0
2 2149 2149.0
3 6082 6082.0
4 4746 4746.0
5 3083 3083.0
6 8414 8414.0
7 10000 10000.0
8 1898 1898.0
9 1534 1534.0
