In [1]:
import chess
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.models import resnet50
from typing import List, Tuple
import torch.optim as optim
import time
from tqdm import tqdm

seed = 123456
np.random.seed(seed=seed)

# Разработка модели


In [2]:
class DualHeadChessDataset(Dataset):
    def __init__(self, csv_file: str):
        self.df = pd.read_csv(csv_file)

        self._validate_data()
        self.piece_to_idx = {
            "P": 0,
            "N": 1,
            "B": 2,
            "R": 3,
            "Q": 4,
            "K": 5,
            "p": 6,
            "n": 7,
            "b": 8,
            "r": 9,
            "q": 10,
            "k": 11,
        }

    def _validate_data(self):
        valid_indices = []
        for idx, row in self.df.iterrows():
            try:
                board = chess.Board(row["fen"])
                move = chess.Move.from_uci(row["move"])
                if move in board.legal_moves:
                    valid_indices.append(idx)
            except:
                continue

        self.df = self.df.loc[valid_indices].reset_index(drop=True)
        print(f"Загружено {len(self.df)} валидных позиций")

    def _board_to_tensor(self, fen: str) -> torch.Tensor:
        board = chess.Board(fen)
        tensor = torch.zeros(20, 8, 8, dtype=torch.float32)

        # Фигуры (плоскости 0-11)
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                row, col = square // 8, square % 8
                piece_idx = self.piece_to_idx[piece.symbol()]
                tensor[piece_idx, row, col] = 1.0

        # Чей ход (плоскость 12)
        tensor[12] = 1.0 if board.turn else 0.0

        # Рокировки (плоскости 13-16)
        castling_rights = [
            board.has_kingside_castling_rights(chess.WHITE),
            board.has_queenside_castling_rights(chess.WHITE),
            board.has_kingside_castling_rights(chess.BLACK),
            board.has_queenside_castling_rights(chess.BLACK),
        ]
        for i, has_right in enumerate(castling_rights):
            if has_right:
                tensor[13 + i] = 1.0

        # Взятие на проходе (плоскость 17)
        if board.ep_square is not None:
            row, col = board.ep_square // 8, board.ep_square % 8
            tensor[17, row, col] = 1.0

        # Счетчик полуходов (плоскость 18)
        tensor[18] = board.halfmove_clock / 50.0

        # Номер хода (плоскость 19)
        tensor[19] = board.fullmove_number / 500.0

        return tensor

    def _move_to_dual_labels(self, move_uci: str, fen: str) -> tuple:
        board = chess.Board(fen)
        move = chess.Move.from_uci(move_uci)

        from_square = move.from_square

        # Для to_square учитываем превращения
        # ДИАПАЗОН ДОЛЖЕН БЫТЬ 0-68 (69 классов)
        # print(move.promotion)
        if move.promotion:
            # Кодируем превращения: 64-68
            to_square = 64 + (move.promotion - 1)  # 64, 65, 66, 67, 68
        else:
            to_square = move.to_square  # 0-63

        # Проверяем, что to_square в правильном диапазоне
        if to_square >= 69:
            raise ValueError(f"Некорректный to_square: {to_square} для хода {move_uci}")

        return from_square, to_square

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        fen = row["fen"]
        move = row["move"]

        board_tensor = self._board_to_tensor(fen)
        from_label, to_label = self._move_to_dual_labels(move, fen)

        return board_tensor, from_label, to_label

