In [None]:
'''
kaggle - https://www.kaggle.com/code/leejin11/deepsarsa-code/edit
'''

In [None]:
# SARSA => s, a, r, s', a'
# Q(s, a) <- Q(s, a) + α(r + γQ(s', a') - Q(s, a))
#                       (           td           )
# on police
# DeepSARSA => q_table -> nn

In [None]:
import numpy as np
import random

import torch, torch.nn as nn
import gym

import matplotlib.pyplot as plt
import os, sys


In [None]:
### 시각화

In [None]:
def make_plot(scores, episodes):
    plt.figure(figsize=(10, 5))
    plt.plot(episodes, scores, label='Score per Episode')
    plt.xlabel('Episode')
    plt.ylabel('Score')
    plt.title('Deep SARSA: CartPole-v1')
    plt.tight_layout()
    plt.savefig(f'deep_sarsa_plot.jpg')
    plt.close()


In [None]:
class DeepSARSA(nn.Module):
    '''DeepSARSA network
        Args:
            state_size (int): state size
            action_size (int): action size
            c_mid (int): hidden layer size
    '''
    def __init__(self, state_size, action_size, c_mid):
        super(DeepSARSA, self).__init__()
        self.sarsa_net = nn.Sequential(
            # Input [state_size]
            nn.Linear(state_size, c_mid), # [c_mid]
            nn.ReLU(),
            nn.Linear(c_mid, action_size), #[action_size]
        )

    def forward(self, x):
        # Input [state_size]
        x = self.sarsa_net(x) # [action_size]
        return x
        
class DeepSARSAAgent:
    '''DeepSARSA를 활용할 수 있는 함수들의 class
        Args:
            state_size (int): state size
            action_size (int): action size
            c_mid (int): hidden layer size
    '''
    def __init__(self, state_size, action_size, c_mid=32, device='cpu'):
        self.state_size = state_size
        self.action_size = action_size
        
        # parms
        self.discount_factor = 0.99
        self.learning_rate = 0.001
        self.epsilon = 1.0
        self.epsilon_decay = 0.9995
        self.epsilon_min = 0.01
        
        self.model = DeepSARSA(self.state_size, self.action_size, c_mid).to(device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.loss = nn.MSELoss()

    # state에 따른 action을 불러오는 함수
    def get_action(self, state):
        # exploration 실행
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.action_size)
        else:
            state = torch.FloatTensor(state).unsqueeze(0).to(device)
            with torch.no_grad():
                return self.model(state).argmax().item()
    
    def update(self, state, action, reward, next_state, next_action, done):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        next_state = torch.FloatTensor(next_state).unsqueeze(0).to(device)

        # value를 가져오기
        q = self.model(state)[0][action]
        next_q = self.model(next_state)[0][next_action]

        target = reward if done else reward + self.discount_factor * next_q

        target = torch.tensor([target], dtype=torch.float32).to(device)
        loss = self.loss(q.unsqueeze(0), target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# gym open source 사용
env = gym.make('CartPole-v1')

state_size = env.observation_space.shape[0] # 4 <= [position, velocity, angle, angular_velocity]
action_size = env.action_space.n # 2 <= [left, right]

agent = DeepSARSAAgent(state_size,action_size, device=device)
scores, episodes = [], []

score_line = {'line':400, 'count':0}
EPISODE = 3000

In [None]:
### train

In [None]:
for episode in range(1, EPISODE + 1):
    
    state = env.reset()
    action = agent.get_action(state)

    score = 0
    
    while True:
    
        next_state, reward, done, _ = env.step(action)
        next_action = agent.get_action(next_state)

        # s, a, r, s', a'을 업데이트
        agent.update(state, action, reward, next_state, next_action, done)
        
        score += reward
        state, action = next_state, next_action

        if done:
            scores.append(score)
            episodes.append(episode)

            make_plot(scores, episodes)
            
            if episode % 10 == 0:
                print(f'episode : {episode:3d}\t score : {score:3.1f}\t epsilon : {agent.epsilon:.3f}')

            # 400점 이상이 지속되면 탈출
            if score >= score_line['line']:
                score_line['count'] += 1
                if score_line['count'] > 5:
                    print(f'episode : {episode:3d}\t score : {score:3.1f}\t epsilon : {agent.epsilon:.3f}')
                    torch.save(agent.model.state_dict(), f'deep_sarsa_{episode}.pth')
                    model_path = f'./deep_sarsa_{episode}.pth'
                    sys.exit()
            else:
                score_line['count'] = 0
                
            break

In [None]:
## test ##

In [None]:
class DeepSARSAAgent:
    def __init__(self, state_size, action_size, c_mid=32):
        
        self.state_size = state_size
        self.action_size = action_size

        
        self.model = DeepSARSA(state_size, action_size, c_mid)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()
    
    def get_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            return self.model(state).argmax().item()

In [None]:
from gym.wrappers import RecordVideo

env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(env, video_folder="./test_video", episode_trigger=lambda e: True)

state_size = env.observation_space.shape[0]
action_size = env.action_space.n

model = DeepSARSAAgent(state_size, action_size)

scores = []
EPISODE = 5
for episode in range(1, EPISODE+1):
    
    state = env.reset()
    score = 0

    while True:
        action = agent.get_action(state)            
        next_state, reward, done, _ = env.step(action)
    
        score += reward
        state = next_state
        
        if done:
            scores.append(score)
            print(f'episode {episode} score : {score}')
            break
            
print(f'max score : {max(scores)}')