In [9]:
import torch
import h5py
import numpy as np
import chess
import chess.svg
import plotly.graph_objects as go
import plotly.express as px
from sklearn.manifold import TSNE
from IPython.display import display, clear_output
import ipywidgets as widgets
import random
import re
import os
import sys
from model import ChessEncoder

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "../weights/v13_temp0.07_emb32.pt"
PGN_FILE = "../data/raw/lichess_elite_2025-11.pgn"
H5_FILE = "../data/lc0-hidden/lichess_elite_2025-11.h5"

CHOSEN_OFFSET = 7

model_trained = torch.load(MODEL_PATH, weights_only=False, map_location=DEVICE)
model_trained.eval()


def get_dashboard_data(model, pgn_path, h5_path, num_games=50):
    all_embeddings, all_fens, all_display_meta = [], [], []
    
    result_map = {-1: "Black", 0: "Draw", 1: "White"}
    
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'src')))
    from process_pgns import read_pgn_iter
    
    pgn_gen = read_pgn_iter(pgn_path)
    
    with h5py.File(h5_path, 'r', swmr=True) as f:
        hidden_ds = f['lc0_hidden']
        
        for g_idx in range(num_games):
            try:
                meta_obj, moves = next(pgn_gen)
                
                start, end = g_idx * 10, (g_idx + 1) * 10
                if end > hidden_ds.shape[0]: break
                
                # embeddig kinyerése
                states = torch.from_numpy(hidden_ds[start:end]).float().to(DEVICE)
                with torch.no_grad():
                    embs = model(states).cpu().numpy()
                all_embeddings.append(embs)
                
                # mintavételezés a prepare_inputs logikája szerint
                board = chess.Board()
                positions_per_game = 10
                first_idx = len(moves) % positions_per_game 
                d = (len(moves) - first_idx) // positions_per_game # lépésköz
                
                temp_fens = []
                for m_idx, move_uci in enumerate(moves):
                    board.push_uci(move_uci)
                    if len(moves) <= positions_per_game or m_idx == first_idx + d * len(temp_fens):
                        if len(temp_fens) < 10:
                            temp_fens.append(board.fen())
                
                while len(temp_fens) < 10: temp_fens.append(board.fen())    # ha kevesebb lépésből állna, mint 10
                all_fens.extend(temp_fens)
                
                display_info = {
                    "Opening": getattr(meta_obj, 'opening_name', "Unknown Opening"),
                    "Result": result_map.get(meta_obj.result, "Unknown"),
                    "Elo": f"W: {meta_obj.white_elo} / B: {meta_obj.black_elo}",
                    "Termination": meta_obj.termination
                }
                all_display_meta.extend([display_info] * 10)
                
            except StopIteration: break

    return np.vstack(all_embeddings), all_fens, all_display_meta


def to_rgba(color_str, opacity):
    if color_str.startswith('#'):
        c = color_str.lstrip('#')
        r, g, b = tuple(int(c[i:i+2], 16) for i in (0, 2, 4))
        return f'rgba({r},{g},{b},{opacity})'
    elif color_str.startswith('rgb'):
        nums = re.findall(r'\d+', color_str)
        if len(nums) >= 3:
            return f'rgba({nums[0]},{nums[1]},{nums[2]},{opacity})'
    return color_str

TARGET_OPENINGS = [
    "Sicilian Defense", 
    "French Defense", 
    "Caro-Kann Defense", 
    "Ruy Lopez", 
    "Queen's Gambit",
    "Italian Game",
    "Scandinavian Defense"
]

print("Adatok szinkronizálása és csoportosítása...")
raw_embs, raw_fens, raw_metas = get_dashboard_data(model_trained, PGN_FILE, H5_FILE, num_games=1000)

embs, fens, metas = [], [], []

num_total_games = len(raw_metas) // 10

for g_idx in range(num_total_games):
    i = g_idx * 10 + CHOSEN_OFFSET
    
    if i >= len(raw_metas):
        break
        
    m = raw_metas[i]
    primary_opening = m['Opening'].split(':')[0].strip()
    
    if primary_opening in TARGET_OPENINGS:
        m['Opening'] = primary_opening 
        embs.append(raw_embs[i])
        fens.append(raw_fens[i])
        metas.append(m)


embs = np.array(embs)
print(f"Szűrt és csoportosított állások száma: {len(metas)}")


# T-SNE
tsne = TSNE(n_components=3, perplexity=min(30, len(metas)-1), 
            max_iter=10000, learning_rate='auto', init='pca', random_state=42)
coords = tsne.fit_transform(embs)

