In [1]:
import chess
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.models import resnet50
from typing import List, Tuple
import torch.optim as optim
import time
from tqdm import tqdm

seed = 123456
np.random.seed(seed=seed)

# Разработка модели


In [2]:
class DualHeadChessDataset(Dataset):
    def __init__(self, csv_file: str):
        self.df = pd.read_csv(csv_file)

        self._validate_data()
        self.piece_to_idx = {
            "P": 0,
            "N": 1,
            "B": 2,
            "R": 3,
            "Q": 4,
            "K": 5,
            "p": 6,
            "n": 7,
            "b": 8,
            "r": 9,
            "q": 10,
            "k": 11,
        }

    def _validate_data(self):
        valid_indices = []
        for idx, row in self.df.iterrows():
            try:
                board = chess.Board(row["fen"])
                move = chess.Move.from_uci(row["move"])
                if move in board.legal_moves:
                    valid_indices.append(idx)
            except:
                continue

        self.df = self.df.loc[valid_indices].reset_index(drop=True)
        print(f"Загружено {len(self.df)} валидных позиций")

    def _board_to_tensor(self, fen: str) -> torch.Tensor:
        board = chess.Board(fen)
        tensor = torch.zeros(20, 8, 8, dtype=torch.float32)

        # Фигуры (плоскости 0-11)
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                row, col = square // 8, square % 8
                piece_idx = self.piece_to_idx[piece.symbol()]
                tensor[piece_idx, row, col] = 1.0

        # Чей ход (плоскость 12)
        tensor[12] = 1.0 if board.turn else 0.0

        # Рокировки (плоскости 13-16)
        castling_rights = [
            board.has_kingside_castling_rights(chess.WHITE),
            board.has_queenside_castling_rights(chess.WHITE),
            board.has_kingside_castling_rights(chess.BLACK),
            board.has_queenside_castling_rights(chess.BLACK),
        ]
        for i, has_right in enumerate(castling_rights):
            if has_right:
                tensor[13 + i] = 1.0

        # Взятие на проходе (плоскость 17)
        if board.ep_square is not None:
            row, col = board.ep_square // 8, board.ep_square % 8
            tensor[17, row, col] = 1.0

        # Счетчик полуходов (плоскость 18)
        tensor[18] = board.halfmove_clock / 50.0

        # Номер хода (плоскость 19)
        tensor[19] = board.fullmove_number / 500.0

        return tensor

    def _move_to_dual_labels(self, move_uci: str, fen: str) -> tuple:
        board = chess.Board(fen)
        move = chess.Move.from_uci(move_uci)

        from_square = move.from_square

        # Для to_square учитываем превращения
        # ДИАПАЗОН ДОЛЖЕН БЫТЬ 0-68 (69 классов)
        # print(move.promotion)
        if move.promotion:
            # Кодируем превращения: 64-68
            to_square = 64 + (move.promotion - 1)  # 64, 65, 66, 67, 68
        else:
            to_square = move.to_square  # 0-63

        # Проверяем, что to_square в правильном диапазоне
        if to_square >= 69:
            raise ValueError(f"Некорректный to_square: {to_square} для хода {move_uci}")

        return from_square, to_square

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        fen = row["fen"]
        move = row["move"]

        board_tensor = self._board_to_tensor(fen)
        from_label, to_label = self._move_to_dual_labels(move, fen)

        return board_tensor, from_label, to_label

