In [None]:
from itertools import count

import matplotlib
import random
import sys
from collections import deque, namedtuple
from torch.distributions import Categorical
from time import time

import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.utils as utils
import numpy as np
import matplotlib.pyplot as plt

sys.path.append(r"C:\Users\takat\PycharmProjects\machine-learning")
import flowdata
import flowenv

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    pass

In [None]:
device_name = "cpu"

if True:
    if torch.cuda.is_available():
        device_name = "cuda"
    elif torch.mps.is_available():
        device_name = "mps"
    # elif torch.hip.is_available():
    #     device_name = "hip"
    elif torch.mtia.is_available():
        device_name = "mtia"
    elif torch.xpu.is_available():
        device_name = "xpu"

device = torch.device(device_name)
print(f"device: {device_name}")

In [None]:
# Constants
BATCH_SIZE = 64
LAMBDA = 0.5
GAMMA = 0.99
HIDDEN_SIZE = 128

In [None]:
raw_data_train, raw_data_test = flowdata.flow_data.using_data()

train_env = gym.make("flowenv/FlowTrain-v0", data=raw_data_train)
test_env = gym.make("flowenv/FlowTest-v0", data=raw_data_test)

In [None]:
class A2C(nn.Module):
    def __init__(self, n_inputs, n_outputs, random_seed=None):
        super().__init__()

        if random_seed:
            # env.seed(random_seed)
            torch.manual_seed(random_seed)

        self.in_size = n_inputs
        self.out_size = n_outputs

    def forward(self, state):
        pass

class Actor(A2C):
    def __init__(self, n_inputs, n_outputs, random_seed=None):
        super().__init__(n_inputs, n_outputs, random_seed)

        self.fc1 = nn.Linear(self.in_size, self.hidden_size)
        self.fc2 = nn.Linear(self.hidden_size, self.out_size)

    def forward(self, state):
        x = torch.relu(self.fc1(state.clone().detach()))
        x = self.fc2(x)

        return x

class Critic(A2C):
    def __init__(self, n_inputs, n_outputs, random_seed=None):
        super().__init__(n_inputs, n_outputs, random_seed)

        self.fc1 = nn.Linear(self.in_size, self.hidden_size)
        self.fc2 = nn.Linear(self.hidden_size, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state.clone().detach()))
        x = self.fc2(x)

        return x

In [None]:
Transaction = namedtuple('Transaction', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        # self.capacity = capacity
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transaction(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
def select_action(state_tensor):
    global actor, critic
    # state_tensor = torch.tensor(state, device=device, dtype=torch.float32).unsqueeze(0)
    if torch.isnan(state_tensor).any():
        print("state_tensor has NaN")
        raise Exception("state_tensor has NaN")

    logits = actor(state_tensor)

    if torch.isnan(logits).any():
        print("logits has NaN")
        raise Exception("logits has NaN")
    action = Categorical(logits=logits).sample()

    return action

def optimize_model():
    global actor, critic, actor_optimizer, critic_optimizer, memory
    if len(memory) < BATCH_SIZE:
        return
    critic_optimizer.zero_grad()
    actor_optimizer.zero_grad()

    transactions = memory.sample(BATCH_SIZE)
    batch = Transaction(*zip(*transactions))

    state_batch = torch.stack(batch.state)
    action_batch = torch.stack(batch.action).unsqueeze(1)
    reward_batch = torch.tensor(batch.reward, dtype=torch.float32)
    next_state_batch = torch.stack([torch.tensor(ns, dtype=torch.float32) for ns in batch.next_state])

    values = critic(state_batch).squeeze()
    next_values = critic(next_state_batch).squeeze()
    targets = reward_batch + critic.gamma * next_values
    advantages = targets - values

    action_probs = actor(state_batch)
    action_log_props = torch.log(action_probs.gather(1, action_batch))
    actor_loss = -torch.mean(action_log_props * advantages.detach())

    critic_loss = nn.MSELoss()(values, targets.detach())

    total_loss = actor_loss + LAMBDA * critic_loss

    total_loss.backward()

    utils.clip_grad_norm_(actor.parameters(), 1.0)
    utils.clip_grad_norm_(critic.parameters(), 1.0)

    actor_optimizer.step()
    critic_optimizer.step()

def get_h_m_s(seconds: float):
    hours = int(seconds // 3600)
    minutes = int((seconds - hours * 3600) // 60)
    seconds = seconds - hours * 3600 - minutes * 60
    return hours, minutes, seconds

def loading_bar(episode, total_episodes, interval):
    pro_size_float = (episode + 1) / total_episodes * 20
    show = pro_size_float * 5
    pro_size = int(pro_size_float)

    # episode...interval -> total_episodes...interval * total_episodes / episode
    last_time = interval * (total_episodes - episode) / (episode + 1)
    hours, minutes, seconds = get_h_m_s(last_time)
    print(f"\r[{'#' * pro_size}{' ' * (20 - pro_size)}] {show:3.02f}%, last={hours:02d}:{minutes:02d}:{seconds:03.3f}", end="")

In [None]:
actor = Actor(train_env.observation_space.shape[0], train_env.action_space.n).to(device)
critic = Critic(train_env.observation_space.shape[0], train_env.action_space.n).to(device)

actor_optimizer = torch.optim.Adam(actor.parameters(), lr=1e-3)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=1e-3)

memory = ReplayMemory(10000)
episode_rewards = []

num_episodes = 10000

start_time = time()
for i_episode in range(num_episodes):
    state, _ = train_env.reset()
    state_tensor = torch.tensor(state, dtype=torch.float32, device=device)
    done = False
    sum_rewards = 0

    for t in count():
        action = select_action(state_tensor)
        next_state, reward, terminated, truncated, _ = train_env.step(action.item())

        reward = torch.tensor([reward], device=device)
        done = bool(terminated)

        memory.push(state_tensor, action, next_state, reward)
        sum_rewards += reward.item()

        if terminated:
            next_state_tensor = None
        else:
            next_state_tensor = torch.tensor(next_state, dtype=torch.float32, device=device)

    end_time = time()
    loading_bar(i_episode, num_episodes, end_time - start_time)
    episode_rewards.append(sum_rewards)
    optimize_model()

In [None]:
mean_rewards = []

for i in range(0, len(episode_rewards)):
    # print(f"Episode {i}, mean reward: {np.mean(episode_rewards[0:i])}")
    mean_rewards.append(np.mean(episode_rewards[0:i]))

plt.figure(figsize=(10, 5))
plt.plot(episode_rewards)
plt.plot(mean_rewards, color="red")
plt.show()