In [11]:
from collections import deque
import gym
import matplotlib.pyplot as plt
from IPython import display
import time
import torch
import torch.nn as nn
import numpy as np
import random

In [12]:
env = gym.make("Acrobot-v1")
state = env.reset()
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
batch_size = 64
epochs = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [13]:
class NNet(nn.Module):
    def __init__(self):
        super(NNet,self).__init__()
        self.lin1 = nn.Linear(state_size,16)
        self.lin2 = nn.Linear(16,32)
        self.lin3 = nn.Linear(32,action_size)
        self.relu = nn.ReLU()

    def forward(self,x):
#         pred = torch.from_numpy(x.reshape(1,state_size)).float().to(device)
        pred = self.lin1(x)
        pred = self.relu(pred)
        pred = self.lin2(pred)
        pred = self.relu(pred)
        pred = self.lin3(pred)

        return pred

In [14]:
class dqn_agent():
    def __init__(self,state_size,action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95
        self.epsilon = 1
        self.epsilon_rate = 0.995
        self.epsilon_lim = 0.01
        self.lr = 0.001
        self.model = NNet().to(device)
        self.target_model = NNet().to(device)
        self.criterion = nn.MSELoss()
        self.optim = torch.optim.Adam(self.model.parameters(),lr=self.lr)
    def train(self,batch_size):
        batch = random.sample(self.memory,batch_size)
        self.target_model.load_state_dict(self.model.state_dict())
        for state,action,reward,next_state,done in batch:
            state = torch.from_numpy(state.reshape(1,state_size)).float().to(device)
            next_state = torch.from_numpy(next_state.reshape(1,state_size)).float().to(device)
            pred = self.model(state)[0][action]
            target = torch.tensor(reward,dtype=torch.float32).to(device)
            if not done:
                action_index = torch.argmax(self.model(next_state)[0])
                target = target+self.target_model(next_state)[0][action_index]
            loss = self.criterion(pred,target.to(device))
            loss.backward()
            self.optim.step()
            self.optim.zero_grad()
        if self.epsilon>self.epsilon_lim:
            self.epsilon*=self.epsilon_rate
    def action(self,state):
        if np.random.rand()<=self.epsilon:
            return torch.tensor(random.randrange(self.action_size))
        else:
            state = torch.from_numpy(state.reshape(1,state_size)).float().to(device)
            return torch.argmax(self.model(state)[0])
    def remember(self,state,action,reward,next_state,done):
        self.memory.append((state,action,reward,next_state,done))

In [15]:
agent = dqn_agent(state_size,action_size)

for epoch in range(epochs):
    state = env.reset()
    
    done = False
    times = 0
    while not done:
        action = agent.action(state)
        if epoch%20 == 0:
            env.render()
        next_state, reward, done, info = env.step(action.to('cpu').numpy())
        if done:
            if reward==0:
                reward = 3000
            else:
                reward =-3000
        else:
            reward = -5+abs(next_state[4])
                
        agent.remember(state,action,reward,next_state,done)
        state = next_state
        if done:
            print(f'epoch = {epoch}/{epochs}, score = {times}, epsilon = {agent.epsilon:.4f}')
        times+=1
        if not done and len(agent.memory)>=batch_size and times%32==0:
            agent.train(batch_size)
    if epoch%20 == 0:
        env.close()
    if times <= 90:
        break

epoch = 0/1000, score = 499, epsilon = 0.9322
epoch = 1/1000, score = 499, epsilon = 0.8647
epoch = 2/1000, score = 499, epsilon = 0.8021
epoch = 3/1000, score = 499, epsilon = 0.7440
epoch = 4/1000, score = 499, epsilon = 0.6901
epoch = 5/1000, score = 499, epsilon = 0.6401
epoch = 6/1000, score = 499, epsilon = 0.5937
epoch = 7/1000, score = 459, epsilon = 0.5535
epoch = 8/1000, score = 499, epsilon = 0.5134
epoch = 9/1000, score = 499, epsilon = 0.4762
epoch = 10/1000, score = 499, epsilon = 0.4417
epoch = 11/1000, score = 499, epsilon = 0.4097
epoch = 12/1000, score = 499, epsilon = 0.3801
epoch = 13/1000, score = 499, epsilon = 0.3525
epoch = 14/1000, score = 499, epsilon = 0.3270
epoch = 15/1000, score = 499, epsilon = 0.3033
epoch = 16/1000, score = 499, epsilon = 0.2813
epoch = 17/1000, score = 499, epsilon = 0.2610
epoch = 18/1000, score = 499, epsilon = 0.2421
epoch = 19/1000, score = 499, epsilon = 0.2245
epoch = 20/1000, score = 499, epsilon = 0.2083
epoch = 21/1000, score 

In [16]:
# play test:
for epoch in range(100):
    state = env.reset()   
    done = False
    times = 0
    while not done:
        action = agent.action(state)
        env.render()
        next_state, reward, done, info = env.step(action.to('cpu').numpy())
        state = next_state
        if done:
            print(f'episode = {epoch}/{100}, score = {times}')
        times+=1
    env.close()

episode = 0/100, score = 141
episode = 1/100, score = 112
episode = 2/100, score = 92
episode = 3/100, score = 85
episode = 4/100, score = 92
episode = 5/100, score = 102
episode = 6/100, score = 91
episode = 7/100, score = 91
episode = 8/100, score = 91
episode = 9/100, score = 124
episode = 10/100, score = 114
episode = 11/100, score = 91
episode = 12/100, score = 98
episode = 13/100, score = 95
episode = 14/100, score = 101
episode = 15/100, score = 102
episode = 16/100, score = 141
episode = 17/100, score = 118
episode = 18/100, score = 91
episode = 19/100, score = 101
episode = 20/100, score = 232
episode = 21/100, score = 93
episode = 22/100, score = 90
episode = 23/100, score = 91
episode = 24/100, score = 110
episode = 25/100, score = 133
episode = 26/100, score = 105
episode = 27/100, score = 97
episode = 28/100, score = 91
episode = 29/100, score = 174
episode = 30/100, score = 109
episode = 31/100, score = 101
episode = 32/100, score = 101
episode = 33/100, score = 89
episod

In [17]:
env.close()