unique_openings = sorted(list(set([m['Opening'] for m in metas])))
opening_to_id = {name: i for i, name in enumerate(unique_openings)}
palette = px.colors.qualitative.Bold 

base_colors_rgba = [to_rgba(palette[opening_to_id[m['Opening']] % len(palette)], 0.9) for m in metas]
base_line_colors = ['rgba(255, 255, 255, 0.6)' for _ in metas]

out_ui = widgets.Output()
reset_button = widgets.Button(description='Összes mutatása', button_style='info', icon='refresh')

# 3d grafikon
fig = go.FigureWidget(data=[go.Scatter3d(
    x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
    mode='markers',
    marker=dict(
        size=5,
        color=base_colors_rgba, 
        line=dict(color=base_line_colors, width=0.5) 
    ),
    text=[m['Opening'] for m in metas],
    hoverinfo='text'
)])

fig.update_layout(
    margin=dict(l=0, r=0, b=0, t=0),
    scene=dict(xaxis=dict(title='Latent X'), yaxis=dict(title='Latent Y'), zaxis=dict(title='Latent Z')),
    width=600, height=600,
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)'
)


def on_click(trace, points, state):
    if not points.point_inds: return
    idx = points.point_inds[0]
    clicked_opening = metas[idx]['Opening']
    
    new_colors, new_line_colors, new_sizes = [], [], []
    
    for m in metas:
        if m['Opening'] == clicked_opening:
            color_val = palette[opening_to_id[m['Opening']] % len(palette)]
            new_colors.append(to_rgba(color_val, 1.0))
            new_line_colors.append('rgba(255, 255, 255, 0.8)')
            new_sizes.append(8)
        else:
            new_colors.append('rgba(255, 255, 255, 0.01)')
            new_line_colors.append('rgba(255, 255, 255, 0.0)')
            new_sizes.append(3)
            
    with fig.batch_update():
        fig.data[0].marker.color = new_colors
        fig.data[0].marker.line.color = new_line_colors
        fig.data[0].marker.size = new_sizes
    
    # a jobb oldali rész frissítése
    with out_ui:
        clear_output(wait=True)
        m = metas[idx]
        res_color = "#2ecc71" if m["Result"] == "White" else "#e74c3c" if m["Result"] == "Black" else "#95a5a6"
        
        meta_panel = f"""
        <div style='border: 1px solid #dcdde1; padding: 15px; border-radius: 10px; background: white; 
                    box-shadow: 2px 2px 8px rgba(0,0,0,0.1); width: 220px; font-family: sans-serif;'>
            <h3 style='margin:0; color:#2f3640; border-bottom: 2px solid #f1f2f6; padding-bottom: 8px; font-size: 1.1em;'>{m['Opening']}</h3>
            <div style='margin-top: 10px; font-size: 0.9em;'>
                <p><b>Győztes:</b> <span style='color: {res_color}; font-weight: bold;'>{m['Result']}</span></p>
                <p><b>ELO:</b> {m['Elo']}</p>
                <p><b>Mód:</b> {m['Termination']}</p>
            </div>
        </div>
        """
        board_svg = chess.svg.board(chess.Board(fens[idx]), size=380)
        display(widgets.HBox([widgets.HTML(board_svg), widgets.HTML(meta_panel)], 
                             layout=widgets.Layout(align_items='center', justify_content='center')))

fig.data[0].on_click(on_click)

def reset_view(b):
    with fig.batch_update():
        fig.data[0].marker.color = base_colors_rgba
        fig.data[0].marker.line.color = base_line_colors
        fig.data[0].marker.size = 5
    with out_ui:
        clear_output()
        display(widgets.HTML("<div style='padding: 100px; color: gray;'>Kattints egy pontra a részletekért...</div>"))


reset_button.on_click(reset_view)
fig.data[0].on_click(on_click)

# kirajzolás
left_container = widgets.VBox([reset_button, fig], layout=widgets.Layout(width='100%', align_items='center'))
dashboard = widgets.HBox([
    widgets.Box([left_container], layout=widgets.Layout(width='45%', display='flex', align_items='center', justify_content='center')),
    widgets.Box([out_ui], layout=widgets.Layout(width='55%', display='flex', align_items='center', justify_content='center'))
], layout=widgets.Layout(width='100%', height='650px', align_items='center'))

display(dashboard)

Adatok szinkronizálása és csoportosítása...
Szűrt és csoportosított állások száma: 385


HBox(children=(Box(children=(VBox(children=(Button(button_style='info', description='Összes mutatása', icon='r…