<a href="https://colab.research.google.com/github/Vasil255/AjedrezDeGardenerConDQN/blob/main/JugarVSIA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [31]:
import os
import torch
import torch.nn as nn
import numpy as np
import requests
import torch.nn.functional as F
from IPython.display import display, HTML, clear_output

# --- 1. CONFIGURACI√ìN Y DESCARGA DE MODELOS ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
GITHUB_BASE = "https://raw.githubusercontent.com/Vasil255/AjedrezDeGardenerConDQN/main/models/"

# Arquitecturas de los entrenamientos
class QNetworkMLP(nn.Module):
    def __init__(self):
        super(QNetworkMLP, self).__init__()
        self.model = nn.Sequential(nn.Linear(25, 512), nn.ReLU(), nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, 625))
    def forward(self, x): return self.model(x)

class QNetworkSimple(nn.Module):
    def __init__(self):
        super(QNetworkSimple, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU())
        self.fc = nn.Sequential(nn.Linear(64 * 5 * 5, 512), nn.ReLU(), nn.Linear(512, 625))
    def forward(self, x): return self.fc(self.conv(x.view(-1, 1, 5, 5)).view(x.size(0), -1))

class QNetworkPro(nn.Module):
    def __init__(self):
        super(QNetworkPro, self).__init__()
        self.network = nn.Sequential(nn.Conv2d(12, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
                                     nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
                                     nn.Flatten(), nn.Linear(128 * 5 * 5, 512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, 625))
    def forward(self, x):
        if x.shape[1] == 25:
            board = x.view(x.size(0), 5, 5)
            mc = torch.zeros((x.size(0), 12, 5, 5), device=DEVICE)
            for i, p in enumerate([1, 3, 4, 5, 9, 200]):
                mc[:, i, :, :] = (board == p).float(); mc[:, i+6, :, :] = (board == -p).float()
            x = mc
        return self.network(x)

# --- 2. ENTORNO GardnerChessEnv  ---
class GardnerChessEnv:
    def __init__(self): self.P, self.N, self.B, self.R, self.Q, self.K = 1, 3, 4, 5, 9, 200; self.reset()
    def reset(self): self.board = np.array([[-5,-3,-4,-9,-200],[-1,-1,-1,-1,-1],[0,0,0,0,0],[1,1,1,1,1],[5,3,4,9,200]]); self.turn = 1; return self.get_state()
    def get_state(self): return torch.FloatTensor(self.board.flatten()).unsqueeze(0).to(DEVICE)
    def find_king(self, p): pos = np.where(self.board == 200*p); return (pos[0][0], pos[1][0]) if len(pos[0]) > 0 else None
    def is_in_check(self, p):
        kp = self.find_king(p)
        if not kp: return True
        opp_kp = self.find_king(-p)
        if opp_kp and abs(kp[0]-opp_kp[0]) <= 1 and abs(kp[1]-opp_kp[1]) <= 1: return True
        for r in range(5):
            for c in range(5):
                if self.board[r,c]*p < 0:
                    for _, target in self._generate_moves_raw(r,c):
                        if target == kp: return True
        return False
    def get_legal_moves(self, p):
        moves = []
        for r in range(5):
            for c in range(5):
                if self.board[r,c]*p > 0:
                    for m in self._generate_moves_raw(r, c):
                        old_t = self.board[m[1]]
                        self.board[m[1]], self.board[r,c] = self.board[r,c], 0
                        if not self.is_in_check(p): moves.append(((r,c), m[1]))
                        self.board[r,c], self.board[m[1]] = self.board[m[1]], old_t
        return moves
    def _generate_moves_raw(self, r, c):
        p = self.board[r,c]; m = []
        if abs(p)==1: # Pe√≥n
            d = -1 if p>0 else 1
            if 0<=r+d<5 and self.board[r+d,c]==0: m.append(((r,c),(r+d,c)))
            for dc in [-1,1]:
                if 0<=r+d<5 and 0<=c+dc<5 and self.board[r+d,c+dc]*p < 0: m.append(((r,c),(r+d,c+dc)))
        elif abs(p)==3: # Caballo
            for dr, dc in [(-2,-1),(-2,1),(-1,-2),(-1,2),(1,-2),(1,2),(2,-1),(2,1)]:
                nr, nc = r+dr, c+dc
                if 0<=nr<5 and 0<=nc<5 and self.board[nr,nc]*p <= 0: m.append(((r,c),(nr,nc)))
        elif abs(p) in [4, 5, 9]: # Alfil, Torre, Reina
            dirs = []
            if abs(p) != 4: dirs += [(-1,0),(1,0),(0,-1),(0,1)]
            if abs(p) != 5: dirs += [(-1,-1),(-1,1),(1,-1),(1,1)]
            for dr, dc in dirs:
                for i in range(1, 5):
                    nr, nc = r+dr*i, c+dc*i
                    if 0<=nr<5 and 0<=nc<5:
                        if self.board[nr,nc]*p <= 0: m.append(((r,c),(nr,nc)))
                        if self.board[nr,nc] != 0: break
                    else: break
        elif abs(p)==200: # Rey
            for dr in [-1,0,1]:
                for dc in [-1,0,1]:
                    if dr==0 and dc==0: continue
                    nr, nc = r+dr, c+dc
                    if 0<=nr<5 and 0<=nc<5 and self.board[nr,nc]*p <= 0: m.append(((r,c),(nr,nc)))
        return m
    def get_action_mask(self):
        mask = np.zeros(625)
        for m in self.get_legal_moves(self.turn): mask[(m[0][0]*5 + m[0][1])*25 + (m[1][0]*5 + m[1][1])] = 1
        return mask
    def step(self, action_idx):
        r1, c1, r2, c2 = (action_idx // 25) // 5, (action_idx // 25) % 5, (action_idx % 25) // 5, (action_idx % 25) % 5
        self.board[r2, c2], self.board[r1, c1] = self.board[r1, c1], 0
        if abs(self.board[r2, c2]) == 1 and (r2 == 0 or r2 == 4): self.board[r2, c2] = 9 * (1 if self.board[r2, c2] > 0 else -1)
        self.turn *= -1

# --- 3. UI TABLERO MADERA ---
def get_board_html(board, status, turn_color="#dcb35c"):
    symbols = {0:'', 1:'‚ôô', 3:'‚ôò', 4:'‚ôó', 5:'‚ôñ', 9:'‚ôï', 200:'‚ôî', -1:'‚ôü', -3:'‚ôû', -4:'‚ôù', -5:'‚ôú', -9:'‚ôõ', -200:'‚ôö'}
    rows = ""
    for i in range(5):
        row_str = f"<td style='color:#dcb35c; font-weight:bold; padding:0 10px;'>{5-i}</td>"
        for j in range(5):
            bg = "#dcb35c" if (i+j)%2==0 else "#926139"
            v = int(board[i,j]); c = "white" if v > 0 else "black"
            row_str += f'<td style="width:55px; height:55px; background:{bg}; text-align:center; font-size:40px; color:{c}; border:none;">{symbols.get(v, "")}</td>'
        rows += f"<tr>{row_str}</tr>"
    return f'''<div style="background:#3e2723; padding:20px; border-radius:12px; width:360px; margin:auto; font-family:sans-serif;">
        <div style="background:{turn_color}; color:#3e2723; padding:10px; text-align:center; border-radius:6px; margin-bottom:15px; font-weight:bold; font-size:18px;">{status}</div>
        <table style="margin:auto; border-collapse:collapse;">{rows}<tr style="color:#dcb35c; text-align:center; font-weight:bold;"><td></td><td>A</td><td>B</td><td>C</td><td>D</td><td>E</td></tr></table>
    </div>'''

# --- 4. BUCLE DE JUEGO ---
def jugar():
    models = {}
    files = {"MLP": "modelo_mlp_gardner_pesos.pth", "Simple": "modelo_cnn_gardner_pesos.pth", "PRO": "modelo_cnn_pro_gardner_pesos.pth"}
    for name, f in files.items():
        if not os.path.exists(f): os.system(f"wget -q {GITHUB_BASE}{f} -O {f}")
        m = (QNetworkMLP() if name=="MLP" else QNetworkSimple() if name=="Simple" else QNetworkPro()).to(DEVICE)
        m.load_state_dict(torch.load(f, map_location=DEVICE)); models[name] = m.eval()

    clear_output(); print("ü§ñ Elige IA: 1-MLP, 2-Simple, 3-PRO")
    ia_name = {"1":"MLP", "2":"Simple", "3":"PRO"}.get(input("Opci√≥n: "), "PRO")

    env = GardnerChessEnv(); board_id = display(HTML(get_board_html(env.board, "¬°Suerte!")), display_id=True)

    while True:
        mask = env.get_action_mask()
        if not any(mask): # Fin del juego
            msg = f"üèÅ JAQUE MATE: {'IA GAN√ì' if env.turn==1 else '¬°GANASTE!'}"
            board_id.update(HTML(get_board_html(env.board, msg, "#ff5252")))
            break

        if env.turn == 1:
            board_id.update(HTML(get_board_html(env.board, "üü¢ TU TURNO (Blancas)", "#4CAF50")))
            move = input("Tu jugada (ej: A2A3): ").upper()
            try:
                c1, r1, c2, r2 = ord(move[0])-65, 5-int(move[1]), ord(move[2])-65, 5-int(move[3])
                action = (r1*5+c1)*25 + (r2*5+c2)
                if mask[action] == 0: print("‚ùå Movimiento ilegal"); continue
                env.step(action)
            except: continue
        else:
            board_id.update(HTML(get_board_html(env.board, f"üîµ IA PENSANDO ({ia_name})...", "#2196F3")))
            with torch.no_grad():
                q_vals = models[ia_name](env.get_state()); q_vals[0][mask == 0] = -1e9
                env.step(torch.argmax(q_vals).item())

jugar()

ü§ñ Elige IA: 1-MLP, 2-Simple, 3-PRO
Opci√≥n: 3


0,1,2,3,4,5
5.0,‚ôï,,,,‚ôö
4.0,,,,‚ôü,‚ôü
3.0,,,,,
2.0,,,‚ôô,,‚ôô
1.0,‚ôñ,‚ôò,‚ôû,,‚ôî
,A,B,C,D,E


Tu jugada (ej: A2A3): a2a3
Tu jugada (ej: A2A3): a3b4
Tu jugada (ej: A2A3): b4c5
Tu jugada (ej: A2A3): b2a3
Tu jugada (ej: A2A3): c5d5
Tu jugada (ej: A2A3): a3a4
Tu jugada (ej: A2A3): d1d2
Tu jugada (ej: A2A3): a4a5
Tu jugada (ej: A2A3): d2d5
‚ùå Movimiento ilegal
Tu jugada (ej: A2A3): d2a5