In [3]:
class ChessDataSplitter:
    def __init__(
        self,
        csv_file,
        train_ratio=0.7,
        val_ratio=0.15,
        test_ratio=0.15,
        random_state=42,
    ):
        print(train_ratio, val_ratio, test_ratio)
        assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, (
            "Сумма долей должна быть равна 1"
        )

        self.csv_file = csv_file
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.random_state = random_state

        # Загружаем полный dataset
        self.full_dataset = DualHeadChessDataset(csv_file)

    def split_data(self):
        """Разбивает данные на train/val/test"""
        # Получаем индексы
        n_total = len(self.full_dataset)
        indices = list(range(n_total))

        # Первое разбиение: отделяем test
        train_val_idx, test_idx = train_test_split(
            indices,
            test_size=self.test_ratio,
            random_state=self.random_state,
            shuffle=True,
        )

        # Второе разбиение: разделяем train и val
        train_idx, val_idx = train_test_split(
            train_val_idx,
            test_size=self.val_ratio / (self.train_ratio + self.val_ratio),
            random_state=self.random_state,
            shuffle=True,
        )

        # Создаем Subset datasets
        train_dataset = Subset(self.full_dataset, train_idx)
        val_dataset = Subset(self.full_dataset, val_idx)
        test_dataset = Subset(self.full_dataset, test_idx)

        print(f"Разбиение завершено:")
        print(
            f"Train: {len(train_dataset)} samples ({len(train_dataset) / n_total * 100:.1f}%)"
        )
        print(
            f"Val: {len(val_dataset)} samples ({len(val_dataset) / n_total * 100:.1f}%)"
        )
        print(
            f"Test: {len(test_dataset)} samples ({len(test_dataset) / n_total * 100:.1f}%)"
        )

        return train_dataset, val_dataset, test_dataset

