In [1]:
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
import gymnasium as gym
from gymnasium.wrappers import RecordEpisodeStatistics, AtariPreprocessing, FrameStack, HumanRendering

from distrl.dqn.agents import DeepQAgent
from distrl.dqn.utils import LinearAnnealer
import matplotlib.pyplot as plt
from collections import Counter
from itertools import count

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
class DeepQNetwork(nn.Module):
    def __init__(self, n_actions: int):
        super().__init__()

        self.conv1 = nn.Conv2d(4, 32, 8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
        self.fc1 = nn.Linear(7 * 7 * 64, 1024)
        self.fc2 = nn.Linear(1024, n_actions)

        torch.nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(self.conv3.weight, nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='leaky_relu')

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.01)
        x = F.leaky_relu(self.conv2(x), 0.01)
        x = F.leaky_relu(self.conv3(x), 0.01)
        x = F.leaky_relu(self.fc1(nn.Flatten()(x)), 0.01)
        return self.fc2(x)

In [4]:
qnet = DeepQNetwork(18)
env = gym.make("ALE/Tennis-v5", frameskip=1)
env = AtariPreprocessing(env, scale_obs=True)
env = FrameStack(env, 4, True)
env = RecordEpisodeStatistics(env, deque_size=1_000_000)

agent = DeepQAgent(qnet, env, 0.99, 10_000, 50_000, 32, 4, 10_000, 50_000, device)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [5]:
%%time
# optimizer = optim.Adam(qnet.parameters(), lr=0.0000625, eps=1.5e-4)
optimizer = optim.Adam(qnet.parameters(), lr=0.00025, eps=1.5e-4)
epsilon = LinearAnnealer(1.0, 0.1, 1_000_000)

target_network = agent.train(10_000_000, optimizer, epsilon)
# target_network = agent.train(10_000, optimizer, epsilon)

100%|██████████| 10000000/10000000 [6:32:08<00:00, 425.01it/s] 

2758 episodes lapsed
CPU times: user 6h 29min 39s, sys: 2min 10s, total: 6h 31min 49s
Wall time: 6h 32min 9s





In [6]:
torch.save(target_network.state_dict(), '../models/DQL Atari Tennis - 10.pth')

In [7]:
env = gym.make("ALE/Tennis-v5", frameskip=1, render_mode="rgb_array")
env.metadata["render_fps"] = 30
env = AtariPreprocessing(env, scale_obs=True)
env = FrameStack(env, 4, True)
env = HumanRendering(env)

In [8]:
target_network.eval()
state, _ = env.reset()
actions = []

for _ in count():
    action = torch.argmax(
        target_network(
            torch.Tensor(np.array(state)).unsqueeze(0).to(device)
        )
    )
    state, reward, terminated, truncated, _ = env.step(action)
    actions.append(action.item())
    
    if terminated or truncated:
        break
env.close()
print(Counter(actions))

KeyboardInterrupt: 