In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

**Value Network**

This neural network predicts the likelihood of winning the game given the current state. It takes four inputs:

* **Board**: the 13×13 game board, processed with a 2-layer CNN (output 64 features via max-pooling)
* **Cards**: player’s current cards, processed with a small MLP (14 → 32 features)
* **Score**: normalized score difference `(player_score - opp_score) / MAX_SCORE`, processed with a small MLP (1 → 8 features)
* **Stage**: normalized game stage `max(player_score, opp_score) / MAX_SCORE`, processed with a small MLP (1 → 8 features)

The features from all inputs are concatenated (64+32+8+8 = 112) and fed into a **fusion network**: fully connected layers with ReLU and a final `tanh` activation, producing a scalar value in [-1, 1] representing the predicted game outcome.

In [15]:
# uncomment the one that you want to use
# from model_V1 import ValueNet  
from model_V2 import ValueNet

The **value network** is trained over a specified number of epochs. For each batch, we feeds the input into the model to predict the game outcome. Use the MSE between predictions and targets as the loss. Gradients are backpropagated, and the optimizer updates the model parameters. The function prints the average loss per epoch to monitor training progress.

In [3]:
MAX_SCORE = 20.0

def train_value_net(model, dataloader, optimizer, device, epochs=10):
    model.train()

    for epoch in range(epochs):
        total_loss = 0.0

        for board, cards, target, player_score, opp_score in dataloader:
            board = board.to(device)
            cards = cards.to(device)
            target = target.to(device).float().view(-1, 1)
            score = ((player_score - opp_score) / MAX_SCORE).to(device).float().view(-1, 1)
            stage = (torch.max(player_score, opp_score) / MAX_SCORE).to(device).float().view(-1, 1)

            pred = model(board, cards, score, stage)   # (B, 1)
            loss = F.mse_loss(pred, target)
            # or: loss = F.smooth_l1_loss(pred, target)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")


The **value network** can be evaluated in 2 ways:
1. `eval_value_net` that computes the **mean squared error (MSE)** for each data point and returns the average across the dataset. This helps us see **how far the model’s predictions are from the actual outcomes** in absolute terms.

2. `eval_value_net2` that measures **correlation** between all predictions and targets, as well as **directional correctness**—the fraction of cases where the model correctly predicts the winner (or the trend). This tells us whether the model is **ranking game states correctly** and predicting the **right direction**, which is often more important than the exact value.

In [4]:
def eval_value_net(model, dataloader, device):
    model.eval() 
    total_loss = 0.0

    with torch.no_grad():  # disable gradient computation
        for board, cards, target, player_score, opp_score in dataloader:
            board = board.to(device)
            cards = cards.to(device)
            target_ = target.to(device).float().view(-1, 1)
            score = ((player_score - opp_score) / MAX_SCORE).to(device).float().view(-1, 1)
            stage = (torch.max(player_score, opp_score) / MAX_SCORE).to(device).float().view(-1, 1)

            pred = model(board, cards, score, stage)
            loss = F.mse_loss(pred, target_)

            total_loss += loss.item()
            if (target.detach().cpu()[0] > 0):
                print(f"pred: {pred.detach().cpu().numpy()[0][0]}, target: {target.detach().cpu()[0]}")
           
            
    avg_loss = total_loss / len(dataloader)
    return avg_loss

def eval_value_net2(model, dataloader, device):
    model.eval() 
    all_preds = []
    all_targets = []    

    with torch.no_grad():  # disable gradient computation
        for board, cards, target, player_score, opp_score in dataloader:
            board = board.to(device)
            cards = cards.to(device)
            target = target.to(device).float().view(-1, 1)
            score = ((player_score - opp_score) / MAX_SCORE).to(device).float().view(-1, 1)
            stage = (torch.max(player_score, opp_score) / MAX_SCORE).to(device).float().view(-1, 1)

            pred = model(board, cards, score, stage)
            all_preds.append(pred.cpu())
            all_targets.append(target.cpu())
    
    preds = torch.cat(all_preds).view(-1)
    targets = torch.cat(all_targets).view(-1)
            
    corr = torch.corrcoef(torch.stack([preds, targets]))[0, 1]
    print("Correlation:", corr.item())

    correct_dir = ((preds * targets) > 0).float()  # True if same sign
    accuracy = correct_dir.mean().item()
    print("Directional correctness:", accuracy)


**GameOf10Dataset**

A PyTorch dataset that converts raw game states into tensors for the value network. Each sample includes a 13×13 board encoded as a 15-channel one-hot tensor, a 14-dimensional card vector, the outcome (likelihood of winning), and both players’ scores.  

The dataset can optionally filter by game stage (`early`, `mid`, `late`) to focus on different parts of the game.


In [5]:
SYMBOL_TO_CHANNEL = {
    '0': 0, '1': 1, '2': 2, '3': 3, '4': 4,
    '5': 5, '6': 6, '7': 7, '8': 8, '9': 9,
    '+': 10,
    '-': 11,
    '/': 12,
    'x': 13,
    '': 14
}

