In [None]:
import operator as ops
import pathlib
from collections import Counter, defaultdict
from itertools import count

import numpy as np
import pandas as pd
import seaborn as sns
import torch
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

from arkanoid import Arkanoid
from dqn import DQNAgent
from nes_py.wrappers import JoypadSpace
from torchinfo import summary

ACTIONS = [["NOOP"], ["left"], ["right"], ["A"]]

env = JoypadSpace(Arkanoid(False), ACTIONS)

In [None]:
checkpoint_dir: pathlib.Path = pathlib.Path("checkpoints")
batch_size: int = 32
episodes: int = 1200
save_every: int = 10_000

agent = DQNAgent(env, batch_size=batch_size)

checkpoint_dir.mkdir(parents=True, exist_ok=True)
latest_dir = checkpoint_dir / "latest"
latest_dir.mkdir(exist_ok=True)

c = count()


def rgb2gray(rgb):
    return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])


summary(agent.policy_net)

In [None]:
for episode in tqdm(range(1, episodes + 1), desc="Episode: ", position=0):
    screen, info = env.reset()
    episode_score = 0

    screen = torch.tensor(
        rgb2gray(env.crop_screen(screen)), dtype=torch.float32, device=agent.device,
    ).unsqueeze(0)
    info = torch.tensor(env.info_to_array(info), device=agent.device)

    for frame in tqdm(c, desc="Frame: ", position=1):
        action = agent.get_action(screen, info)
        next_screen, reward, done, next_info = env.step(action)

        next_screen = torch.tensor(
            rgb2gray(env.crop_screen(next_screen)),
            dtype=torch.float32,
            device=agent.device,
        ).unsqueeze(0)
        next_info = torch.tensor(env.info_to_array(next_info), device=agent.device)

        episode_score += reward
        agent.update(screen, info, action, reward, done, next_screen, next_info)

        screen = next_screen
        info = next_info

        if save_every is not None and frame % save_every == 0:
            checkpoint_dir_ = checkpoint_dir / f"{frame}"
            checkpoint_dir_.mkdir(exist_ok=True)
            torch.save(
                agent.policy_net.state_dict(), checkpoint_dir_ / "policy_net.pth",
            )
            torch.save(
                agent.policy_net.state_dict(), latest_dir / "policy_net.pth",
            )

        if done:
            losses = agent.loss[env.episode]
            agent.scores[env.episode] = episode_score
            print(
                f"Episode {episode}: final score={env.game['score']} total rewards={episode_score} mean loss = {torch.mean(torch.tensor(losses)):.4f}",
                flush=True,
            )
            break

In [None]:
torch.save(agent.policy_net.state_dict(), checkpoint_dir / "policy_net.pth")
torch.save(agent.policy_net.state_dict(), latest_dir / "policy_net.pth")

In [None]:
episode_durations = agent.durations
sns.scatterplot(x=episode_durations.keys(), y=episode_durations.values(), s=10)
plt.title("Duration per episode")
plt.ylabel("Frames")
plt.xlabel("Episode")
plt.show()

In [None]:
episodes_scores = agent.scores
sns.scatterplot(
    x=np.arange(1, len(episodes_scores) + 1), y=list(episodes_scores.values()), s=10
)
plt.title("Score per episode")
plt.ylabel("score")
plt.xlabel("Episode")
plt.show()

In [None]:
df = pd.DataFrame(agent.actions).T
df.rename(columns=lambda x: ACTIONS[x][0], inplace=True)
df = df.div(df.sum(axis=1), axis=0)

In [None]:
sns.barplot(
    data=df.melt(ignore_index=False, value_name="prop", var_name="action")
    .reset_index()
    .rename(columns={"index": "episode"}),
    x="episode",
    y="prop",
    hue="action",
)