In [4]:
import torch
import torch.nn as nn
from torch import optim
import numpy as np

import gym
import random

from model import QFunc
from ReplayMemory import ReplayMemory

lr = 1e-3
gamma = 0.95
epsilon = 0.3
batch_size = 32
initial_exploration = 500

qf = QFunc()
target_qf = QFunc()
#model.state_dict():モデルの学習パラメータをとってきている
target_qf.load_state_dict(qf.state_dict())

optimizer = optim.Adam(qf.parameters(), lr = lr)

criterion = nn.MSELoss()

memory = ReplayMemory()

env = gym.make('CartPole-v0')
obs_size =env.observation_space.shape[0]
action_size = env.action_space.n

total_step = 0

for episode in range(150):
    done = False
    obs = env.reset()
    sum_reward = 0
    step = 0
    
    while not done:
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            action = qf.select_action(obs)
        epsilon -= 1e-4
        if epsilon < 0:
            epsilon = 0
            
        next_obs, reward, done, _ = env.step(action)
        
        terminal = 0
        reward = 0
        if done:
            terminal = 1
            if not step >= 195:
                reward = -1
        sum_reward += reward
        
        memory.add(obs, action, reward, next_obs, terminal)
        obs = next_obs.copy()
        
        step += 1
        total_step += 1
        if total_step < initial_exploration:
            continue
            
        batch = memory.sample()
        
        q_value = qf(batch['obs']).gather(1, batch['actions'])
        
        with torch.no_grad():
            next_q_value = target_qf(batch['next_obs']).max(dim = 1, keepdim = True)[0]
            target_q_value = batch['rewards'] + gamma * next_q_value * (1 - batch['terminates'])
            
        loss =criterion(q_value, target_q_value)
        
        optimizer.zero_grad()
        
        loss.backward()
        optimizer.step()
        
        if total_step % 10 == 0:
            #targetネットワークの更新
            target_qf.load_state_dict(qf.state_dict())
            
    if episode % 10 == 0:
        print('episode:',episode, 'return:', step, 'epsilon:', epsilon)

episode: 0 return: 16 epsilon: 0.29840000000000017
episode: 10 return: 10 epsilon: 0.28840000000000127
episode: 20 return: 9 epsilon: 0.27580000000000265
episode: 30 return: 9 epsilon: 0.2653000000000038
episode: 40 return: 12 epsilon: 0.25420000000000503
episode: 50 return: 9 epsilon: 0.24250000000000632
episode: 60 return: 24 epsilon: 0.22230000000000855
episode: 70 return: 81 epsilon: 0.1483000000000167
episode: 80 return: 122 epsilon: 0.06300000000001749
episode: 90 return: 200 epsilon: 0
episode: 100 return: 198 epsilon: 0
episode: 110 return: 200 epsilon: 0
episode: 120 return: 200 epsilon: 0
episode: 130 return: 200 epsilon: 0
episode: 140 return: 200 epsilon: 0