In [3]:
class ChessDataSplitter:
    def __init__(
        self,
        csv_file,
        train_ratio=0.7,
        val_ratio=0.15,
        test_ratio=0.15,
        random_state=42,
    ):
        print(train_ratio, val_ratio, test_ratio)
        assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, (
            "Сумма долей должна быть равна 1"
        )

        self.csv_file = csv_file
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.random_state = random_state

        # Загружаем полный dataset
        self.full_dataset = DualHeadChessDataset(csv_file)

    def split_data(self):
        """Разбивает данные на train/val/test"""
        # Получаем индексы
        n_total = len(self.full_dataset)
        indices = list(range(n_total))

        # Первое разбиение: отделяем test
        train_val_idx, test_idx = train_test_split(
            indices,
            test_size=self.test_ratio,
            random_state=self.random_state,
            shuffle=True,
        )

        # Второе разбиение: разделяем train и val
        train_idx, val_idx = train_test_split(
            train_val_idx,
            test_size=self.val_ratio / (self.train_ratio + self.val_ratio),
            random_state=self.random_state,
            shuffle=True,
        )

        # Создаем Subset datasets
        train_dataset = Subset(self.full_dataset, train_idx)
        val_dataset = Subset(self.full_dataset, val_idx)
        test_dataset = Subset(self.full_dataset, test_idx)

        print(f"Разбиение завершено:")
        print(
            f"Train: {len(train_dataset)} samples ({len(train_dataset) / n_total * 100:.1f}%)"
        )
        print(
            f"Val: {len(val_dataset)} samples ({len(val_dataset) / n_total * 100:.1f}%)"
        )
        print(
            f"Test: {len(test_dataset)} samples ({len(test_dataset) / n_total * 100:.1f}%)"
        )

        return train_dataset, val_dataset, test_dataset

In [4]:
class ResidualCNNWithAttention(nn.Module):
    def __init__(self, device="cpu", pretrained=True):
        super().__init__()
        self.device = device

        # Используем более легкую версию ResNet
        self.backbone = resnet50(pretrained=pretrained)
        
        # Модифицируем первый слой для 20 каналов
        original_conv1 = self.backbone.conv1
        self.backbone.conv1 = nn.Conv2d(
            20, 64,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=original_conv1.bias is not None
        )
        
        # Копируем веса из предобученной модели
        if pretrained:
            with torch.no_grad():
                original_weights = original_conv1.weight.data
                mean_weights = original_weights.mean(dim=1, keepdim=True)
                new_weights = mean_weights.repeat(1, 20, 1, 1) / 3.0
                self.backbone.conv1.weight.data = new_weights
        
        # Замораживаем первые слои ResNet
        self._freeze_early_layers()
        
        # Упрощаем attention механизмы
        self.attention4 = SimplifiedAttention(2048)  # Только последний слой
        
        # Уменьшаем размерность фич
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))  # Уменьшаем с 8x8 до 4x4
        
        # Упрощаем головы
        self.from_fc1 = nn.Linear(2048 * 4 * 4, 256)  # Уменьшаем размерность
        self.from_output = nn.Linear(256, 64)
        
        self.to_conv1 = nn.Conv2d(2048 + 64, 256, 3, padding=1)  # Меньше каналов
        self.to_fc1 = nn.Linear(256 * 4 * 4, 128)
        self.to_output = nn.Linear(128, 69)
        
        # Заменяем BatchNorm на более быстрые альтернативы
        self.from_ln1 = nn.LayerNorm(256)
        self.to_ln1 = nn.LayerNorm(128)

    def _freeze_early_layers(self):
        """Замораживаем ранние слои ResNet"""
        # Замораживаем conv1, bn1, layer1, layer2
        for param in self.backbone.conv1.parameters():
            param.requires_grad = False
        for param in self.backbone.bn1.parameters():
            param.requires_grad = False
        for param in self.backbone.layer1.parameters():
            param.requires_grad = False
        for param in self.backbone.layer2.parameters():
            param.requires_grad = False

    def forward(self, board_tensor, from_square=None):
        batch_size = board_tensor.size(0)
        
        # Убедимся, что входной тензор на правильном устройстве
        board_tensor = board_tensor.to(self.device)
        
        # Проход через ResNet backbone (только незамороженные слои)
        x = self.backbone.conv1(board_tensor)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        
        # Пропускаем замороженные слои (проход без вычисления градиентов)
        with torch.no_grad():
            x = self.backbone.layer1(x)
            x = self.backbone.layer2(x)
        
        # Обучаемые слои
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        
        # Attention на последнем слое
        shared_features, att4 = self.attention4(x)
        
        # Приводим к размеру 4x4
        shared_features = self.adaptive_pool(shared_features)

        # Голова для исходной клетки
        from_flat = shared_features.reshape(shared_features.size(0), -1)
        from_hidden = F.relu(self.from_fc1(from_flat))
        from_hidden = self.from_ln1(from_hidden)
        from_logits = self.from_output(from_hidden)
        from_probs = F.log_softmax(from_logits, dim=1)

        # Если только предсказание from_square
        if from_square is None:
            return from_probs, None, [att4]  # Только один attention map

        # Голова для целевой клетки
        # Создаем one-hot кодирование для from_square на правильном устройстве
        from_onehot = torch.zeros(batch_size, 64, 4, 4, device=self.device)
        from_onehot[torch.arange(batch_size, device=self.device), from_square, :, :] = 1

        # Конкатенируем ResNet features с one-hot кодированием
        to_input = torch.cat([shared_features, from_onehot], dim=1)

        # Обработка через сверточные слои
        to_features = F.relu(self.to_conv1(to_input))
        to_flat = to_features.reshape(to_features.size(0), -1)
        to_hidden = F.relu(self.to_fc1(to_flat))
        to_hidden = self.to_ln1(to_hidden)
        to_logits = self.to_output(to_hidden)
        to_probs = F.log_softmax(to_logits, dim=1)

        return from_probs, to_probs, [att4]

    def to(self, device):
        """Переопределяем метод to для установки устройства"""
        self.device = device
        return super().to(device)

