In [1]:
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
import gymnasium as gym
from gymnasium.wrappers import TransformReward, AtariPreprocessing, FrameStack, HumanRendering
from collections import Counter
from itertools import count

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
class DeepQNetwork(nn.Module):
    def __init__(self, n_actions: int):
        super().__init__()

        self.conv1 = nn.Conv2d(4, 32, 8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
        self.fc1 = nn.Linear(7 * 7 * 64, 1024)
        self.fc2 = nn.Linear(1024, n_actions)

        torch.nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(self.conv3.weight, nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='leaky_relu')

    def forward(self, x):
        x = x.type(torch.float32) / 255.0
        x = F.leaky_relu(self.conv1(x), 0.01)
        x = F.leaky_relu(self.conv2(x), 0.01)
        x = F.leaky_relu(self.conv3(x), 0.01)
        x = F.leaky_relu(self.fc1(nn.Flatten()(x)), 0.01)
        return self.fc2(x)

In [4]:
env = gym.make("ALE/Breakout-v5", frameskip=1, render_mode="rgb_array")
env.metadata["render_fps"] = 60
env = AtariPreprocessing(env, scale_obs=False)
env = FrameStack(env, 4, True)
# env = TransformReward(env, lambda x: min(max(x, -1.0), 1.0))
env = HumanRendering(env)

qnet = DeepQNetwork(env.action_space.n)
qnet.load_state_dict(torch.load("../checkpoints/checkpoint-35000.pth"))
# qnet.load_state_dict(torch.load("../models/checkpoint-250000.pth"))
qnet.to(device)

qnet.eval()
state, _ = env.reset()
actions = []
q_values_mean = torch.zeros((1, 4)).to(device)
steps = 100

for _ in range(steps):
    q_values = qnet(torch.Tensor(np.array(state)).unsqueeze(0).to(device))
    action = torch.argmax(q_values)
    # action = env.action_space.sample()
    q_values_mean += q_values
    state, reward, terminated, truncated, _ = env.step(action)
    actions.append(action.item())
    if reward != 0:
        print(reward)
    
    if terminated or truncated:
        break
env.close()
print(q_values_mean / steps)
print(Counter(actions))

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


RuntimeError: Error(s) in loading state_dict for DeepQNetwork:
	Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3.weight", "conv3.bias", "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias". 
	Unexpected key(s) in state_dict: "layer1.weight", "layer1.bias", "layer2.weight", "layer2.bias", "layer3.weight", "layer3.bias". 

In [None]:
0.3180,  0.1097, -0.0419,  0.2139