In [1]:
!pip install torch numpy datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting 

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from datasets import load_dataset
from collections import defaultdict
import json
import time
import os

# Модель нейросети
class DurakNet(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256, num_cards=24, num_actions=4):
        super(DurakNet, self).__init__()

        self.state_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )

        self.action_head = nn.Linear(hidden_dim // 2, num_actions)
        self.card_head = nn.Linear(hidden_dim // 2, num_cards)
        self.state_head = nn.Linear(hidden_dim // 2, 3)

    def forward(self, state):
        encoded = self.state_encoder(state)
        action_logits = self.action_head(encoded)
        card_logits = self.card_head(encoded)
        state_logits = self.state_head(encoded)
        return action_logits, card_logits, state_logits

# Кодирование состояния
def encode_state(snapshot, my_id, device='cpu'):
    try:
        data = json.loads(snapshot['snapshot'])

        state = np.zeros(128, dtype=np.float32)

        trump = data['trump']
        suits = {'S': 0, 'C': 1, 'D': 2, 'H': 3}
        ranks = {'9': 0, '10': 1, '11': 2, '12': 3, '13': 4, '14': 5}
        trump_rank, trump_suit = trump[:-1], trump[-1]
        state[0 + suits[trump_suit]] = 1.0
        state[4 + ranks[trump_rank]] = 1.0

        for player in data['players']:
            if player['id'] == my_id:
                for card in player['hand']:
                    rank, suit = card[:-1], card[-1]
                    card_idx = ranks[rank] * 4 + suits[suit]
                    state[10 + card_idx] = 1.0

        for i, pair in enumerate(data['table'][:4]):
            attack_card = pair['attack_card']['card']
            rank, suit = attack_card[:-1], attack_card[-1]
            state[34 + i * 12 + ranks[rank]] = 1.0
            state[34 + i * 12 + 6 + suits[suit]] = 1.0
            if 'defend_card' in pair:
                defend_card = pair['defend_card']['card']
                rank, suit = defend_card[:-1], defend_card[-1]
                state[82 + i * 12 + ranks[rank]] = 1.0
                state[82 + i * 12 + 6 + suits[suit]] = 1.0

        state[114] = len(data['deck']) / 24.0

        for player in data['players']:
            if player['id'] == my_id:
                state_map = {'attack': 115, 'defend': 116, 'bat': 117, 'pass': 118, 'take': 119}
                if player['state'] in state_map:
                    state[state_map[player['state']]] = 1.0

        state[120] = data['game_rules']['game_type']

        tensor = torch.from_numpy(state).float().to(device)
        tensor.requires_grad_(True)
        return tensor

    except Exception as e:
        print(f"Error in encode_state: {e}")
        return torch.zeros(128, dtype=torch.float32).to(device)

# Проверка, может ли defend_card побить attack_card
def can_beat(attack_card, defend_card, trump_suit):
    suits = {'S': 0, 'C': 1, 'D': 2, 'H': 3}
    ranks = {'9': 0, '10': 1, '11': 2, '12': 3, '13': 4, '14': 5}

    attack_rank, attack_suit = attack_card[:-1], attack_card[-1]
    defend_rank, defend_suit = defend_card[:-1], defend_card[-1]

    attack_rank_idx = ranks[attack_rank]
    defend_rank_idx = ranks[defend_rank]
    attack_suit_idx = suits[attack_suit]
    defend_suit_idx = suits[defend_suit]
    trump_suit_idx = suits[trump_suit]

    if attack_suit_idx == defend_suit_idx:
        return defend_rank_idx > attack_rank_idx
    elif defend_suit_idx == trump_suit_idx and attack_suit_idx != trump_suit_idx:
        return True
    return False

# Извлечение действия
def get_action(prev_snapshot, curr_snapshot, my_id):
    try:
        prev_data = json.loads(prev_snapshot['snapshot'])
        curr_data = json.loads(curr_snapshot['snapshot'])

        ranks = {'9': 0, '10': 1, '11': 2, '12': 3, '13': 4, '14': 5}
        suits = {'S': 0, 'C': 1, 'D': 2, 'H': 3}
        trump_suit = curr_data['trump'][-1]
        game_type = curr_data['game_rules']['game_type']

        if len(curr_data['table']) > len(prev_data['table']):
            new_pair = curr_data['table'][-1]
            attack_card = new_pair['attack_card']['card']
            attack_user = new_pair['attack_card']['user_id']

            if attack_user == my_id:
                rank, suit = attack_card[:-1], attack_card[-1]
                card_idx = ranks[rank] * 4 + suits[suit]

                if game_type == 1 and prev_data['table']:
                    last_attack = prev_data['table'][-1]['attack_card']['card']
                    if last_attack[:-1] == attack_card[:-1]:
                        print(f"Detected transfer: {attack_card} matches {last_attack}")
                        return {'type': 'attack', 'card_idx': card_idx}
                return {'type': 'attack', 'card_idx': card_idx}

            if 'defend_card' in new_pair and new_pair['defend_card']['user_id'] == my_id:
                defend_card = new_pair['defend_card']['card']
                rank, suit = defend_card[:-1], defend_card[-1]
                card_idx = ranks[rank] * 4 + suits[suit]

                if not can_beat(attack_card, defend_card, trump_suit):
                    print(f"Invalid defend: {defend_card} cannot beat {attack_card}, trump={trump_suit}")
                    return {'type': 'invalid'}
                return {'type': 'defend', 'card_idx': card_idx}

        for curr_p, prev_p in zip(curr_data['players'], prev_data['players']):
            if curr_p['id'] == my_id and curr_p['state'] != prev_p['state']:
                if curr_p['state'] in ['bat', 'pass', 'take']:
                    state_map = {'bat': 0, 'pass': 1, 'take': 2}
                    print(f"State change detected: {prev_p['state']} -> {curr_p['state']}")
                    return {'type': 'state', 'state_idx': state_map[curr_p['state']]}

        if len(curr_data['table']) == 0 and len(prev_data['table']) > 0:
            for curr_p, prev_p in zip(curr_data['players'], prev_data['players']):
                if curr_p['id'] == my_id and curr_p['state'] == 'bat':
                    print("Detected 'bat' action: table cleared")
                    return {'type': 'state', 'state_idx': 0}

        return {'type': 'wait'}

    except Exception as e:
        print(f"Error in get_action: {e}")
        return {'type': 'invalid'}

# Обучение модели
def train_model(max_games=None, device='cpu'):
    print("Loading dataset...")
    dataset = load_dataset("neuronetties/durak")
    print("Dataset loaded.")

    model = DurakNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    action_criterion = nn.CrossEntropyLoss()
    card_criterion = nn.CrossEntropyLoss()
    state_criterion = nn.CrossEntropyLoss()

    for name, param in model.named_parameters():
        if not param.requires_grad:
            print(f"Warning: Parameter {name} does not require grad!")

    games = defaultdict(list)
    for sample in dataset['train']:
        games[sample['game_id']].append(sample)

    if max_games is not None:
        game_ids = list(games.keys())[:max_games]
        games = {gid: games[gid] for gid in game_ids}

    total_games = len(games)
    print(f"Total games to process: {total_games}")

    action_counts = defaultdict(int)
    invalid_samples = 0

    epochs = 10
    for epoch in range(epochs):
        start_time = time.time()
        total_loss = 0.0
        num_samples = 0
        games_processed = 0

        for game_id, snapshots in games.items():
            print(f"Processing game {game_id}...")
            snapshots = sorted(snapshots, key=lambda x: json.loads(x['snapshot'])['timestamp'])

            try:
                my_id = json.loads(snapshots[0]['snapshot'])['players'][0]['id']
            except (KeyError, IndexError):
                print(f"Skipping game {game_id}: invalid player data")
                continue

            for i in range(len(snapshots) - 1):
                try:
                    state = encode_state(snapshots[i], my_id, device=device)
                    print(f"Snapshot {i}, state requires_grad: {state.requires_grad}")

                    # Проверка на nan/inf в state
                    if torch.isnan(state).any() or torch.isinf(state).any():
                        print(f"Invalid values in state for snapshot {i} in game {game_id}")
                        continue

                    action = get_action(snapshots[i], snapshots[i + 1], my_id)

                    if action['type'] == 'invalid':
                        invalid_samples += 1
                        continue

                    if action['type'] not in ['attack', 'defend', 'state', 'wait']:
                        print(f"Unknown action type: {action['type']}")
                        continue

                    action_counts[action['type']] += 1

                    action_logits, card_logits, state_logits = model(state.unsqueeze(0))

                    # Проверка, что выходы модели подключены к графу
                    print(f"action_logits grad_fn: {action_logits.grad_fn}")

                    loss = 0.0
                    if action['type'] == 'attack':
                        action_target = torch.tensor([0], dtype=torch.long, requires_grad=False, device=device)
                        card_target = torch.tensor([action['card_idx']], dtype=torch.long, requires_grad=False, device=device)
                        loss += action_criterion(action_logits, action_target)
                        loss += card_criterion(card_logits, card_target)
                    elif action['type'] == 'defend':
                        action_target = torch.tensor([1], dtype=torch.long, requires_grad=False, device=device)
                        card_target = torch.tensor([action['card_idx']], dtype=torch.long, requires_grad=False, device=device)
                        loss += action_criterion(action_logits, action_target)
                        loss += card_criterion(card_logits, card_target)
                    elif action['type'] == 'state':
                        action_target = torch.tensor([2], dtype=torch.long, requires_grad=False, device=device)
                        state_target = torch.tensor([action['state_idx']], dtype=torch.long, requires_grad=False, device=device)
                        loss += action_criterion(action_logits, action_target)
                        loss += state_criterion(state_logits, state_target)
                    else:
                        action_target = torch.tensor([3], dtype=torch.long, requires_grad=False, device=device)
                        loss += action_criterion(action_logits, action_target)

                    # Проверка на nan/inf в loss
                    if torch.isnan(loss) or torch.isinf(loss):
                        print(f"Loss is invalid: {loss}")
                        continue

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    total_loss += loss.item()
                    num_samples += 1

                except Exception as e:
                    print(f"Error processing snapshot {i} in game {game_id}: {e}")
                    continue

            games_processed += 1
            if games_processed % 10 == 0:
                print(f"Epoch {epoch + 1}: Processed {games_processed}/{total_games} games ({games_processed/total_games*100:.1f}%)")

        elapsed_time = time.time() - start_time
        if num_samples > 0:
            print(f"Epoch {epoch + 1}, Loss: {total_loss / num_samples:.4f}, Samples: {num_samples}, Time: {elapsed_time:.2f}s")
        else:
            print(f"Epoch {epoch + 1}: No valid samples processed")

        print(f"Action distribution: {dict(action_counts)}")
        print(f"Invalid samples skipped: {invalid_samples}")

        checkpoint_path = f"durak_checkpoint_epoch_{epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    print("Saving final model...")
    torch.save(model.state_dict(), "durak_model.pth")
    print("Model weights saved to durak_model.pth")

    return model

# Предсказание действия
def predict_action(model, snapshot, my_id, device='cpu'):
    model.eval()
    with torch.no_grad():
        state = encode_state(snapshot, my_id, device=device)
        action_logits, card_logits, state_logits = model(state.unsqueeze(0))

        print(f"action_logits: {action_logits}")

        action_idx = torch.argmax(action_logits, dim=1).item()
        action_map = {0: 'attack', 1: 'defend', 2: 'state', 3: 'wait'}
        action_type = action_map[action_idx]

        if action_type == 'attack' or action_type == 'defend':
            card_idx = torch.argmax(card_logits, dim=1).item()
            ranks = {0: '9', 1: '10', 2: '11', 3: '12', 4: '13', 5: '14'}
            suits = {0: 'S', 1: 'C', 2: 'D', 3: 'H'}
            card = f"{ranks[card_idx // 4]}{suits[card_idx % 4]}"
            return {"type": action_type, "move": card}
        elif action_type == 'state':
            state_idx = torch.argmax(state_logits, dim=1).item()
            state_map = {0: 'bat', 1: 'pass', 2: 'take'}
            return {"type": "state", "state": state_map[state_idx]}
        else:
            return {"type": "wait"}

# Загрузка модели
def load_model(model_path="durak_model.pth", device='cpu'):
    model = DurakNet().to(device)
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        print(f"Model weights loaded from {model_path}")
        return model
    else:
        print(f"Model file {model_path} not found")
        return None

if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    model = train_model(max_games=100, device=device)

    print("Testing predictions...")
    dataset = load_dataset("neuronetties/durak")
    for i in range(3):
        sample = dataset['train'][i]
        my_id = json.loads(sample['snapshot'])['players'][0]['id']
        action = predict_action(model, sample, my_id, device=device)
        print(f"Sample {i}: Predicted action: {action}")

    loaded_model = load_model("durak_model.pth", device=device)
    if loaded_model:
        action = predict_action(loaded_model, dataset['train'][0], my_id, device=device)
        print(f"Loaded model prediction: {action}")

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
action_logits grad_fn: None
Error processing snapshot 35 in game ee966a6f-4066-418a-be55-06f66a32c877: element 0 of tensors does not require grad and does not have a grad_fn
Snapshot 36, state requires_grad: True
action_logits grad_fn: None
Error processing snapshot 36 in game ee966a6f-4066-418a-be55-06f66a32c877: element 0 of tensors does not require grad and does not have a grad_fn
Snapshot 37, state requires_grad: True
action_logits grad_fn: None
Error processing snapshot 37 in game ee966a6f-4066-418a-be55-06f66a32c877: element 0 of tensors does not require grad and does not have a grad_fn
Snapshot 38, state requires_grad: True
action_logits grad_fn: None
Error processing snapshot 38 in game ee966a6f-4066-418a-be55-06f66a32c877: element 0 of tensors does not require grad and does not have a grad_fn
Snapshot 39, state requires_grad: True
action_logits grad_fn: None
Error processing snapshot 39 in game e

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from datasets import load_dataset
from collections import defaultdict
import json
import time
import os

# Модель нейросети
class DurakNet(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256, num_cards=24, num_actions=4):
        super(DurakNet, self).__init__()

        self.state_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )

        self.action_head = nn.Linear(hidden_dim // 2, num_actions)
        self.card_head = nn.Linear(hidden_dim // 2, num_cards)
        self.state_head = nn.Linear(hidden_dim // 2, 3)

    def forward(self, state):
        encoded = self.state_encoder(state)
        action_logits = self.action_head(encoded)
        card_logits = self.card_head(encoded)
        state_logits = self.state_head(encoded)
        return action_logits, card_logits, state_logits

# Кодирование состояния на основе предложенной структуры
def encode_state(snapshot, my_id, device='cpu'):
    try:
        data = json.loads(snapshot['snapshot'])

        # Инициализация вектора состояния
        state = np.zeros(128, dtype=np.float32)

        # Козырь (trump)
        trump = data['trump']
        suits = {'S': 0, 'C': 1, 'D': 2, 'H': 3}
        ranks = {'9': 0, '10': 1, '11': 2, '12': 3, '13': 4, '14': 5}
        trump_rank, trump_suit = trump[:-1], trump[-1]
        state[0 + suits[trump_suit]] = 1.0  # Кодируем масть козыря
        state[4 + ranks[trump_rank]] = 1.0  # Кодируем ранг козыря

        # Карты игрока (my_id)
        for player in data['players']:
            if player['id'] == my_id:
                for card in player['hand']:
                    rank, suit = card[:-1], card[-1]
                    card_idx = ranks[rank] * 4 + suits[suit]
                    state[10 + card_idx] = 1.0  # Кодируем карты на руке

        # Карты на столе (table)
        for i, pair in enumerate(data['table'][:4]):  # Ограничиваем до 4 пар
            attack_card = pair['attack_card']['card']
            rank, suit = attack_card[:-1], attack_card[-1]
            state[34 + i * 12 + ranks[rank]] = 1.0  # Кодируем ранг атакующей карты
            state[34 + i *12 + 6 + suits[suit]] = 1.0  # Кодируем масть атакующей карты
            if 'defend_card' in pair:
                defend_card = pair['defend_card']['card']
                rank, suit = defend_card[:-1], defend_card[-1]
                state[82 + i * 12 + ranks[rank]] = 1.0  # Кодируем ранг защитной карты
                state[82 + i * 12 + 6 + suits[suit]] = 1.0  # Кодируем масть защитной карты

        # Колода (deck)
        state[114] = len(data['deck']) / 24.0  # Нормализованное количество карт в колоде

        # Состояние игрока (state)
        for player in data['players']:
            if player['id'] == my_id:
                state_map = {'attack': 115, 'defend': 116, 'bat': 117, 'pass': 118, 'take': 119}
                if player['state'] in state_map:
                    state[state_map[player['state']]] = 1.0

        # Режим игры (game_type)
        state[120] = data['game_rules']['game_type']

        # Создаем тензор
        tensor = torch.tensor(state, dtype=torch.float32, device=device, requires_grad=True)
        return tensor

    except Exception as e:
        print(f"Error in encode_state: {e}")
        return torch.zeros(128, dtype=torch.float32, device=device)

# Проверка, может ли defend_card побить attack_card
def can_beat(attack_card, defend_card, trump_suit):
    suits = {'S': 0, 'C': 1, 'D': 2, 'H': 3}
    ranks = {'9': 0, '10': 1, '11': 2, '12': 3, '13': 4, '14': 5}

    attack_rank, attack_suit = attack_card[:-1], attack_card[-1]
    defend_rank, defend_suit = defend_card[:-1], defend_card[-1]

    attack_rank_idx = ranks[attack_rank]
    defend_rank_idx = ranks[defend_rank]
    attack_suit_idx = suits[attack_suit]
    defend_suit_idx = suits[defend_suit]
    trump_suit_idx = suits[trump_suit]

    if attack_suit_idx == defend_suit_idx:
        return defend_rank_idx > attack_rank_idx
    elif defend_suit_idx == trump_suit_idx and attack_suit_idx != trump_suit_idx:
        return True
    return False

# Извлечение действия
def get_action(prev_snapshot, curr_snapshot, my_id):
    try:
        prev_data = json.loads(prev_snapshot['snapshot'])
        curr_data = json.loads(curr_snapshot['snapshot'])

        ranks = {'9': 0, '10': 1, '11': 2, '12': 3, '13': 4, '14': 5}
        suits = {'S': 0, 'C': 1, 'D': 2, 'H': 3}
        trump_suit = curr_data['trump'][-1]
        game_type = curr_data['game_rules']['game_type']

        if len(curr_data['table']) > len(prev_data['table']):
            new_pair = curr_data['table'][-1]
            attack_card = new_pair['attack_card']['card']
            attack_user = new_pair['attack_card']['user_id']

            if attack_user == my_id:
                rank, suit = attack_card[:-1], attack_card[-1]
                card_idx = ranks[rank] * 4 + suits[suit]

                if game_type == 1 and prev_data['table']:
                    last_attack = prev_data['table'][-1]['attack_card']['card']
                    if last_attack[:-1] == attack_card[:-1]:
                        print(f"Detected transfer: {attack_card} matches {last_attack}")
                        return {'type': 'attack', 'card_idx': card_idx}
                return {'type': 'attack', 'card_idx': card_idx}

            if 'defend_card' in new_pair and new_pair['defend_card']['user_id'] == my_id:
                defend_card = new_pair['defend_card']['card']
                rank, suit = defend_card[:-1], defend_card[-1]
                card_idx = ranks[rank] * 4 + suits[suit]

                if not can_beat(attack_card, defend_card, trump_suit):
                    print(f"Invalid defend: {defend_card} cannot beat {attack_card}, trump={trump_suit}")
                    return {'type': 'invalid'}
                return {'type': 'defend', 'card_idx': card_idx}

        for curr_p, prev_p in zip(curr_data['players'], prev_data['players']):
            if curr_p['id'] == my_id and curr_p['state'] != prev_p['state']:
                if curr_p['state'] in ['bat', 'pass', 'take']:
                    state_map = {'bat': 0, 'pass': 1, 'take': 2}
                    print(f"State change detected: {prev_p['state']} -> {curr_p['state']}")
                    return {'type': 'state', 'state_idx': state_map[curr_p['state']]}

        if len(curr_data['table']) == 0 and len(prev_data['table']) > 0:
            for curr_p, prev_p in zip(curr_data['players'], prev_data['players']):
                if curr_p['id'] == my_id and curr_p['state'] == 'bat':
                    print("Detected 'bat' action: table cleared")
                    return {'type': 'state', 'state_idx': 0}

        return {'type': 'wait'}

    except Exception as e:
        print(f"Error in get_action: {e}")
        return {'type': 'invalid'}

# Обучение модели
def train_model(max_games=None, device='cpu'):
    print("Loading dataset...")
    dataset = load_dataset("neuronetties/durak")
    print("Dataset loaded.")

    model = DurakNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    action_criterion = nn.CrossEntropyLoss()
    card_criterion = nn.CrossEntropyLoss()
    state_criterion = nn.CrossEntropyLoss()

    # Проверка параметров модели
    for name, param in model.named_parameters():
        print(f"Parameter {name}, requires_grad: {param.requires_grad}")
        if not param.requires_grad:
            print(f"Warning: Parameter {name} does not require grad!")

    games = defaultdict(list)
    for sample in dataset['train']:
        games[sample['game_id']].append(sample)

    if max_games is not None:
        game_ids = list(games.keys())[:max_games]
        games = {gid: games[gid] for gid in game_ids}

    total_games = len(games)
    print(f"Total games to process: {total_games}")

    action_counts = defaultdict(int)
    invalid_samples = 0

    epochs = 10
    for epoch in range(epochs):
        start_time = time.time()
        total_loss = 0.0
        num_samples = 0
        games_processed = 0

        for game_id, snapshots in games.items():
            print(f"Processing game {game_id}...")
            snapshots = sorted(snapshots, key=lambda x: json.loads(x['snapshot'])['timestamp'])

            try:
                my_id = json.loads(snapshots[0]['snapshot'])['players'][0]['id']
            except (KeyError, IndexError):
                print(f"Skipping game {game_id}: invalid player data")
                continue

            for i in range(len(snapshots) - 1):
                try:
                    state = encode_state(snapshots[i], my_id, device=device)
                    print(f"Snapshot {i}, state requires_grad: {state.requires_grad}")

                    # Проверка на nan/inf в state
                    if torch.isnan(state).any() or torch.isinf(state).any():
                        print(f"Invalid values in state for snapshot {i} in game {game_id}")
                        continue

                    action = get_action(snapshots[i], snapshots[i + 1], my_id)

                    if action['type'] == 'invalid':
                        invalid_samples += 1
                        continue

                    if action['type'] not in ['attack', 'defend', 'state', 'wait']:
                        print(f"Unknown action type: {action['type']}")
                        continue

                    action_counts[action['type']] += 1

                    # Прямой проход через модель
                    action_logits, card_logits, state_logits = model(state.unsqueeze(0))

                    # Проверка grad_fn
                    print(f"action_logits grad_fn: {action_logits.grad_fn}")
                    print(f"card_logits grad_fn: {card_logits.grad_fn}")
                    print(f"state_logits grad_fn: {state_logits.grad_fn}")

                    loss = 0.0
                    if action['type'] == 'attack':
                        action_target = torch.tensor([0], dtype=torch.long, requires_grad=False, device=device)
                        card_target = torch.tensor([action['card_idx']], dtype=torch.long, requires_grad=False, device=device)
                        loss += action_criterion(action_logits, action_target)
                        loss += card_criterion(card_logits, card_target)
                    elif action['type'] == 'defend':
                        action_target = torch.tensor([1], dtype=torch.long, requires_grad=False, device=device)
                        card_target = torch.tensor([action['card_idx']], dtype=torch.long, requires_grad=False, device=device)
                        loss += action_criterion(action_logits, action_target)
                        loss += card_criterion(card_logits, card_target)
                    elif action['type'] == 'state':
                        action_target = torch.tensor([2], dtype=torch.long, requires_grad=False, device=device)
                        state_target = torch.tensor([action['state_idx']], dtype=torch.long, requires_grad=False, device=device)
                        loss += action_criterion(action_logits, action_target)
                        loss += state_criterion(state_logits, state_target)
                    else:
                        action_target = torch.tensor([3], dtype=torch.long, requires_grad=False, device=device)
                        loss += action_criterion(action_logits, action_target)

                    # Проверка на nan/inf в loss
                    if torch.isnan(loss) or torch.isinf(loss):
                        print(f"Loss is invalid: {loss}")
                        continue

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    total_loss += loss.item()
                    num_samples += 1

                except Exception as e:
                    print(f"Error processing snapshot {i} in game {game_id}: {e}")
                    continue

            games_processed += 1
            if games_processed % 10 == 0:
                print(f"Epoch {epoch + 1}: Processed {games_processed}/{total_games} games ({games_processed/total_games*100:.1f}%)")

        elapsed_time = time.time() - start_time
        if num_samples > 0:
            print(f"Epoch {epoch + 1}, Loss: {total_loss / num_samples:.4f}, Samples: {num_samples}, Time: {elapsed_time:.2f}s")
        else:
            print(f"Epoch {epoch + 1}: No valid samples processed")

        print(f"Action distribution: {dict(action_counts)}")
        print(f"Invalid samples skipped: {invalid_samples}")

        checkpoint_path = f"durak_checkpoint_epoch_{epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    print("Saving final model...")
    torch.save(model.state_dict(), "durak_model.pth")
    print("Model weights saved to durak_model.pth")

    return model

# Предсказание действия
def predict_action(model, snapshot, my_id, device='cpu'):
    model.eval()
    with torch.no_grad():
        state = encode_state(snapshot, my_id, device=device)
        action_logits, card_logits, state_logits = model(state.unsqueeze(0))

        print(f"action_logits: {action_logits}")

        action_idx = torch.argmax(action_logits, dim=1).item()
        action_map = {0: 'attack', 1: 'defend', 2: 'state', 3: 'wait'}
        action_type = action_map[action_idx]

        if action_type == 'attack' or action_type == 'defend':
            card_idx = torch.argmax(card_logits, dim=1).item()
            ranks = {0: '9', 1: '10', 2: '11', 3: '12', 4: '13', 5: '14'}
            suits = {0: 'S', 1: 'C', 2: 'D', 3: 'H'}
            card = f"{ranks[card_idx // 4]}{suits[card_idx % 4]}"
            return {"type": action_type, "move": card}
        elif action_type == 'state':
            state_idx = torch.argmax(state_logits, dim=1).item()
            state_map = {0: 'bat', 1: 'pass', 2: 'take'}
            return {"type": "state", "state": state_map[state_idx]}
        else:
            return {"type": "wait"}

# Загрузка модели
def load_model(model_path="durak_model.pth", device='cpu'):
    model = DurakNet().to(device)
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        print(f"Model weights loaded from {model_path}")
        return model
    else:
        print(f"Model file {model_path} not found")
        return None

if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    model = train_model(max_games=100, device=device)

    print("Testing predictions...")
    dataset = load_dataset("neuronetties/durak")
    for i in range(3):
        sample = dataset['train'][i]
        my_id = json.loads(sample['snapshot'])['players'][0]['id']
        action = predict_action(model, sample, my_id, device=device)
        print(f"Sample {i}: Predicted action: {action}")

    loaded_model = load_model("durak_model.pth", device=device)
    if loaded_model:
        action = predict_action(loaded_model, dataset['train'][0], my_id, device=device)
        print(f"Loaded model prediction: {action}")

Using device: cpu
Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/321 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/18.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/292962 [00:00<?, ? examples/s]

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
Snapshot 25, state requires_grad: True
action_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
card_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
state_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
Snapshot 26, state requires_grad: True
action_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
card_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
state_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
Snapshot 27, state requires_grad: True
action_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
card_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
state_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
Snapshot 28, state requires_grad: True
action_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
card_logits grad_fn: <AddmmBackward0 object at 0x7fe6f0887cd0>
state_logits grad_fn: <AddmmBackward0 object at 0x7fe6f088

In [3]:
import torch
import torch.nn as nn
import numpy as np
from datasets import load_dataset
from collections import defaultdict
import json
import os

# Модель нейросети (должна совпадать с той, что использовалась при обучении)
class DurakNet(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256, num_cards=24, num_actions=4):
        super(DurakNet, self).__init__()

        self.state_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )

        self.action_head = nn.Linear(hidden_dim // 2, num_actions)
        self.card_head = nn.Linear(hidden_dim // 2, num_cards)
        self.state_head = nn.Linear(hidden_dim // 2, 3)

    def forward(self, state):
        encoded = self.state_encoder(state)
        action_logits = self.action_head(encoded)
        card_logits = self.card_head(encoded)
        state_logits = self.state_head(encoded)
        return action_logits, card_logits, state_logits

# Кодирование состояния
def encode_state(snapshot, my_id, device='cpu'):
    try:
        data = json.loads(snapshot['snapshot'])

        state = np.zeros(128, dtype=np.float32)

        trump = data['trump']
        suits = {'S': 0, 'C': 1, 'D': 2, 'H': 3}
        ranks = {'9': 0, '10': 1, '11': 2, '12': 3, '13': 4, '14': 5}
        trump_rank, trump_suit = trump[:-1], trump[-1]
        state[0 + suits[trump_suit]] = 1.0
        state[4 + ranks[trump_rank]] = 1.0

        for player in data['players']:
            if player['id'] == my_id:
                for card in player['hand']:
                    rank, suit = card[:-1], card[-1]
                    card_idx = ranks[rank] * 4 + suits[suit]
                    state[10 + card_idx] = 1.0

        for i, pair in enumerate(data['table'][:4]):
            attack_card = pair['attack_card']['card']
            rank, suit = attack_card[:-1], attack_card[-1]
            state[34 + i * 12 + ranks[rank]] = 1.0
            state[34 + i * 12 + 6 + suits[suit]] = 1.0
            if 'defend_card' in pair:
                defend_card = pair['defend_card']['card']
                rank, suit = defend_card[:-1], defend_card[-1]
                state[82 + i * 12 + ranks[rank]] = 1.0
                state[82 + i * 12 + 6 + suits[suit]] = 1.0

        state[114] = len(data['deck']) / 24.0

        for player in data['players']:
            if player['id'] == my_id:
                state_map = {'attack': 115, 'defend': 116, 'bat': 117, 'pass': 118, 'take': 119}
                if player['state'] in state_map:
                    state[state_map[player['state']]] = 1.0

        state[120] = data['game_rules']['game_type']

        tensor = torch.tensor(state, dtype=torch.float32, device=device, requires_grad=False)
        return tensor

    except Exception as e:
        print(f"Error in encode_state: {e}")
        return torch.zeros(128, dtype=torch.float32, device=device)

# Проверка, может ли defend_card побить attack_card
def can_beat(attack_card, defend_card, trump_suit):
    suits = {'S': 0, 'C': 1, 'D': 2, 'H': 3}
    ranks = {'9': 0, '10': 1, '11': 2, '12': 3, '13': 4, '14': 5}

    attack_rank, attack_suit = attack_card[:-1], attack_card[-1]
    defend_rank, defend_suit = defend_card[:-1], defend_card[-1]

    attack_rank_idx = ranks[attack_rank]
    defend_rank_idx = ranks[defend_rank]
    attack_suit_idx = suits[attack_suit]
    defend_suit_idx = suits[defend_suit]
    trump_suit_idx = suits[trump_suit]

    if attack_suit_idx == defend_suit_idx:
        return defend_rank_idx > attack_rank_idx
    elif defend_suit_idx == trump_suit_idx and attack_suit_idx != trump_suit_idx:
        return True
    return False

# Извлечение реального действия (та же функция, что использовалась при обучении)
def get_action(prev_snapshot, curr_snapshot, my_id):
    try:
        prev_data = json.loads(prev_snapshot['snapshot'])
        curr_data = json.loads(curr_snapshot['snapshot'])

        ranks = {'9': 0, '10': 1, '11': 2, '12': 3, '13': 4, '14': 5}
        suits = {'S': 0, 'C': 1, 'D': 2, 'H': 3}
        trump_suit = curr_data['trump'][-1]
        game_type = curr_data['game_rules']['game_type']

        if len(curr_data['table']) > len(prev_data['table']):
            new_pair = curr_data['table'][-1]
            attack_card = new_pair['attack_card']['card']
            attack_user = new_pair['attack_card']['user_id']

            if attack_user == my_id:
                rank, suit = attack_card[:-1], attack_card[-1]
                card_idx = ranks[rank] * 4 + suits[suit]

                if game_type == 1 and prev_data['table']:
                    last_attack = prev_data['table'][-1]['attack_card']['card']
                    if last_attack[:-1] == attack_card[:-1]:
                        print(f"Detected transfer: {attack_card} matches {last_attack}")
                        return {'type': 'attack', 'card_idx': card_idx, 'card': attack_card}
                return {'type': 'attack', 'card_idx': card_idx, 'card': attack_card}

            if 'defend_card' in new_pair and new_pair['defend_card']['user_id'] == my_id:
                defend_card = new_pair['defend_card']['card']
                rank, suit = defend_card[:-1], defend_card[-1]
                card_idx = ranks[rank] * 4 + suits[suit]

                if not can_beat(attack_card, defend_card, trump_suit):
                    print(f"Invalid defend: {defend_card} cannot beat {attack_card}, trump={trump_suit}")
                    return {'type': 'invalid'}
                return {'type': 'defend', 'card_idx': card_idx, 'card': defend_card}

        for curr_p, prev_p in zip(curr_data['players'], prev_data['players']):
            if curr_p['id'] == my_id and curr_p['state'] != prev_p['state']:
                if curr_p['state'] in ['bat', 'pass', 'take']:
                    state_map = {'bat': 0, 'pass': 1, 'take': 2}
                    print(f"State change detected: {prev_p['state']} -> {curr_p['state']}")
                    return {'type': 'state', 'state_idx': state_map[curr_p['state']], 'state': curr_p['state']}

        if len(curr_data['table']) == 0 and len(prev_data['table']) > 0:
            for curr_p, prev_p in zip(curr_data['players'], prev_data['players']):
                if curr_p['id'] == my_id and curr_p['state'] == 'bat':
                    print("Detected 'bat' action: table cleared")
                    return {'type': 'state', 'state_idx': 0, 'state': 'bat'}

        return {'type': 'wait'}

    except Exception as e:
        print(f"Error in get_action: {e}")
        return {'type': 'invalid'}

# Предсказание действия моделью
def predict_action(model, snapshot, my_id, device='cpu'):
    model.eval()
    with torch.no_grad():
        state = encode_state(snapshot, my_id, device=device)
        action_logits, card_logits, state_logits = model(state.unsqueeze(0))

        print(f"action_logits: {action_logits}")

        action_idx = torch.argmax(action_logits, dim=1).item()
        action_map = {0: 'attack', 1: 'defend', 2: 'state', 3: 'wait'}
        action_type = action_map[action_idx]

        if action_type == 'attack' or action_type == 'defend':
            card_idx = torch.argmax(card_logits, dim=1).item()
            ranks = {0: '9', 1: '10', 2: '11', 3: '12', 4: '13', 5: '14'}
            suits = {0: 'S', 1: 'C', 2: 'D', 3: 'H'}
            card = f"{ranks[card_idx // 4]}{suits[card_idx % 4]}"
            return {"type": action_type, "move": card}
        elif action_type == 'state':
            state_idx = torch.argmax(state_logits, dim=1).item()
            state_map = {0: 'bat', 1: 'pass', 2: 'take'}
            return {"type": "state", "state": state_map[state_idx]}
        else:
            return {"type": "wait"}

# Загрузка модели
def load_model(model_path="durak_model.pth", device='cpu'):
    model = DurakNet().to(device)
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        print(f"Model weights loaded from {model_path}")
        return model
    else:
        print(f"Model file {model_path} not found")
        return None

# Тестирование модели
def test_model(max_games=None, device='cpu'):
    print("Loading dataset...")
    dataset = load_dataset("neuronetties/durak")
    print("Dataset loaded.")

    # Загрузка модели
    model = load_model("durak_model.pth", device=device)
    if model is None:
        print("Cannot proceed with testing: model not found.")
        return

    # Группировка снимков по играм
    games = defaultdict(list)
    for sample in dataset['train']:
        games[sample['game_id']].append(sample)

    if max_games is not None:
        game_ids = list(games.keys())[:max_games]
        games = {gid: games[gid] for gid in game_ids}

    total_games = len(games)
    print(f"Total games to test: {total_games}")

    # Метрики
    total_samples = 0
    correct_action_type = 0
    correct_card = 0  # Для attack и defend
    correct_state = 0  # Для state
    invalid_samples = 0

    for game_id, snapshots in games.items():
        print(f"\nTesting game {game_id}...")
        snapshots = sorted(snapshots, key=lambda x: json.loads(x['snapshot'])['timestamp'])

        try:
            my_id = json.loads(snapshots[0]['snapshot'])['players'][0]['id']
        except (KeyError, IndexError):
            print(f"Skipping game {game_id}: invalid player data")
            continue

        for i in range(len(snapshots) - 1):
            try:
                # Получаем реальное действие
                real_action = get_action(snapshots[i], snapshots[i + 1], my_id)

                if real_action['type'] == 'invalid':
                    invalid_samples += 1
                    continue

                if real_action['type'] not in ['attack', 'defend', 'state', 'wait']:
                    print(f"Unknown real action type: {real_action['type']}")
                    continue

                # Делаем предсказание
                predicted_action = predict_action(model, snapshots[i], my_id, device=device)

                total_samples += 1

                # Сравниваем тип действия
                print(f"\nSnapshot {i} in game {game_id}:")
                print(f"Real action: {real_action}")
                print(f"Predicted action: {predicted_action}")

                if real_action['type'] == predicted_action['type']:
                    correct_action_type += 1
                    print("Action type: CORRECT")

                    # Дополнительно сравниваем детали
                    if real_action['type'] in ['attack', 'defend']:
                        real_card = real_action.get('card', '')
                        predicted_card = predicted_action.get('move', '')
                        if real_card == predicted_card:
                            correct_card += 1
                            print("Card: CORRECT")
                        else:
                            print(f"Card: WRONG (Real: {real_card}, Predicted: {predicted_card})")

                    elif real_action['type'] == 'state':
                        real_state = real_action.get('state', '')
                        predicted_state = predicted_action.get('state', '')
                        if real_state == predicted_state:
                            correct_state += 1
                            print("State: CORRECT")
                        else:
                            print(f"State: WRONG (Real: {real_state}, Predicted: {predicted_state})")
                else:
                    print(f"Action type: WRONG (Real: {real_action['type']}, Predicted: {predicted_action['type']})")

            except Exception as e:
                print(f"Error testing snapshot {i} in game {game_id}: {e}")
                continue

    # Вывод метрик
    print("\n=== Testing Summary ===")
    print(f"Total samples tested: {total_samples}")
    print(f"Invalid samples skipped: {invalid_samples}")
    if total_samples > 0:
        action_accuracy = (correct_action_type / total_samples) * 100
        print(f"Action type accuracy: {action_accuracy:.2f}% ({correct_action_type}/{total_samples})")

        attack_defend_samples = sum(1 for action in dataset['train']
                                   if 'type' in get_action(action, dataset['train'][0], my_id)
                                   and get_action(action, dataset['train'][0], my_id)['type'] in ['attack', 'defend'])
        if attack_defend_samples > 0:
            card_accuracy = (correct_card / attack_defend_samples) * 100
            print(f"Card accuracy (attack/defend): {card_accuracy:.2f}% ({correct_card}/{attack_defend_samples})")

        state_samples = sum(1 for action in dataset['train']
                           if 'type' in get_action(action, dataset['train'][0], my_id)
                           and get_action(action, dataset['train'][0], my_id)['type'] == 'state')
        if state_samples > 0:
            state_accuracy = (correct_state / state_samples) * 100
            print(f"State accuracy (state actions): {state_accuracy:.2f}% ({correct_state}/{state_samples})")
    else:
        print("No valid samples to evaluate.")

if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    # Тестируем на 10 играх для начала (можно увеличить или убрать max_games)
    test_model(max_games=10, device=device)

Using device: cpu
Loading dataset...
Dataset loaded.
Model weights loaded from durak_model.pth
Total games to test: 10

Testing game 8219abbc-7ec9-48f5-b9be-63374c227d6d...
action_logits: tensor([[  6.2217, -44.8036,  -7.4920,  -6.2623]])

Snapshot 0 in game 8219abbc-7ec9-48f5-b9be-63374c227d6d:
Real action: {'type': 'attack', 'card_idx': 6, 'card': '10D'}
Predicted action: {'type': 'attack', 'move': '10D'}
Action type: CORRECT
Card: CORRECT
action_logits: tensor([[ -2.8293, -51.1464,  -1.5825,   2.1460]])

Snapshot 1 in game 8219abbc-7ec9-48f5-b9be-63374c227d6d:
Real action: {'type': 'wait'}
Predicted action: {'type': 'wait'}
Action type: CORRECT
State change detected: attack -> bat
action_logits: tensor([[ -2.4544, -72.2606,   1.5098, -11.6610]])

Snapshot 2 in game 8219abbc-7ec9-48f5-b9be-63374c227d6d:
Real action: {'type': 'state', 'state_idx': 0, 'state': 'bat'}
Predicted action: {'type': 'state', 'state': 'bat'}
Action type: CORRECT
State: CORRECT
action_logits: tensor([[ -101.86