class SimplifiedAttention(nn.Module):
    """Упрощенный attention механизм"""
    def __init__(self, channels, reduction=8):  # Уменьшили reduction
        super().__init__()
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        ca = self.channel_attention(x)
        return x * ca, ca  # Убрали spatial attention для скорости

In [None]:
class ChessTrainer:
    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        test_loader,
        device="cpu",
        lr=0.001,
        weight_decay=0.01,
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay

        # Функции потерь для dual-head архитектуры
        self.from_criterion = nn.NLLLoss()  # Т.к. используем log_softmax
        self.to_criterion = nn.NLLLoss()

        # Оптимизатор
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
        )

        # Планировщик learning rate
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer, step_size=5, gamma=0.1
        )

        # История обучения
        self.history = {
            "train_loss": [],
            "val_loss": [],
            "train_from_acc": [],
            "val_from_acc": [],
            "train_to_acc": [],
            "val_to_acc": [],
            "learning_rate": [],
            "attention_maps": []  # Для сохранения примеров attention maps
        }

    def train_epoch(self):
        """Одна эпоха обучения для dual-head модели"""
        self.model.train()
        total_loss = 0
        total_from_correct = 0
        total_to_correct = 0
        total_samples = 0

        train_bar = tqdm(self.train_loader, desc="Training")

        for batch_idx, (boards, from_targets, to_targets) in enumerate(train_bar):
            try:
                # Перемещаем данные на device
                boards = boards.to(self.device, dtype=torch.float32)
                from_targets = from_targets.to(self.device, dtype=torch.long)
                to_targets = to_targets.to(self.device, dtype=torch.long)

                # Валидация целей
                assert from_targets.min() >= 0 and from_targets.max() < 64, \
                    f"Неверные from_targets: {from_targets.min()}-{from_targets.max()}"
                assert to_targets.min() >= 0 and to_targets.max() < 69, \
                    f"Неверные to_targets: {to_targets.min()}-{to_targets.max()}"

                # Обнуляем градиенты
                self.optimizer.zero_grad()

                # Прямой проход через dual-head модель - сначала только from_head
                from_probs, _, attention_weights = self.model(boards)

                # Вычисляем потери для from_head
                from_loss = self.from_criterion(from_probs, from_targets)

                # Прямой проход для to_head с истинными from_targets (teacher forcing)
                _, to_probs, _ = self.model(boards, from_targets)
                to_loss = self.to_criterion(to_probs, to_targets)
                
                # Комбинируем потери
                loss = from_loss + to_loss

                # Проверка на NaN
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Пропуск батча {batch_idx} из-за невалидного loss")
                    continue

                # Обратный проход
                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

                # Шаг оптимизатора
                self.optimizer.step()

                # Статистика
                total_loss += loss.item()
                batch_size = boards.size(0)
                total_samples += batch_size

                # Точность для from-square
                from_preds = torch.argmax(from_probs, dim=1)
                from_correct = (from_preds == from_targets).sum().item()
                total_from_correct += from_correct

                # Точность для to-square
                to_preds = torch.argmax(to_probs, dim=1)
                to_correct = (to_preds == to_targets).sum().item()
                total_to_correct += to_correct

                # Обновляем progress bar
                train_bar.set_postfix(
                    {
                        "Loss": f"{loss.item():.4f}",
                        "From Acc": f"{from_correct / batch_size:.3f}",
                        "To Acc": f"{to_correct / batch_size:.3f}",
                    }
                )

            except Exception as e:
                print(f"Ошибка в батче {batch_idx}: {e}")
                continue

        # Вычисляем средние метрики за эпоху
        avg_loss = total_loss / len(self.train_loader) if len(self.train_loader) > 0 else 0
        from_acc = total_from_correct / total_samples if total_samples > 0 else 0
        to_acc = total_to_correct / total_samples if total_samples > 0 else 0

        return avg_loss, from_acc, to_acc

    def validate(self):
        """Валидация dual-head модели"""
        self.model.eval()
        total_loss = 0
        total_from_correct = 0
        total_to_correct = 0
        total_samples = 0

        # Сохраняем пример attention maps для визуализации
        sample_attention_maps = []

        with torch.no_grad():
            for batch_idx, (boards, from_targets, to_targets) in enumerate(tqdm(
                self.val_loader, desc="Validation"
            )):
                try:
                    boards = boards.to(self.device, dtype=torch.float32)
                    from_targets = from_targets.to(self.device, dtype=torch.long)
                    to_targets = to_targets.to(self.device, dtype=torch.long)

                    # Прямой проход для from_head
                    from_probs, _, attention_weights = self.model(boards)

                    # Потери для from_head
                    from_loss = self.from_criterion(from_probs, from_targets)

                    # Прямой проход для to_head с истинными from_targets
                    _, to_probs, _ = self.model(boards, from_targets)
                    to_loss = self.to_criterion(to_probs, to_targets)
                    
                    loss = from_loss + to_loss

                    total_loss += loss.item()
                    batch_size = boards.size(0)
                    total_samples += batch_size

                    # Точность
                    from_preds = torch.argmax(from_probs, dim=1)
                    from_correct = (from_preds == from_targets).sum().item()
                    total_from_correct += from_correct

                    to_preds = torch.argmax(to_probs, dim=1)
                    to_correct = (to_preds == to_targets).sum().item()
                    total_to_correct += to_correct

                    # Сохраняем attention maps из первого батча
                    if batch_idx == 0 and len(sample_attention_maps) == 0:
                        sample_attention_maps = attention_weights

                except Exception as e:
                    print(f"Ошибка в валидационном батче {batch_idx}: {e}")
                    continue

        avg_loss = total_loss / len(self.val_loader) if len(self.val_loader) > 0 else 0
        from_acc = total_from_correct / total_samples if total_samples > 0 else 0
        to_acc = total_to_correct / total_samples if total_samples > 0 else 0

        return avg_loss, from_acc, to_acc, sample_attention_maps

    def train(self, num_epochs=50, early_stopping_patience=10):
        """Полный цикл обучения для dual-head модели"""
        best_val_loss = float("inf")
        patience_counter = 0

        print("Начало обучения ResidualCNNWithAttention модели...")
        print(f"Используется устройство: {self.device}")
        print(f"Размер тренировочного набора: {len(self.train_loader.dataset)}")
        print(f"Размер валидационного набора: {len(self.val_loader.dataset)}")
        print(f"Architecture: ResNet-50 backbone + Attention + Dual Heads")

        for epoch in range(num_epochs):
            print(f"\nЭпоха {epoch + 1}/{num_epochs}")
            print("-" * 50)

            # Обучение
            train_loss, train_from_acc, train_to_acc = self.train_epoch()

            # Валидация
            val_loss, val_from_acc, val_to_acc, attention_maps = self.validate()

            # Обновление learning rate
            self.scheduler.step(val_loss)
            current_lr = self.optimizer.param_groups[0]["lr"]

            # Сохраняем историю
            self.history["train_loss"].append(train_loss)
            self.history["val_loss"].append(val_loss)
            self.history["train_from_acc"].append(train_from_acc)
            self.history["val_from_acc"].append(val_from_acc)
            self.history["train_to_acc"].append(train_to_acc)
            self.history["val_to_acc"].append(val_to_acc)
            self.history["learning_rate"].append(current_lr)
            self.history["attention_maps"].append(attention_maps)

            # Выводим результаты
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            print(f"Train From Acc: {train_from_acc:.4f} | Val From Acc: {val_from_acc:.4f}")
            print(f"Train To Acc: {train_to_acc:.4f} | Val To Acc: {val_to_acc:.4f}")
            print(f"Learning Rate: {current_lr:.6f}")

            # Сохранение лучшей модели
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                self.save_model("best_chess_resnet_attention_model.pth")
                print("Сохранена лучшая модель!")
            else:
                patience_counter += 1
                print(f"Early stopping: {patience_counter}/{early_stopping_patience}")

            # Early stopping
            if patience_counter >= early_stopping_patience:
                print("Ранняя остановка!")
                break

        print("\nОбучение завершено!")
        self.plot_training_history()

    def evaluate(self):
        """Финальная оценка на тестовом наборе"""
        print("\nОценка на тестовом наборе...")
        self.model.eval()

        total_from_correct = 0
        total_to_correct = 0
        total_full_correct = 0
        total_samples = 0

        with torch.no_grad():
            for boards, from_targets, to_targets in tqdm(
                self.test_loader, desc="Testing"
            ):
                boards = boards.to(self.device, dtype=torch.float32)
                from_targets = from_targets.to(self.device, dtype=torch.long)
                to_targets = to_targets.to(self.device, dtype=torch.long)

                # Предсказание from-square
                from_probs, _, _ = self.model(boards)
                from_preds = torch.argmax(from_probs, dim=1)

                # Предсказание to-square с истинными from-square для оценки качества модели
                _, to_probs, _ = self.model(boards, from_targets)
                to_preds = torch.argmax(to_probs, dim=1)

                # Подсчет правильных предсказаний
                batch_size = boards.size(0)
                total_samples += batch_size
                
                from_correct = (from_preds == from_targets).sum().item()
                to_correct = (to_preds == to_targets).sum().item()
                full_correct = ((from_preds == from_targets) & (to_preds == to_targets)).sum().item()

                total_from_correct += from_correct
                total_to_correct += to_correct
                total_full_correct += full_correct

        from_acc = total_from_correct / total_samples if total_samples > 0 else 0
        to_acc = total_to_correct / total_samples if total_samples > 0 else 0
        full_acc = total_full_correct / total_samples if total_samples > 0 else 0

        print(f"\nРезультаты на тестовом наборе:")
        print(f"From-square Accuracy: {from_acc:.4f}")
        print(f"To-square Accuracy: {to_acc:.4f}")
        print(f"Full Move Accuracy: {full_acc:.4f}")

        return from_acc, to_acc, full_acc

    def predict_single_move(self, board_tensor):
        """Предсказание хода для одного тензора доски"""
        self.model.eval()
        with torch.no_grad():
            # Предсказываем from_square
            from_probs, _, _ = self.model(board_tensor.unsqueeze(0).to(self.device))
            from_square = torch.argmax(from_probs, dim=1).item()
            
            # Предсказываем to_square на основе предсказанного from_square
            _, to_probs, attention_weights = self.model(
                board_tensor.unsqueeze(0).to(self.device), 
                torch.tensor([from_square], device=self.device)
            )
            to_square = torch.argmax(to_probs, dim=1).item()
            
            return from_square, to_square, attention_weights

    def analyze_attention(self, board_tensor):
        """Анализ attention maps для заданной позиции"""
        self.model.eval()
        with torch.no_grad():
            board_input = board_tensor.unsqueeze(0).to(self.device)
            from_probs, _, attention_weights = self.model(board_input)
            from_square = torch.argmax(from_probs, dim=1).item()
            
            # Получаем feature maps для анализа
            if hasattr(self.model, 'get_feature_maps'):
                feature_maps = self.model.get_feature_maps(board_input)
            else:
                feature_maps = None
            
            return {
                'from_square_pred': from_square,
                'attention_weights': attention_weights,
                'feature_maps': feature_maps
            }

    def save_model(self, path):
        """Сохранение модели"""
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "scheduler_state_dict": self.scheduler.state_dict(),
                "history": self.history,
                "epoch": len(self.history["train_loss"]),
            },
            path,
        )

    def load_model(self, path):
        """Загрузка модели"""
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        self.history = checkpoint["history"]

    def plot_training_history(self):
        """Визуализация истории обучения"""
        import matplotlib.pyplot as plt

        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

        # Потери
        epochs = range(1, len(self.history["train_loss"]) + 1)
        ax1.plot(epochs, self.history["train_loss"], label="Train Loss", linewidth=2)
        ax1.plot(epochs, self.history["val_loss"], label="Val Loss", linewidth=2)
        ax1.set_title("Потери", fontsize=14)
        ax1.set_xlabel("Эпоха")
        ax1.set_ylabel("Loss")
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # From-square точность
        ax2.plot(epochs, self.history["train_from_acc"], label="Train From Acc", linewidth=2)
        ax2.plot(epochs, self.history["val_from_acc"], label="Val From Acc", linewidth=2)
        ax2.set_title("From-square Точность", fontsize=14)
        ax2.set_xlabel("Эпоха")
        ax2.set_ylabel("Accuracy")
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        # To-square точность
        ax3.plot(epochs, self.history["train_to_acc"], label="Train To Acc", linewidth=2)
        ax3.plot(epochs, self.history["val_to_acc"], label="Val To Acc", linewidth=2)
        ax3.set_title("To-square Точность", fontsize=14)
        ax3.set_xlabel("Эпоха")
        ax3.set_ylabel("Accuracy")
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # Learning rate
        ax4.plot(epochs, self.history["learning_rate"], linewidth=2, color='purple')
        ax4.set_title("Learning Rate", fontsize=14)
        ax4.set_xlabel("Эпоха")
        ax4.set_ylabel("LR")
        ax4.grid(True, alpha=0.3)
        ax4.set_yscale('log')

        plt.tight_layout()
        plt.savefig("training_history_resnet_attention.png", dpi=300, bbox_inches="tight")
        plt.show()

    def plot_attention_maps(self, board_tensor, epoch=-1):
        """Визуализация attention maps для примера доски"""
        if len(self.history["attention_maps"]) == 0:
            print("No attention maps available")
        
        attention_maps = self.history["attention_maps"][epoch]
        if isinstance(attention_maps, list) and len(attention_maps) > 0:
            # Берем последний уровень attention (самый высокоуровневый)
            attention_map = attention_maps[-1][0].cpu().numpy()  # [1, H, W]
        else:
            attention_map = attention_maps[0].cpu().numpy() if attention_maps is not None else None
        
        if attention_map is None:
            print("No attention map available")
            return
            
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Исходная доска
        board_vis = board_tensor[:12].sum(dim=0).cpu().numpy()  # Сумма по фигурам
        ax1.imshow(board_vis, cmap='RdYlBu_r')
        ax1.set_title("Шахматная доска")
        ax1.grid(True, alpha=0.3)
        
        # Attention map
        im = ax2.imshow(attention_map.squeeze(), cmap='hot', interpolation='nearest')
        ax2.set_title("Attention Map")
        ax2.grid(True, alpha=0.3)
        plt.colorbar(im, ax=ax2)
        
        plt.tight_layout()
        plt.savefig("attention_visualization.png", dpi=300, bbox_inches="tight")
        plt.show()