In [4]:
class ResidualCNNWithAttention(nn.Module):
    def __init__(self, device="cpu", pretrained=True):
        super().__init__()
        self.device = device

        # ResNet-50 backbone
        self.backbone = resnet50(pretrained=pretrained)

        # Модифицируем ResNet для работы с нашими данными
        # 1. Заменяем первый слой для 20 каналов (входные плоскости шахматной доски)
        original_first_conv = self.backbone.conv1
        self.backbone.conv1 = nn.Conv2d(
            20,
            64,
            kernel_size=original_first_conv.kernel_size,
            stride=original_first_conv.stride,
            padding=original_first_conv.padding,
            bias=original_first_conv.bias is not None,
        )

        # Копируем веса из предобученной модели (усредняем по каналам)
        if pretrained:
            with torch.no_grad():
                original_weights = original_first_conv.weight.data
                # Усредняем веса по RGB каналам и повторяем для 20 каналов
                mean_weights = original_weights.mean(dim=1, keepdim=True)
                new_weights = mean_weights.repeat(1, 20, 1, 1) / 3.0
                self.backbone.conv1.weight.data = new_weights

        # 2. Убираем последние слои (avgpool и classifier)
        backbone_layers = list(self.backbone.children())[:-2]
        self.backbone = nn.Sequential(*backbone_layers)

        # 3. Attention механизмы на разных уровнях ResNet
        class SpatialAttention(nn.Module):
            def __init__(self, channels, reduction=16):
                super().__init__()
                # Channel attention
                self.channel_attention = nn.Sequential(
                    nn.AdaptiveAvgPool2d(1),
                    nn.Conv2d(channels, channels // reduction, 1, bias=False),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(channels // reduction, channels, 1, bias=False),
                    nn.Sigmoid(),
                )
                # Spatial attention
                self.spatial_attention = nn.Sequential(
                    nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False), nn.Sigmoid()
                )

            def forward(self, x):
                # Channel attention
                ca = self.channel_attention(x)
                x_ca = x * ca

                # Spatial attention
                avg_out = torch.mean(x_ca, dim=1, keepdim=True)
                max_out, _ = torch.max(x_ca, dim=1, keepdim=True)
                sa_input = torch.cat([avg_out, max_out], dim=1)
                sa = self.spatial_attention(sa_input)
                x_sa = x_ca * sa

                return x_sa, sa

        # Добавляем attention механизмы к разным слоям ResNet
        self.attention1 = SpatialAttention(256)  # После layer1
        self.attention2 = SpatialAttention(512)  # После layer2
        self.attention3 = SpatialAttention(1024)  # После layer3
        self.attention4 = SpatialAttention(2048)  # После layer4

        # Adaptive pooling для приведения к размеру 8x8
        self.adaptive_pool = nn.AdaptiveAvgPool2d((8, 8))

        # BatchNorm и Dropout для стабилизации
        self.from_bn1 = nn.BatchNorm1d(512)
        self.from_dropout = nn.Dropout(0.3)

        self.to_bn1 = nn.BatchNorm2d(2048 + 64)  # ResNet features + one-hot
        self.to_bn2 = nn.BatchNorm1d(512)
        self.to_dropout = nn.Dropout(0.3)

        # Голова для исходной клетки
        self.from_conv_reduce = nn.Conv2d(2048, 256, 1)  # Уменьшаем размерность
        self.from_bn_conv = nn.BatchNorm2d(256)
        self.from_fc1 = nn.Linear(256 * 8 * 8, 512)
        self.from_output = nn.Linear(512, 64)

        # Голова для целевой клетки
        self.to_conv1 = nn.Conv2d(2048 + 64, 512, 3, padding=1)
        self.to_conv2 = nn.Conv2d(512, 256, 3, padding=1)
        self.to_fc1 = nn.Linear(256 * 8 * 8, 512)
        self.to_output = nn.Linear(512, 69)  # 64 клетки + 5 превращения

        # Дополнительные слои для feature fusion
        self.feature_fusion_conv = nn.Conv2d(2048, 512, 1)
        self.feature_fusion_bn = nn.BatchNorm2d(512)

        # Инициализация весов
        self._initialize_weights()

        # Перенос на устройство
        self.to(device)

    def _initialize_weights(self):
        """Инициализация весов для новых слоев"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, board_tensor, from_square=None):
        batch_size = board_tensor.size(0)
        # Проход через ResNet backbone с attention механизмами
        x = self.backbone.conv1(board_tensor)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        # Layer 1 с attention
        x = self.backbone.layer1(x)
        x, att1 = self.attention1(x)

        # Layer 2 с attention
        x = self.backbone.layer2(x)
        x, att2 = self.attention2(x)

        # Layer 3 с attention
        x = self.backbone.layer3(x)
        x, att3 = self.attention3(x)

        # Layer 4 с attention
        x = self.backbone.layer4(x)
        shared_features, att4 = self.attention4(x)

        # Приводим к размеру 8x8
        shared_features = self.adaptive_pool(shared_features)

        # Голова для исходной клетки
        from_features = F.relu(
            self.from_bn_conv(self.from_conv_reduce(shared_features))
        )
        from_flat = from_features.reshape(from_features.size(0), -1)
        from_hidden = F.relu(self.from_fc1(from_flat))
        from_hidden = self.from_bn1(from_hidden)
        from_hidden = self.from_dropout(from_hidden)
        from_logits = self.from_output(from_hidden)

        from_probs = F.log_softmax(from_logits, dim=1)

        # Если только предсказание from_square
        if from_square is None:
            return from_probs, None, [att1, att2, att3, att4]

        # Голова для целевой клетки
        # Создаем one-hot кодирование для from_square
        from_onehot = torch.zeros(batch_size, 64, 8, 8, device=self.device)
        from_onehot[torch.arange(batch_size), from_square, :, :] = 1

        # Конкатенируем ResNet features с one-hot кодированием
        to_input = torch.cat([shared_features, from_onehot], dim=1)

        # Обработка через сверточные слои
        to_features = F.relu(self.to_conv1(to_input))
        to_features = self.to_bn1(to_features)
        to_features = F.relu(self.to_conv2(to_features))

        # Полносвязные слои
        to_flat = to_features.reshape(to_features.size(0), -1)
        to_hidden = F.relu(self.to_fc1(to_flat))
        to_hidden = self.to_bn2(to_hidden)
        to_hidden = self.to_dropout(to_hidden)
        to_logits = self.to_output(to_hidden)

        to_probs = F.log_softmax(to_logits, dim=1)

        return from_probs, to_probs, [att1, att2, att3, att4]

    def get_attention_maps(self, board_tensor, layer_idx=3):
        """Возвращает карты внимания для визуализации"""
        self.eval()
        with torch.no_grad():
            _, _, attention_weights = self.forward(board_tensor)
            return attention_weights[layer_idx].squeeze().cpu().numpy()

    def predict_full_move(self, board_tensor):
        """Предсказывает полный ход (from_square и to_square)"""
        self.eval()
        with torch.no_grad():
            # Предсказываем from_square
            from_probs, _, attention_weights = self.forward(board_tensor)
            from_square_pred = torch.argmax(from_probs, dim=1)

            # Предсказываем to_square на основе from_square_pred
            _, to_probs, _ = self.forward(board_tensor, from_square_pred)
            to_square_pred = torch.argmax(to_probs, dim=1)

            return from_square_pred, to_square_pred, attention_weights

    def get_feature_maps(self, board_tensor):
        """Возвращает feature maps с разных уровней ResNet"""
        self.eval()
        with torch.no_grad():
            batch_size = board_tensor.size(0)

            # Проход через ResNet с сохранением промежуточных features
            x = self.backbone.conv1(board_tensor)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)

            layer1_out = self.backbone.layer1(x)
            layer1_att, att1 = self.attention1(layer1_out)

            layer2_out = self.backbone.layer2(layer1_att)
            layer2_att, att2 = self.attention2(layer2_out)

            layer3_out = self.backbone.layer3(layer2_att)
            layer3_att, att3 = self.attention3(layer3_out)

            layer4_out = self.backbone.layer4(layer3_att)
            final_features, att4 = self.attention4(layer4_out)

            final_features_pooled = self.adaptive_pool(final_features)

            return {
                "layer1": layer1_out,
                "layer2": layer2_out,
                "layer3": layer3_out,
                "layer4": layer4_out,
                "final_features": final_features_pooled,
                "attention_maps": [att1, att2, att3, att4],
            }

    def freeze_backbone(self):
        """Замораживает веса ResNet backbone для transfer learning"""
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Размораживаем attention механизмы
        for param in self.attention1.parameters():
            param.requires_grad = True
        for param in self.attention2.parameters():
            param.requires_grad = True
        for param in self.attention3.parameters():
            param.requires_grad = True
        for param in self.attention4.parameters():
            param.requires_grad = True

    def unfreeze_backbone(self):
        """Размораживает все веса"""
        for param in self.parameters():
            param.requires_grad = True

In [5]:
class ChessTrainer:
    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        test_loader,
        device="cpu",
        lr=0.001,  # Уменьшил learning rate для ResNet
        weight_decay=0.01,
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay

        # Функции потерь для dual-head архитектуры
        self.from_criterion = nn.NLLLoss()  # Т.к. используем log_softmax
        self.to_criterion = nn.NLLLoss()

        # Оптимизатор
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
        )

        # Планировщик learning rate
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )

        # История обучения
        self.history = {
            "train_loss": [],
            "val_loss": [],
            "train_from_acc": [],
            "val_from_acc": [],
            "train_to_acc": [],
            "val_to_acc": [],
            "learning_rate": [],
            "attention_maps": []  # Для сохранения примеров attention maps
        }

    def train_epoch(self):
        """Одна эпоха обучения для dual-head модели"""
        self.model.train()
        total_loss = 0
        total_from_correct = 0
        total_to_correct = 0
        total_samples = 0

        train_bar = tqdm(self.train_loader, desc="Training")

        for batch_idx, (boards, from_targets, to_targets) in enumerate(train_bar):
            # Перемещаем данные на device
            boards = boards.to(self.device)
            from_targets = from_targets.to(self.device)
            to_targets = to_targets.to(self.device)

            # Обнуляем градиенты
            self.optimizer.zero_grad()

            # Прямой проход через dual-head модель
            from_probs, to_probs, attention_weights = self.model(boards, from_targets)

            # Вычисляем потери для обеих голов
            from_loss = self.from_criterion(from_probs, from_targets)
            to_loss = self.to_criterion(to_probs, to_targets)
            
            # Комбинируем потери
            loss = from_loss + to_loss

            # Обратный проход
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # Шаг оптимизатора
            self.optimizer.step()

            # Статистика
            total_loss += loss.item()
            batch_size = boards.size(0)
            total_samples += batch_size

            # Точность для from-square
            from_preds = torch.argmax(from_probs, dim=1)
            from_correct = (from_preds == from_targets).sum().item()
            total_from_correct += from_correct

            # Точность для to-square
            to_preds = torch.argmax(to_probs, dim=1)
            to_correct = (to_preds == to_targets).sum().item()
            total_to_correct += to_correct

            # Обновляем progress bar
            train_bar.set_postfix(
                {
                    "Loss": f"{loss.item():.4f}",
                    "From Acc": f"{from_correct / batch_size:.3f}",
                    "To Acc": f"{to_correct / batch_size:.3f}",
                }
            )

        # Вычисляем средние метрики за эпоху
        avg_loss = total_loss / len(self.train_loader)
        from_acc = total_from_correct / total_samples
        to_acc = total_to_correct / total_samples

        return avg_loss, from_acc, to_acc

    def validate(self):
        """Валидация dual-head модели"""
        self.model.eval()
        total_loss = 0
        total_from_correct = 0
        total_to_correct = 0
        total_samples = 0

        # Сохраняем пример attention maps для визуализации
        sample_attention_maps = []

        with torch.no_grad():
            for batch_idx, (boards, from_targets, to_targets) in enumerate(tqdm(
                self.val_loader, desc="Validation"
            )):
                boards = boards.to(self.device)
                from_targets = from_targets.to(self.device)
                to_targets = to_targets.to(self.device)

                # Прямой проход
                from_probs, to_probs, attention_weights = self.model(boards, from_targets)

                # Потери
                from_loss = self.from_criterion(from_probs, from_targets)
                to_loss = self.to_criterion(to_probs, to_targets)
                loss = from_loss + to_loss

                total_loss += loss.item()
                batch_size = boards.size(0)
                total_samples += batch_size

                # Точность
                from_preds = torch.argmax(from_probs, dim=1)
                from_correct = (from_preds == from_targets).sum().item()
                total_from_correct += from_correct

                to_preds = torch.argmax(to_probs, dim=1)
                to_correct = (to_preds == to_targets).sum().item()
                total_to_correct += to_correct

                # Сохраняем attention maps из первого батча
                if batch_idx == 0 and len(sample_attention_maps) == 0:
                    sample_attention_maps = attention_weights

        avg_loss = total_loss / len(self.val_loader)
        from_acc = total_from_correct / total_samples
        to_acc = total_to_correct / total_samples

        return avg_loss, from_acc, to_acc, sample_attention_maps

    def train(self, num_epochs=50, early_stopping_patience=10):
        """Полный цикл обучения для dual-head модели"""
        best_val_loss = float("inf")
        patience_counter = 0

        print("Начало обучения ResidualCNNWithAttention модели...")
        print(f"Используется устройство: {self.device}")
        print(f"Размер тренировочного набора: {len(self.train_loader.dataset)}")
        print(f"Размер валидационного набора: {len(self.val_loader.dataset)}")
        print(f"Architecture: ResNet-50 backbone + Attention + Dual Heads")

        for epoch in range(num_epochs):
            print(f"\nЭпоха {epoch + 1}/{num_epochs}")
            print("-" * 50)

            # Обучение
            train_loss, train_from_acc, train_to_acc = self.train_epoch()

            # Валидация
            val_loss, val_from_acc, val_to_acc, attention_maps = self.validate()

            # Обновление learning rate
            self.scheduler.step(val_loss)
            current_lr = self.optimizer.param_groups[0]["lr"]

            # Сохраняем историю
            self.history["train_loss"].append(train_loss)
            self.history["val_loss"].append(val_loss)
            self.history["train_from_acc"].append(train_from_acc)
            self.history["val_from_acc"].append(val_from_acc)
            self.history["train_to_acc"].append(train_to_acc)
            self.history["val_to_acc"].append(val_to_acc)
            self.history["learning_rate"].append(current_lr)
            self.history["attention_maps"].append(attention_maps)

            # Выводим результаты
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            print(f"Train From Acc: {train_from_acc:.4f} | Val From Acc: {val_from_acc:.4f}")
            print(f"Train To Acc: {train_to_acc:.4f} | Val To Acc: {val_to_acc:.4f}")
            print(f"Learning Rate: {current_lr:.6f}")

            # Сохранение лучшей модели
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                self.save_model("best_chess_resnet_attention_model.pth")
                print("Сохранена лучшая модель!")
            else:
                patience_counter += 1
                print(f"Early stopping: {patience_counter}/{early_stopping_patience}")

            # Early stopping
            if patience_counter >= early_stopping_patience:
                print("Ранняя остановка!")
                break

        print("\nОбучение завершено!")
        self.plot_training_history()

    def evaluate(self):
        """Финальная оценка на тестовом наборе"""
        print("\nОценка на тестовом наборе...")
        self.model.eval()

        total_from_correct = 0
        total_to_correct = 0
        total_full_correct = 0
        total_samples = 0

        with torch.no_grad():
            for boards, from_targets, to_targets in tqdm(
                self.test_loader, desc="Testing"
            ):
                boards = boards.to(self.device)
                from_targets = from_targets.to(self.device)
                to_targets = to_targets.to(self.device)

                # Предсказание from-square
                from_probs, _, _ = self.model(boards)
                from_preds = torch.argmax(from_probs, dim=1)

                # Предсказание to-square (используя правильные from-square для честной оценки)
                _, to_probs, _ = self.model(boards, from_targets)
                to_preds = torch.argmax(to_probs, dim=1)

                # Подсчет правильных предсказаний
                batch_size = boards.size(0)
                total_samples += batch_size
                
                from_correct = (from_preds == from_targets).sum().item()
                to_correct = (to_preds == to_targets).sum().item()
                full_correct = ((from_preds == from_targets) & (to_preds == to_targets)).sum().item()

                total_from_correct += from_correct
                total_to_correct += to_correct
                total_full_correct += full_correct

        from_acc = total_from_correct / total_samples
        to_acc = total_to_correct / total_samples
        full_acc = total_full_correct / total_samples

        print(f"\nРезультаты на тестовом наборе:")
        print(f"From-square Accuracy: {from_acc:.4f}")
        print(f"To-square Accuracy: {to_acc:.4f}")
        print(f"Full Move Accuracy: {full_acc:.4f}")

        return from_acc, to_acc, full_acc

    def predict_single_move(self, board_tensor):
        """Предсказание хода для одного тензора доски"""
        self.model.eval()
        with torch.no_grad():
            # Предсказываем from_square
            from_probs, _, _ = self.model(board_tensor.unsqueeze(0))
            from_square = torch.argmax(from_probs, dim=1).item()
            
            # Предсказываем to_square на основе предсказанного from_square
            _, to_probs, attention_weights = self.model(
                board_tensor.unsqueeze(0), 
                torch.tensor([from_square], device=self.device)
            )
            to_square = torch.argmax(to_probs, dim=1).item()
            
            return from_square, to_square, attention_weights

    def analyze_attention(self, board_tensor):
        """Анализ attention maps для заданной позиции"""
        self.model.eval()
        with torch.no_grad():
            from_probs, _, attention_weights = self.model(board_tensor.unsqueeze(0))
            from_square = torch.argmax(from_probs, dim=1).item()
            
            # Получаем feature maps для анализа
            feature_maps = self.model.get_feature_maps(board_tensor.unsqueeze(0))
            
            return {
                'from_square_pred': from_square,
                'attention_weights': attention_weights,
                'feature_maps': feature_maps
            }

    def save_model(self, path):
        """Сохранение модели"""
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "scheduler_state_dict": self.scheduler.state_dict(),
                "history": self.history,
                "epoch": len(self.history["train_loss"]),
            },
            path,
        )

    def load_model(self, path):
        """Загрузка модели"""
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        self.history = checkpoint["history"]

    def plot_training_history(self):
        """Визуализация истории обучения"""
        import matplotlib.pyplot as plt

        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

        # Потери
        ax1.plot(self.history["train_loss"], label="Train Loss", linewidth=2)
        ax1.plot(self.history["val_loss"], label="Val Loss", linewidth=2)
        ax1.set_title("Потери", fontsize=14)
        ax1.set_xlabel("Эпоха")
        ax1.set_ylabel("Loss")
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # From-square точность
        ax2.plot(self.history["train_from_acc"], label="Train From Acc", linewidth=2)
        ax2.plot(self.history["val_from_acc"], label="Val From Acc", linewidth=2)
        ax2.set_title("From-square Точность", fontsize=14)
        ax2.set_xlabel("Эпоха")
        ax2.set_ylabel("Accuracy")
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        # To-square точность
        ax3.plot(self.history["train_to_acc"], label="Train To Acc", linewidth=2)
        ax3.plot(self.history["val_to_acc"], label="Val To Acc", linewidth=2)
        ax3.set_title("To-square Точность", fontsize=14)
        ax3.set_xlabel("Эпоха")
        ax3.set_ylabel("Accuracy")
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # Learning rate
        ax4.plot(self.history["learning_rate"], linewidth=2, color='purple')
        ax4.set_title("Learning Rate", fontsize=14)
        ax4.set_xlabel("Эпоха")
        ax4.set_ylabel("LR")
        ax4.grid(True, alpha=0.3)
        ax4.set_yscale('log')

        plt.tight_layout()
        plt.savefig("training_history_resnet_attention.png", dpi=300, bbox_inches="tight")
        plt.show()

    def plot_attention_maps(self, board_tensor, epoch=-1):
        """Визуализация attention maps для примера доски"""
        if len(self.history["attention_maps"]) == 0:
            print("No attention maps available")
            return
        
        import matplotlib.pyplot as plt
        
        attention_maps = self.history["attention_maps"][epoch]
        if isinstance(attention_maps, list):
            # Берем последний уровень attention (самый высокоуровневый)
            attention_map = attention_maps[-1][0].cpu().numpy()  # [1, 8, 8]
        else:
            attention_map = attention_maps[0].cpu().numpy()
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Исходная доска
        board_vis = board_tensor[:12].sum(dim=0).cpu().numpy()  # Сумма по фигурам
        ax1.imshow(board_vis, cmap='RdYlBu_r')
        ax1.set_title("Шахматная доска")
        ax1.grid(True, alpha=0.3)
        
        # Attention map
        im = ax2.imshow(attention_map.squeeze(), cmap='hot', interpolation='nearest')
        ax2.set_title("Attention Map")
        ax2.grid(True, alpha=0.3)
        plt.colorbar(im, ax=ax2)
        
        plt.tight_layout()
        plt.savefig("attention_visualization.png", dpi=300, bbox_inches="tight")
        plt.show()

In [6]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Используется устройство: {device}")

    # Загрузка данных
    print("Загрузка данных...")
    splitter = ChessDataSplitter(
        test_ratio=0.01,
        val_ratio=0.2 * (1 - 0.01),
        train_ratio=0.8 * (1 - 0.01),
        csv_file="fens_training_set.csv",
        random_state=seed
    )
    train_dataset, val_dataset, test_dataset = splitter.split_data()
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

    # Модель
    print("Инициализация модели...")
    model = ResidualCNNWithAttention()

    # Тренер
    trainer = ChessTrainer(model, train_loader, val_loader, test_loader, device)

    # Обучение
    trainer.train(num_epochs=20, early_stopping_patience=10)

    # Валидация
    trainer.validate()

    # # Сохранение финальной модели
    # trainer.save_model("final_chess_model.pth")
    # print("Модель сохранена!")

In [7]:
if __name__ == "__main__":
    main()

Используется устройство: cpu
Загрузка данных...
0.792 0.198 0.01
Загружено 268549 валидных позиций
Разбиение завершено:
Train: 212690 samples (79.2%)
Val: 53173 samples (19.8%)
Test: 2686 samples (1.0%)
Инициализация модели...




Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/ivan/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


 18%|█▊        | 17.1M/97.8M [06:20<29:51, 47.2kB/s]  


KeyboardInterrupt: 