In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [None]:
class ValueNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.board_net = nn.Sequential(
            nn.Conv2d(in_channels=15, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU()
        )

        self.cards_net = nn.Sequential(
            nn.Linear(14,32),
            nn.Linear(32,32)
        )

        # score diff as input. I.e (player_score - opp_score) / MAX_SCORE
        self.score_net = nn.Sequential(
            nn.Linear(1, 8),
            nn.ReLU()
        )

        # stage = max(player_score, opp_score) / MAX_SCORE
        self.stage_net = nn.Sequential(
            nn.Linear(1, 8),
            nn.ReLU()   
        )

        self.fusion_net = nn.Sequential(
            nn.Linear(64+32+8+8, 128), # maybe add 1 more layer or output 128 to 256? (if underfit)
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Tanh()
        )

    def forward(self, board, cards, score, stage):
        board_feat = self.board_net(board)          # (B, 64, 13, 13)
        board_feat = torch.amax(board_feat, dim=(2, 3))  # max pooling

        cards_feat = self.cards_net(cards)           # (B, 32)
        score_feat = self.score_net(score)          # (B, 8)
        stage_feat = self.stage_net(stage)          # (B, 8)

        # Fusion
        fused = torch.cat([board_feat, cards_feat, score_feat, stage_feat], dim=1)

        value = self.fusion_net(fused)
        return value

In [31]:
MAX_SCORE = 20.0

def train_value_net(model, dataloader, optimizer, device, epochs=10):
    model.train()

    for epoch in range(epochs):
        total_loss = 0.0

        for board, cards, target, player_score, opp_score in dataloader:
            board = board.to(device)
            cards = cards.to(device)
            target = target.to(device).float().view(-1, 1)
            score = ((player_score - opp_score) / MAX_SCORE).to(device).float().view(-1, 1)
            stage = (torch.max(player_score, opp_score) / MAX_SCORE).to(device).float().view(-1, 1)

            pred = model(board, cards, score, stage)   # (B, 1)
            loss = F.mse_loss(pred, target)
            # or: loss = F.smooth_l1_loss(pred, target)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")


In [36]:
def eval_value_net(model, dataloader, device):
    model.eval() 
    total_loss = 0.0

    with torch.no_grad():  # disable gradient computation
        for board, cards, target, player_score, opp_score in dataloader:
            board = board.to(device)
            cards = cards.to(device)
            target_ = target.to(device).float().view(-1, 1)
            score = ((player_score - opp_score) / MAX_SCORE).to(device).float().view(-1, 1)
            stage = (torch.max(player_score, opp_score) / MAX_SCORE).to(device).float().view(-1, 1)

            pred = model(board, cards, score, stage)
            loss = F.mse_loss(pred, target_)

            total_loss += loss.item()

            print(f"pred: {pred.detach().cpu().numpy()[0][0]}, target: {target.detach().cpu()[0]}")
           
            
    avg_loss = total_loss / len(dataloader)
    return avg_loss

def eval_value_net2(model, dataloader, device):
    model.eval() 
    all_preds = []
    all_targets = []    

    with torch.no_grad():  # disable gradient computation
        for board, cards, target, player_score, opp_score in dataloader:
            board = board.to(device)
            cards = cards.to(device)
            target = target.to(device).float().view(-1, 1)
            score = ((player_score - opp_score) / MAX_SCORE).to(device).float().view(-1, 1)
            stage = (torch.max(player_score, opp_score) / MAX_SCORE).to(device).float().view(-1, 1)

            pred = model(board, cards, score, stage)
            all_preds.append(pred.cpu())
            all_targets.append(target.cpu())
    
    preds = torch.cat(all_preds).view(-1)
    targets = torch.cat(all_targets).view(-1)
            
    corr = torch.corrcoef(torch.stack([preds, targets]))[0, 1]
    print("Correlation:", corr.item())

    correct_dir = ((preds[-1] * targets[-1]) > 0).float()  # True if same sign
    accuracy = correct_dir.mean().item()
    print("Directional correctness:", accuracy)


In [21]:
SYMBOL_TO_CHANNEL = {
    '0': 0, '1': 1, '2': 2, '3': 3, '4': 4,
    '5': 5, '6': 6, '7': 7, '8': 8, '9': 9,
    '+': 10,
    '-': 11,
    '/': 12,
    'x': 13,
    '': 14
}

class GameOf10Dataset(Dataset):
    def __init__(self, dataset, stage='full'):
        boards = []
        cards = []
        outcomes = []
        player_score = []
        opp_score = []

        for i in range(len(dataset['boards'])):
            if stage == 'late' and (dataset['player_score'][i] > 15 or dataset['opp_score'][i] > 15):
                boards.append(dataset['boards'][i])
                cards.append(dataset['cards'][i])
                outcomes.append(dataset['outcomes'][i])
                player_score.append(dataset['player_score'][i])
                opp_score.append(dataset['opp_score'][i])
            if stage == 'mid' and (8 < dataset['player_score'][i] < 15 or 8 < dataset['opp_score'][i] < 15):
                boards.append(dataset['boards'][i])
                cards.append(dataset['cards'][i])
                outcomes.append(dataset['outcomes'][i])
                player_score.append(dataset['player_score'][i])
                opp_score.append(dataset['opp_score'][i]) 
            if stage == 'early' and (0 < dataset['player_score'][i] < 8 or 0 < dataset['opp_score'][i] < 8):
                boards.append(dataset['boards'][i])
                cards.append(dataset['cards'][i])
                outcomes.append(dataset['outcomes'][i])
                player_score.append(dataset['player_score'][i])
                opp_score.append(dataset['opp_score'][i])
        
        if len(boards) == 0:
            boards = dataset['boards']
            cards = dataset['cards']
            outcomes = dataset['outcomes']
            player_score = dataset['player_score']
            opp_score = dataset['opp_score']

        self.n = len(boards)
        self.boards = torch.zeros(self.n, 15, 13, 13, dtype=torch.float32)
        for k, board in enumerate(boards):
            for i in range(13):
                for j in range(13):
                    cell = board[i][j]
                    ch = SYMBOL_TO_CHANNEL[cell]
                    self.boards[k, ch, i, j] = 1.0
    
        self.cards = torch.zeros(self.n, 14, dtype=torch.float32)
        for k, user_cards in enumerate(cards):
            for card in user_cards:
                self.cards[k, SYMBOL_TO_CHANNEL[card]]+=1.
                
        self.outcomes = outcomes
        self.player_score = player_score
        self.opp_score = opp_score

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        board = self.boards[idx]          # (13, 13)
        cards = self.cards[idx]           # (13,)
        outcome = self.outcomes[idx]
        player_score = self.player_score[idx]
        opp_score = self.opp_score[idx]

        return board.float(), cards.float(), outcome, player_score, opp_score


In [33]:
dataset = torch.load("nn_data/2026-01-09/data_16-59-40.pt")
dataset = GameOf10Dataset(dataset)
train_loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ValueNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train_value_net(
    model=model,
    dataloader=train_loader,
    optimizer=optimizer,
    device=device,
    epochs=20
)

Epoch 1: loss = 0.7867
Epoch 2: loss = 0.5557
Epoch 3: loss = 0.5430
Epoch 4: loss = 0.5188
Epoch 5: loss = 0.4632
Epoch 6: loss = 0.4138
Epoch 7: loss = 0.3713
Epoch 8: loss = 0.3522
Epoch 9: loss = 0.3229
Epoch 10: loss = 0.2951
Epoch 11: loss = 0.2788
Epoch 12: loss = 0.2536
Epoch 13: loss = 0.2314
Epoch 14: loss = 0.2106
Epoch 15: loss = 0.2020
Epoch 16: loss = 0.1858
Epoch 17: loss = 0.1604
Epoch 18: loss = 0.1562
Epoch 19: loss = 0.1380
Epoch 20: loss = 0.1345


In [7]:
def print_board(board):
    cols = len(board[0])

    print("+" + "---+" * cols)
    for row in board:
        print("| " + " | ".join(str(cell) for cell in row) + " |")
        print("+" + "---+" * cols)

In [48]:
dataset = torch.load("nn_data/2026-01-09/data_19-11-20.pt")

# k = len(dataset['boards']) - 1
# print_board(dataset['boards'][k])
# print(dataset['cards'][k])
# print(dataset['player_score'][k].detach().numpy())
# print(dataset['opp_score'][k].detach().numpy())

dataset = GameOf10Dataset(dataset, 'early')
val_loader = DataLoader(
    dataset,
    batch_size=1,
    # shuffle=True
)

# val_loss = eval_value_net(model, val_loader, device)
# print(f"Validation loss: {val_loss:.4f}")

eval_value_net2(model, val_loader, device)



Correlation: -0.21922534704208374
Directional correctness: 1.0


In [None]:
# torch.save(model.state_dict(), "model_weights.pth")