In [6]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Используется устройство: {device}")

    torch.cuda.empty_cache()

    # Загрузка данных
    print("Загрузка данных...")
    splitter = ChessDataSplitter(
        test_ratio=0.01,
        val_ratio=0.2 * (1 - 0.01),
        train_ratio=0.8 * (1 - 0.01),
        csv_file="fens_training_set.csv",
        random_state=seed
    )
    train_dataset, val_dataset, test_dataset = splitter.split_data()
    train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

    # Модель
    print("Инициализация модели...")
    model = ResidualCNNWithAttention()

    # Тренер
    trainer = ChessTrainer(model, train_loader, val_loader, test_loader, device)

    # Обучение
    trainer.train(num_epochs=20, early_stopping_patience=10)

    # Валидация
    trainer.validate()

    # # Сохранение финальной модели
    # trainer.save_model("final_chess_model.pth")
    # print("Модель сохранена!")

In [7]:
if __name__ == "__main__":
    main()

Используется устройство: cuda
Загрузка данных...
0.792 0.198 0.01
Загружено 268549 валидных позиций
Разбиение завершено:
Train: 212690 samples (79.2%)
Val: 53173 samples (19.8%)
Test: 2686 samples (1.0%)
Инициализация модели...




Начало обучения ResidualCNNWithAttention модели...
Используется устройство: cuda
Размер тренировочного набора: 212690
Размер валидационного набора: 53173
Architecture: ResNet-50 backbone + Attention + Dual Heads