class GameOf10Dataset(Dataset):
    def __init__(self, dataset, stage='full'):
        boards = []
        cards = []
        outcomes = []
        player_score = []
        opp_score = []

        for i in range(len(dataset['boards'])):
            if stage == 'late' and (dataset['player_score'][i] > 15 or dataset['opp_score'][i] > 15):
                boards.append(dataset['boards'][i])
                cards.append(dataset['cards'][i])
                outcomes.append(dataset['outcomes'][i])
                player_score.append(dataset['player_score'][i])
                opp_score.append(dataset['opp_score'][i])
            if stage == 'mid' and (8 < dataset['player_score'][i] < 15 or 8 < dataset['opp_score'][i] < 15):
                boards.append(dataset['boards'][i])
                cards.append(dataset['cards'][i])
                outcomes.append(dataset['outcomes'][i])
                player_score.append(dataset['player_score'][i])
                opp_score.append(dataset['opp_score'][i]) 
            if stage == 'early' and (0 < dataset['player_score'][i] < 8 or 0 < dataset['opp_score'][i] < 8):
                boards.append(dataset['boards'][i])
                cards.append(dataset['cards'][i])
                outcomes.append(dataset['outcomes'][i])
                player_score.append(dataset['player_score'][i])
                opp_score.append(dataset['opp_score'][i])
        
        if len(boards) == 0:
            boards = dataset['boards']
            cards = dataset['cards']
            outcomes = dataset['outcomes']
            player_score = dataset['player_score']
            opp_score = dataset['opp_score']

        self.n = len(boards)
        self.boards = torch.zeros(self.n, 15, 13, 13, dtype=torch.float32)
        for k, board in enumerate(boards):
            for i in range(13):
                for j in range(13):
                    cell = board[i][j]
                    ch = SYMBOL_TO_CHANNEL[cell]
                    self.boards[k, ch, i, j] = 1.0
    
        self.cards = torch.zeros(self.n, 14, dtype=torch.float32)
        for k, user_cards in enumerate(cards):
            for card in user_cards:
                self.cards[k, SYMBOL_TO_CHANNEL[card]]+=1.
                
        self.outcomes = outcomes
        self.player_score = player_score
        self.opp_score = opp_score

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        board = self.boards[idx]          # (13, 13)
        cards = self.cards[idx]           # (13,)
        outcome = self.outcomes[idx]
        player_score = self.player_score[idx]
        opp_score = self.opp_score[idx]

        return board.float(), cards.float(), outcome, player_score, opp_score


Example of how to train a new model

In [None]:
dataset = torch.load("nn_data/2026-01-10/data_11-48-45.pt")
dataset = GameOf10Dataset(dataset)
train_loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ValueNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train_value_net(
    model=model,
    dataloader=train_loader,
    optimizer=optimizer,
    device=device,
    epochs=5
)

# Uncomment this line to store the model locally
# torch.save(model.state_dict(), "model_weightsV3.pth")

Epoch 1: loss = 0.6920
Epoch 2: loss = 0.5992
Epoch 3: loss = 0.5077
Epoch 4: loss = 0.4172
Epoch 5: loss = 0.3581


Example of how to load parameters from a local file (remember to import the correct version of ValueNet)

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ValueNet().to(device)
model.load_state_dict(torch.load("model_weightsV2.pth"))

<All keys matched successfully>

Example of how to evaluate the model

In [None]:
def print_board(board):
    cols = len(board[0])

    print("+" + "---+" * cols)
    for row in board:
        print("| " + " | ".join(str(cell) for cell in row) + " |")
        print("+" + "---+" * cols)

dataset = torch.load("nn_data/2026-01-09/data_16-59-40.pt")

# Uncomment the following code to see the data
# k = len(dataset['boards']) - 1
# print_board(dataset['boards'][k])
# print(dataset['cards'][k])
# print(dataset['player_score'][k].detach().numpy())
# print(dataset['opp_score'][k].detach().numpy())

dataset = GameOf10Dataset(dataset, 'mid')
val_loader = DataLoader(
    dataset,
    batch_size=1,
)

# Uncomment the following code to evaluate using MSE
# val_loss = eval_value_net(model, val_loader, device)
# print(f"Validation loss: {val_loss:.4f}")

eval_value_net2(model, val_loader, device)



Correlation: 0.722737729549408
Directional correctness: 0.8319767713546753


Better way to evaluate the model using correlation + direction and stage of the game

In [17]:
def eval(stage, data_path):
    print(stage)
    stage = stage.lower()
    dataset = torch.load(data_path)


    # k = len(dataset['boards']) - 1
    # print_board(dataset['boards'][k])
    # print(dataset['cards'][k])
    # print(dataset['player_score'][k].detach().numpy())
    # print(dataset['opp_score'][k].detach().numpy())

    dataset = GameOf10Dataset(dataset, stage)
    val_loader = DataLoader(
        dataset,
        batch_size=1,
    )

    # Uncomment the following code to see the MSE loss
    # val_loss = eval_value_net(model, val_loader, device)
    # print(f"Validation loss: {val_loss:.4f}")

    eval_value_net2(model, val_loader, device)
    print()

eval('Early', "nn_data/2026-01-10/data_13-56-47.pt")
eval('Mid', "nn_data/2026-01-10/data_13-56-47.pt")
eval('Late', "nn_data/2026-01-10/data_13-56-47.pt")
eval('Full', "nn_data/2026-01-10/data_13-56-47.pt")

Early
Correlation: 0.6745660901069641
Directional correctness: 0.7666666507720947

Mid
Correlation: 0.8631322979927063
Directional correctness: 0.9166666865348816

Late
Correlation: 0.9092518091201782
Directional correctness: 0.9411764740943909

Full
Correlation: 0.7298254370689392
Directional correctness: 0.807692289352417

