In [None]:
import gymnasium as gym
import ale_py
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


preprocess = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(),
    transforms.Resize((84, 84)),
    transforms.ToTensor()
])

def preprocess_frame(frame):
    return preprocess(frame).squeeze(0).numpy()

# red neuronal
class PolicyNet(nn.Module):
    def __init__(self, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(84 * 84 * 4, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, x):
        return self.net(x)

# Evaluación de un individuo
def evaluate(params, action_dim, episodes=2):
    env = gym.make('ALE/Breakout-v5', frameskip=1)
    model = PolicyNet(action_dim).to(device)
    model.load_state_dict(params)
    model.eval()

    total_reward = 0
    for _ in range(episodes):
        obs, _ = env.reset()
        frame = preprocess_frame(obs)
        state_stack = np.stack([frame] * 4, axis=0)
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state_stack.flatten()).unsqueeze(0).to(device)
            with torch.no_grad():
                logits = model(state_tensor)
            action = torch.argmax(logits).item()
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            next_frame = preprocess_frame(next_obs)
            state_stack = np.append(state_stack[1:], [next_frame], axis=0)
            total_reward += reward
    env.close()
    return total_reward

# cruzar dos individuos
def crossover(p1, p2):
    child = {}
    for k in p1:
        child[k] = (p1[k] + p2[k]) / 2
    return child

# mutación gaussiana
def mutate(params, mutation_rate=0.1):
    for k in params:
        noise = torch.randn_like(params[k]) * mutation_rate
        params[k] += noise
    return params

# parámetros
population_size = 4
generations = 5
mutation_rate = 0.02
elite_fraction = 0.2

env = gym.make('ALE/Breakout-v5', frameskip=1)
num_actions = env.action_space.n
env.close()

# inicializa la población
population = []
for _ in range(population_size):
    model = PolicyNet(num_actions).to(device)
    population.append(model.state_dict())

reward_history = []

# ==== generaciones ====
for gen in range(generations):
    fitness = []
    for individual in population:
        score = evaluate(individual, num_actions)
        fitness.append(score)

    # ordena la población por fitness 
    sorted_indices = np.argsort(fitness)[::-1]
    elite_count = int(population_size * elite_fraction)
    next_population = [population[idx] for idx in sorted_indices[:elite_count]]

    print(f"Gen {gen+1}, Best Fitness: {fitness[sorted_indices[0]]}")
    reward_history.append(fitness[sorted_indices[0]])

    while len(next_population) < population_size:
        parents = random.sample(next_population, 2)
        child = crossover(parents[0], parents[1])
        child = mutate(child, mutation_rate)
        next_population.append(child)

    population = next_population

# ==== grafica ====
plt.plot(reward_history)
plt.xlabel('Generaciones')
plt.ylabel('Mejor fitness')
plt.title('NES Breakout MinAtar - Recompensa por episodio')
plt.show()

# ==== guardar las políticas ====
best_index = np.argmax(fitness)
best_model = PolicyNet(num_actions).to(device)
best_model.load_state_dict(population[best_index])
torch.save(best_model.state_dict(), "nes_breakout.pth")
print("✅ Política guardada como 'nes_breakout.pth'")