Эпоха 1/20
--------------------------------------------------


Training: 100%|██████████| 208/208 [08:27<00:00,  2.44s/it, Loss=5.8964, From Acc=0.137, To Acc=0.312]
Validation: 100%|██████████| 1662/1662 [02:02<00:00, 13.59it/s]


Train Loss: 6.4809 | Val Loss: 6.3799
Train From Acc: 0.1211 | Val From Acc: 0.1232
Train To Acc: 0.2346 | Val To Acc: 0.2517
Learning Rate: 0.000100
Сохранена лучшая модель!

Эпоха 2/20
--------------------------------------------------


Training: 100%|██████████| 208/208 [08:02<00:00,  2.32s/it, Loss=5.7845, From Acc=0.154, To Acc=0.313]
Validation: 100%|██████████| 1662/1662 [01:57<00:00, 14.20it/s]


Train Loss: 5.7648 | Val Loss: 5.7616
Train From Acc: 0.1654 | Val From Acc: 0.1659
Train To Acc: 0.3137 | Val To Acc: 0.3141
Learning Rate: 0.000100
Сохранена лучшая модель!

Эпоха 3/20
--------------------------------------------------


Training: 100%|██████████| 208/208 [07:33<00:00,  2.18s/it, Loss=5.5682, From Acc=0.194, To Acc=0.338]
Validation: 100%|██████████| 1662/1662 [01:54<00:00, 14.54it/s]


