# PPO

In [10]:
import matplotlib
import random
import sys
import torch

import numpy as np
import gymnasium as gym
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as utils
import matplotlib.pyplot as plt

from collections import deque, namedtuple
from itertools import count
from torch.distributions import Categorical
from time import time

sys.path.append(r"C:\Users\takat\PycharmProjects\machine-learning")
import flowdata
import flowenv

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

In [11]:
device_name = "cpu"

if True:
    if torch.cuda.is_available():
        device_name = "cuda"
    elif torch.mps.is_available():
        device_name = "mps"
    # elif torch.hip.is_available():
    #     device_name = "hip"
    elif torch.mtia.is_available():
        device_name = "mtia"
    elif torch.xpu.is_available():
        device_name = "xpu"

device = torch.device(device_name)
print(f"device: {device_name}")

device: cuda


In [12]:
LR = 1e-4
GAMMA = 0.99
EPSILON = 0.2
BATCH_SIZE = 64

In [13]:
raw_data_train, raw_data_test = flowdata.flow_data.using_data()

train_env = gym.make("flowenv/FlowTrain-v0", data=raw_data_train)
test_env = gym.make("flowenv/FlowTest-v0", data=raw_data_test)

In [14]:
class PolicyNetwork(nn.Module):
    def __init__(self, n_inputs, n_outputs):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(n_inputs, 128)
        self.fc2 = nn.ReLU()
        self.fc3 = nn.Linear(128, n_outputs)
        self.fc4 = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        return x

In [15]:
Transaction = namedtuple('Transaction', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        # self.capacity = capacity
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transaction(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [32]:
def select_action(state):
    global policy_net
    probs = policy_net(state)
    return Categorical(probs).sample()

def optimize_model():
    global log_probs, policy_net, memory
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transaction(*zip(*transitions))

    states = torch.cat(batch.state)
    actions = torch.cat(batch.action)
    rewards = torch.cat(batch.reward)

    G = 0
    returns = []
    
    for r in torch.flip(rewards, dims=[0]):
        G = r + GAMMA * G
        returns.insert(0, G)

    returns = torch.tensor(returns, device=device, dtype=torch.float32)

    old_probs = policy_net(states).gather(1, actions).detach()

    for _ in range(5):
        new_props = policy_net(states).gather(1, actions)
        ratio = new_props / old_probs

        advantages = returns - returns.mean()

        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - EPSILON, 1 + EPSILON) * advantages
        loss = -torch.min(surr1, surr2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def get_h_m_s(seconds: float):
    hours = int(seconds // 3600)
    minutes = int((seconds - hours * 3600) // 60)
    seconds = seconds - hours * 3600 - minutes * 60
    return hours, minutes, seconds

def loading_bar(episode, total_episodes, interval):
    pro_size_float = (episode + 1) / total_episodes * 20
    show = pro_size_float * 5
    pro_size = int(pro_size_float)

    # episode...interval -> total_episodes...interval * total_episodes / episode
    last_time = interval * (total_episodes - episode) / (episode + 1)
    hours, minutes, seconds = get_h_m_s(last_time)
    print(f"\r[{'#' * pro_size}{' ' * (20 - pro_size)}] {show:3.02f}%, last={hours:02d}:{minutes:02d}:{seconds:03.3f}", end="")

In [33]:
episode_metrics = {
    "accuracy": [],
    "precision": [],
    "recall": [],
    "f1": [],
    "fpr": []
}

def plot_metrics(show_result=False):
    fig = plt.figure(figsize=(16, 20))

    ac = fig.add_subplot(5, 1, 1)
    ac.plot(episode_metrics["accuracy"], label="accuracy")
    ac.grid()
    ac.set_title("Accuracy")

    pr = fig.add_subplot(5, 1, 2)
    pr.plot(episode_metrics["precision"], label="precision", color="green")
    pr.grid()
    pr.set_title("Precision")

    re = fig.add_subplot(5, 1, 3)
    re.plot(episode_metrics["recall"], label="recall", color="red")
    re.grid()
    re.set_title("Recall")

    f1 = fig.add_subplot(5, 1, 4)
    f1.plot(episode_metrics["f1"], label="f1", color="black")
    f1.grid()
    f1.set_title("F1")

    fpr = fig.add_subplot(5, 1, 5)
    fpr.plot(episode_metrics["fpr"], label="fpr", color="purple")
    fpr.grid()
    fpr.set_title("FPR")

    plt.tight_layout()
    plt.pause(0.001)
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())


def calcurate_metrics(tp, tn, fp, fn):
    accuracy = (tp + tn) / (tp + fp + fn + tn)
    precision = tp / (tp + fp) if tp + fp != 0 else -1
    recall = tp / (tp + fn) if tp + fn != 0 else -1
    f1 = 2 * precision * recall / (precision + recall) if precision + recall < 0 else None
    fpr = fp / (fp + tn) if fp + tn != 0 else None

    if precision < 0:
        precision = None
    if recall < 0:
        recall = None
    return accuracy, precision, recall, f1, fpr

In [34]:
policy_net = PolicyNetwork(train_env.observation_space.shape[0], train_env.action_space.n).to(device)
optimizer = torch.optim.Adam(policy_net.parameters(), lr=LR)
memory = ReplayMemory(int(1e6))

steps_done = 0

num_episodes = 10000
episode_rewards = []

start_time = time()
for i_episode in range(num_episodes):
    state, info = train_env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

    sum_reward = 0

    for t in count():
        action = select_action(state)
        next_state, reward, terminated, truncated, _ = train_env.step(action.item())

        reward = torch.tensor([reward], device=device)
        done = bool(terminated)
        steps_done += 1

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)

        print(type(state))
        memory.push(state, action, next_state, reward)

        state = next_state
        sum_reward += reward.item()

        if done:
            episode_rewards.append(sum_reward)
            break

    optimize_model()
    end_time = time()
    loading_bar(i_episode, num_episodes, end_time - start_time)

<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
[                    ] 0.01%, last=00:23:55.931<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
[                    ] 0.02%, last=00:13:35.451<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
[                    ] 0.03%, last=00:10:31.938<class '

RuntimeError: Index tensor must have the same number of dimensions as input tensor

In [None]:
mean_rewards = []

for i in range(0, len(episode_rewards)):
    # print(f"Episode {i}, mean reward: {np.mean(episode_rewards[0:i])}")
    mean_rewards.append(np.mean(episode_rewards[0:i]))

plt.figure(figsize=(10, 5))
plt.plot(episode_rewards)
plt.plot(mean_rewards, color="red")
plt.show()