In [5]:
import gymnasium as gym
import ale_py
import torch
import numpy as np
from torchvision import transforms
from torch.distributions import Categorical

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocesamiento de frames igual que en entrenamiento
preprocess = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(),
    transforms.Resize((84, 84)),
    transforms.ToTensor()
])

def preprocess_frame(frame):
    return preprocess(frame).squeeze(0).numpy()

# Red Actor-Critic CNN igual al entrenamiento
class CNNActorCritic(torch.nn.Module):
    def __init__(self, action_dim):
        super().__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(4, 32, kernel_size=8, stride=4), torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, kernel_size=4, stride=2), torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1), torch.nn.ReLU()
        )
        self.fc = torch.nn.Linear(64 * 7 * 7, 512)
        self.policy_head = torch.nn.Linear(512, action_dim)
        self.value_head = torch.nn.Linear(512, 1)

    def forward(self, x):
        x = self.conv(x).view(x.size(0), -1)
        x = torch.relu(self.fc(x))
        action_probs = torch.softmax(self.policy_head(x), dim=-1)
        state_value = self.value_head(x)
        return action_probs, state_value

# === Cargar el entorno y el modelo entrenado ===
env = gym.make('ALE/Breakout-v5', render_mode="human", frameskip=1)
num_actions = env.action_space.n
model = CNNActorCritic(num_actions).to(device)
model.load_state_dict(torch.load('ppo_breakout.pth', map_location=device))
model.eval()

test_episodes = 5

for ep in range(test_episodes):
    obs, _ = env.reset()
    frame = preprocess_frame(obs)
    state_stack = np.stack([frame] * 4, axis=0)
    total_reward = 0
    done = False

    while not done:
        state_tensor = torch.FloatTensor(state_stack).unsqueeze(0).to(device)
        with torch.no_grad():
            action_probs, _ = model(state_tensor)
        dist = Categorical(action_probs)
        action = dist.probs.argmax().item()

        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        next_frame = preprocess_frame(next_obs)
        state_stack = np.append(state_stack[1:], [next_frame], axis=0)
        total_reward += reward

    print(f"✅ Test Episode {ep + 1}, Total Reward: {total_reward}")

env.close()



RuntimeError: Error(s) in loading state_dict for CNNActorCritic:
	Missing key(s) in state_dict: "policy_head.weight", "policy_head.bias", "value_head.weight", "value_head.bias". 
	Unexpected key(s) in state_dict: "politica_up.weight", "politica_up.bias", "valor_up.weight", "valor_up.bias". 