Train Loss: 5.6567 | Val Loss: 5.7010
Train From Acc: 0.1780 | Val From Acc: 0.1721
Train To Acc: 0.3236 | Val To Acc: 0.3182
Learning Rate: 0.000100
Сохранена лучшая модель!

Эпоха 4/20
--------------------------------------------------


Training: 100%|██████████| 208/208 [07:37<00:00,  2.20s/it, Loss=5.6005, From Acc=0.191, To Acc=0.328]
Validation: 100%|██████████| 1662/1662 [01:54<00:00, 14.54it/s]


Train Loss: 5.5644 | Val Loss: 5.6502
Train From Acc: 0.1901 | Val From Acc: 0.1811
Train To Acc: 0.3319 | Val To Acc: 0.3244
Learning Rate: 0.000100
Сохранена лучшая модель!

Эпоха 5/20
--------------------------------------------------


Training: 100%|██████████| 208/208 [07:25<00:00,  2.14s/it, Loss=5.4656, From Acc=0.219, To Acc=0.319]
Validation: 100%|██████████| 1662/1662 [01:50<00:00, 14.99it/s]


Train Loss: 5.4613 | Val Loss: 5.6158
Train From Acc: 0.2033 | Val From Acc: 0.1843
Train To Acc: 0.3419 | Val To Acc: 0.3250
Learning Rate: 0.000100
Сохранена лучшая модель!

Эпоха 6/20
--------------------------------------------------


Training:  69%|██████▉   | 144/208 [05:11<02:18,  2.16s/it, Loss=5.2654, From Acc=0.219, To Acc=0.375]


KeyboardInterrupt: 