In [1]:
import gymnasium as gym
from net import DQN
from wrappers import make_env

from dataclasses import dataclass
import time
import numpy as np
import collections
import typing as tt
import yaml

import torch
import torch.nn as nn
import torch.optim as optim

from buffer import Experience, ExperienceBuffer, BatchTensors
import ale_py
from agent import Agent

## Import config file

In [2]:
with open("../config/config.yaml", "r") as file:
    config = yaml.safe_load(file)

In [3]:
gym.register_envs(ale_py)
env = make_env(config['env']['name'])

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


In [4]:
device = torch.device(config['model']['device'])
device

device(type='mps')

In [8]:
net = DQN(env.observation_space.shape, env.action_space.n).to(device)
tgt_net = DQN(env.observation_space.shape, env.action_space.n).to(device)

In [9]:
load = False
if load:
    print('LOADING...')
    data_loaded = torch.load('../models/Pong-v5' + "-best_%.0f.dat" % -9)

    net.load_state_dict(data_loaded)
    tgt_net.load_state_dict(data_loaded)

In [10]:
buffer = ExperienceBuffer(config['replay_buffer']['size'])
agent = Agent(env, buffer)
epsilon = config['agent']['epsilon_start']

optimizer = optim.Adam(net.parameters(), lr=float(config['agent']['learning_rate']))

In [11]:
total_rewards = []
frame_idx = 0
ts_frame = 0
ts = time.time()
best_m_reward = None

In [12]:
def batch_to_tensors(batch: tt.List[Experience], device: torch.device) -> BatchTensors:
    states, actions, rewards, dones, new_state = [], [], [], [], []
    for e in batch:
        states.append(e.state)
        actions.append(e.action)
        rewards.append(e.reward)
        dones.append(e.done_trunc)
        new_state.append(e.new_state)
    states_t = torch.as_tensor(np.asarray(states))
    actions_t = torch.LongTensor(actions)
    rewards_t = torch.FloatTensor(rewards)
    dones_t = torch.BoolTensor(dones)
    new_states_t = torch.as_tensor(np.asarray(new_state))
    return states_t.to(device), actions_t.to(device), rewards_t.to(device), \
           dones_t.to(device),  new_states_t.to(device)

def calc_loss(batch: tt.List[Experience], net: DQN, tgt_net: DQN,
              device: torch.device) -> torch.Tensor:
    states_t, actions_t, rewards_t, dones_t, new_states_t = batch_to_tensors(batch, device)

    state_action_values = net(states_t).gather(
        1, actions_t.unsqueeze(-1)
    ).squeeze(-1)
    with torch.no_grad():
        next_state_values = tgt_net(new_states_t).max(1)[0]
        next_state_values[dones_t] = 0.0
        next_state_values = next_state_values.detach()

    expected_state_action_values = next_state_values * config['agent']['gamma']+ rewards_t
    return nn.MSELoss()(state_action_values, expected_state_action_values)

In [None]:
while True:
    frame_idx += 1
    epsilon = max(config['agent']['epsilon_final'], config['agent']['epsilon_start'] - frame_idx / config['agent']['epsilon_decay_last_frame'])
    reward = agent.play_step(net, device, epsilon)
    if reward is not None:
        total_rewards.append(reward)
        speed = (frame_idx - ts_frame) / (time.time() - ts)
        ts_frame = frame_idx
        ts = time.time()
        m_reward = np.mean(total_rewards[-100:])
        print(f"{frame_idx}: done {len(total_rewards)} games, reward {m_reward:.3f}, eps {epsilon:.2f}, speed {speed:.2f} f/s")
        if best_m_reward is None or best_m_reward < m_reward:
            torch.save(net.state_dict(), '../models/Pong-v5' + "-best_%.0f.dat" % m_reward)
            if best_m_reward is not None:
                print(f"Best reward updated {best_m_reward:.3f} -> {m_reward:.3f}")
            best_m_reward = m_reward
        if m_reward > config['agent']['mean_rew_bound']:
            print("Solved in %d frames!" % frame_idx)
            break
    if len(buffer) < config['replay_buffer']['replay_start_size']:
        continue
    if frame_idx % config['agent']['sync_target_frames']== 0:
        tgt_net.load_state_dict(net.state_dict())

    optimizer.zero_grad()
    batch = buffer.sample(config['agent']['batch_size'])
    loss_t = calc_loss(batch, net, tgt_net, device)
    loss_t.backward()
    optimizer.step()