### 1. Lógica de NN para el 2 armed bandit

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

# ====== 1. BANDIT ENVIRONMENT ======

class TwoArmedBandit:
    def __init__(self):
        # En cada episodio cambiamos las probabilidades
        p = np.random.uniform(0.1, 0.9)
        self.probs = [p, 1 - p]

    def pull(self, action):
        return 1 if np.random.rand() < self.probs[action] else 0


# ====== 2. META-RL AGENT (LSTM) ======

class MetaRLAgent(nn.Module):
    def __init__(self, input_size=3, hidden_size=32):
        super().__init__()
        self.lstm = nn.LSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, 2)  # 2 acciones

        self.hidden_size = hidden_size
        self.reset_state()

    def reset_state(self):
        self.hx = torch.zeros(1, self.hidden_size)
        self.cx = torch.zeros(1, self.hidden_size)

    def forward(self, action, reward, timestep):
        # Entrada: [última acción (one-hot), última recompensa, timestep normalizado]
        action_one_hot = F.one_hot(torch.tensor([action]), num_classes=2).float()
        reward = torch.tensor([[reward]], dtype=torch.float32)
        timestep = torch.tensor([[timestep / 10.0]])  # normalizar

        x = torch.cat([action_one_hot, reward, timestep], dim=1)
        self.hx, self.cx = self.lstm(x, (self.hx, self.cx))
        logits = self.fc(self.hx)
        return logits


### 3. Ejecutar el agente entrenado (aún falta implementar el entrenamiento)

In [None]:
# Episodio de ejemplo con pesos fijos (simulación de post-entrenamiento)
agent = MetaRLAgent()
agent.eval()

env = TwoArmedBandit()
agent.reset_state()

last_action = 0
last_reward = 0
total_reward = 0

print("Probabilidades ocultas del entorno:", env.probs)

for t in range(5):  # 5 pasos
    with torch.no_grad():
        logits = agent(last_action, last_reward, t)
        probs = F.softmax(logits, dim=1)
        action = torch.multinomial(probs, num_samples=1).item()

    reward = env.pull(action)
    total_reward += reward

    print(f"Paso {t} | Acción: {action} | Recompensa: {reward}")

    last_action = action
    last_reward = reward

print("Recompensa total:", total_reward)
