In [17]:
class GridEnvironment:
    def __init__(self, grid_size=4):
        self.grid_size = grid_size
        self.start_position = (0, 0)
        self.goal_position = (grid_size-1, grid_size-1)
        self.state = self.start_position

    def reset(self):
        self.state = self.start_position
        return self.state

    def step(self, action):
        """
        Action: 0=Up, 1=Down, 2=Left, 3=Right
        """
        y, x = self.state
        if action == 0:  # Up
            y = max(y-1, 0)
        elif action == 1:  # Down
            y = min(y+1, self.grid_size-1)
        elif action == 2:  # Left
            x = max(x-1, 0)
        elif action == 3:  # Right
            x = min(x+1, self.grid_size-1)

        self.state = (y, x)

        # Check if goal is reached
        if self.state == self.goal_position:
            return self.state, 1, True  # state, reward, done
        else:
            return self.state, 0, False  # state, reward, done

    def render(self):
        grid = [['-' for _ in range(self.grid_size)] for _ in range(self.grid_size)]
        y, x = self.state
        grid[y][x] = 'A'  # Mark the agent's position
        grid[self.goal_position[0]][self.goal_position[1]] = 'G'  # Mark the goal position
        for row in grid:
            print(' '.join(row))
        print()

# Example usage
env = GridEnvironment(grid_size=4)
print("Initial State:")
env.render()

state, reward, done = env.step(1)  # Move down
print("After moving down:")
env.render()

state, reward, done = env.step(3)  # Move right
print("After moving right:")
env.render()


Initial State:
A - - -
- - - -
- - - -
- - - G

After moving down:
- - - -
A - - -
- - - -
- - - G

After moving right:
- - - -
- A - -
- - - -
- - - G



In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical


In [19]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(2, 128)  # 상태는 (y, x) 좌표이므로 입력 크기는 2
        self.fc2 = nn.Linear(128, 4)  # 출력 크기는 4 (Up, Down, Left, Right)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.softmax(x, dim=1)


In [20]:
def reinforce(env, policy_network, optimizer, episodes=1000, gamma=0.99):
    for episode in range(episodes):
        saved_log_probs = []
        rewards = []
        state = env.reset()
        done = False

        while not done:
            state = torch.tensor([state], dtype=torch.float)
            probs = policy_network(state)
            m = Categorical(probs)
            action = m.sample()
            saved_log_probs.append(m.log_prob(action))

            state, reward, done = env.step(action.item())
            rewards.append(reward)

        R = 0
        policy_loss = []
        for r in reversed(rewards):
            R = r + gamma * R
            policy_loss.append(-saved_log_probs.pop() * R)

        optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        optimizer.step()

        if episode % 100 == 0:
            print('Episode {}: Loss = {:.4f}'.format(episode, policy_loss.item()))


In [16]:
env = GridEnvironment(grid_size=4)
policy_network = PolicyNetwork()
optimizer = optim.Adam(policy_network.parameters(), lr=0.01)

reinforce(env, policy_network, optimizer, episodes=1000)


Episode 0: Loss = 31.8666
Episode 100: Loss = 6.1345
Episode 200: Loss = 3.1809
Episode 300: Loss = 9.8116
Episode 400: Loss = 0.9939
Episode 500: Loss = 3.4993
Episode 600: Loss = 4.3215
Episode 700: Loss = 2.9629
Episode 800: Loss = 7.4055
Episode 900: Loss = 2.6622


In [24]:
# def reinforce_with_baseline(env, policy_network, optimizer, episodes=1000, gamma=0.99):
#     for episode in range(episodes):
#         saved_log_probs = []
#         rewards = []
#         state = env.reset()
#         done = False

#         while not done:
#             state = torch.tensor([state], dtype=torch.float)
#             probs = policy_network(state)
#             m = Categorical(probs)
#             action = m.sample()
#             saved_log_probs.append(m.log_prob(action))

#             state, reward, done = env.step(action.item())
#             rewards.append(reward)

#         R = 0
#         policy_loss = []
#         returns = []
#         for r in reversed(rewards):
#             R = r + gamma * R
#             returns.insert(0, R)

#         returns = torch.tensor(returns)
#         returns = (returns - returns.mean()) / (returns.std() + 1e-9)  # 베이스라인 추가

#         for log_prob, R in zip(saved_log_probs, returns):
#             policy_loss.append(-log_prob * R)

#         optimizer.zero_grad()
#         policy_loss = torch.cat(policy_loss).sum()
#         policy_loss.backward()
#         optimizer.step()

#         if episode % 100 == 0:
#             print(f'Episode {episode}: Loss = {policy_loss.item()}')


In [25]:
# env = GridEnvironment(grid_size=4)
# policy_network = PolicyNetwork()
# optimizer = optim.Adam(policy_network.parameters(), lr=0.01)

# reinforce_with_baseline(env, policy_network, optimizer, episodes=1000)


Episode 0: Loss = 0.470025897026062
Episode 100: Loss = 0.12820225954055786
Episode 200: Loss = -0.23298974335193634
Episode 300: Loss = 0.3718506693840027
Episode 400: Loss = -0.07241284102201462
Episode 500: Loss = 0.2414531111717224
Episode 600: Loss = 2.145578384399414
Episode 700: Loss = 0.01262718066573143
Episode 800: Loss = -2.2827460765838623
Episode 900: Loss = 0.014667188748717308


In [26]:
def visualize_policy(env, policy_network):
    directions = ['↑', '↓', '←', '→']
    for y in range(env.grid_size):
        for x in range(env.grid_size):
            state = torch.tensor([[y, x]], dtype=torch.float)
            with torch.no_grad():
                probs = policy_network(state)
            best_action = torch.argmax(probs).item()
            # 목표 지점에는 G를, 시작 지점에는 S를 표시합니다.
            if (y, x) == env.goal_position:
                print(' G ', end='')
            elif (y, x) == env.start_position:
                print(' S ', end='')
            else:
                print(f' {directions[best_action]} ', end='')
        print()
    print()

# 학습된 정책 시각화
visualize_policy(env, policy_network)


 S  →  ↓  ↓ 
 →  →  ↓  ↓ 
 →  →  →  ↓ 
 →  →  →  G 

