In [3]:
from hnefatal.game import *
import time

def benchmark_function(func, iterations=100):
    start_time = time.perf_counter()
    for _ in range(iterations):
        func()
    end_time = time.perf_counter()
    print(f"Average time per iteration: {(end_time - start_time) / iterations * 1000 :.4f} ms")

import torch
from torch import nn
from sklearn.preprocessing import OneHotEncoder
import numpy as np

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

BOARD_WIDTH = 13


Using cpu device


In [None]:
import torch
from torch import nn
from sklearn.preprocessing import OneHotEncoder
import numpy as np

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

BOARD_WIDTH = 13

# The Neural Network is designed with a OneHot-Encoded input where each piece type is represented as a separate feature.
# The output is a tensor of shape (BOARD_WIDTH, BOARD_WIDTH, 4, DISTANCE_SIZE) representing the policy for each piece's movement in four directions (up, down, left, right) and distances from 1 to DISTANCE_SIZE.

class SimpleDefendersModel(nn.Module):
    def __init__(self):
        super(SimpleDefendersModel, self).__init__()
        self.flatten = nn.Flatten()
        self.nn = nn.Sequential(
            nn.Linear(BOARD_WIDTH**2*4, 1024),
            nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            nn.Linear(1024, BOARD_WIDTH**2 * 4 * (BOARD_WIDTH - 1)),  # Output for each direction and a length of maximum BOARD_WIDTH - 1
            nn.Softmax(dim=-1)  # Softmax to get probabilities for each move
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.nn(x)
        return logits
    
model = SimpleDefendersModel().to(device)
print(model)

encoder = OneHotEncoder()
encoder.fit([[Piece.EMPTY.value], [Piece.ATTACKER.value], [Piece.DEFENDER.value], [Piece.KING.value]])

def apply_direction(row, col, direction, distance):
    if direction == 0:   # Up
        return row - distance, col
    elif direction == 1: # Down
        return row + distance, col
    elif direction == 2: # Left
        return row, col - distance
    elif direction == 3: # Right
        return row, col + distance

def find_ai_move(game, player):    
    # Flatten the board and get the value of each piece
    flat_board = [piece.value for row in game.board for piece in row]
    flat_board = torch.tensor(flat_board).reshape(-1, 1)
    x = encoder.transform(flat_board)
    x = torch.tensor(x.toarray(), dtype=torch.float32).to(device)

    # Get the model's prediction
    logits = model(x.flatten().unsqueeze(0))

    flat_policy = logits.reshape((BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1))
    
    moves = []
    for row in range(BOARD_WIDTH):
        for col in range(BOARD_WIDTH):
            for direction in range(4):
                for distance in range(BOARD_WIDTH - 1):
                    prob = flat_policy[row, col, direction, distance]
                    # if prob > 0:
                    moves.append(((row, col, direction, distance + 1), prob))

    moves.sort(key=lambda x: x[1], reverse=True)

    # Filter moves to only valid ones
    valid_moves = []
    for move, prob in moves:
        from_row, from_col, direction, steps = move
        to_row, to_col = apply_direction(from_row, from_col, direction, steps)
        if not all(0 <= pos < BOARD_WIDTH for pos in (from_row, from_col, to_row, to_col)):
            continue

        from_pos = Coord(from_row, from_col)
        to_pos = Coord(to_row, to_col)

        if game.is_valid_move(from_pos, to_pos):
            return Move(from_pos, to_pos)

    assert False, "No valid moves found"

game = Game()
game.fill_board_13_by_13()
def test_ai_move():
    move = find_ai_move(game, Player.DEFENDER)

benchmark_function(test_ai_move, 4)

Using cpu device
SimpleDefendersModel(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (nn): Sequential(
    (0): Linear(in_features=676, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=8112, bias=True)
    (3): Softmax(dim=-1)
  )
)
Average time per iteration: 334.0328 ms


In [48]:
import torch
from torch import nn
from sklearn.preprocessing import OneHotEncoder
import numpy as np

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

BOARD_WIDTH = 13

# The Neural Network is designed with a OneHot-Encoded input where each piece type is represented as a separate feature.
# The output is a tensor of shape (BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1) representing the policy for each piece's movement in four directions (up, down, left, right) and distances from 1 to BOARD_WIDTH - 1.

class SimpleDefendersModel(nn.Module):
    def __init__(self):
        super(SimpleDefendersModel, self).__init__()
        self.flatten = nn.Flatten()
        self.nn = nn.Sequential(
            nn.Linear(BOARD_WIDTH**2*4, 1024),
            nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            nn.Linear(1024, BOARD_WIDTH**2 * 4 * (BOARD_WIDTH - 1)),  # Output for each direction and a length of maximum BOARD_WIDTH - 1
            nn.Softmax(dim=-1)  # Softmax to get probabilities for each move
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.nn(x)
        return logits
    
model = SimpleDefendersModel().to(device)
print(model)

encoder = OneHotEncoder()
encoder.fit([[Piece.EMPTY.value], [Piece.ATTACKER.value], [Piece.DEFENDER.value], [Piece.KING.value]])

def apply_direction(row, col, direction, distance):
    if direction == 0:   # Up
        return row - distance, col
    elif direction == 1: # Down
        return row + distance, col
    elif direction == 2: # Left
        return row, col - distance
    elif direction == 3: # Right
        return row, col + distance

def find_ai_move(game, player):    
    # Flatten the board and get the value of each piece
    flat_board = torch.tensor(game.board, dtype=torch.float32).flatten().reshape(-1, 1)
    x = encoder.transform(flat_board)
    x = torch.tensor(x.toarray(), dtype=torch.float32).to(device)

    # Get the model's prediction
    logits = model(x.flatten().unsqueeze(0))

    flat_policy = logits.reshape((BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1))
    
    moves = []
    for row in range(BOARD_WIDTH):
        for col in range(BOARD_WIDTH):
            for direction in range(4):
                for distance in range(BOARD_WIDTH - 1):
                    prob = flat_policy[row, col, direction, distance]
                    # if prob > 0:
                    moves.append(((row, col, direction, distance + 1), prob))

    moves.sort(key=lambda x: x[1], reverse=True)

    # Filter moves to only valid ones
    valid_moves = []
    for move, prob in moves:
        from_row, from_col, direction, steps = move
        to_row, to_col = apply_direction(from_row, from_col, direction, steps)
        if not all(0 <= pos < BOARD_WIDTH for pos in (from_row, from_col, to_row, to_col)):
            continue

        from_pos = Coord(from_row, from_col)
        to_pos = Coord(to_row, to_col)

        if game.is_valid_move(from_pos, to_pos):
            return Move(from_pos, to_pos)

    assert False, "No valid moves found"

game = Game()
game.fill_board_13_by_13()
def test_ai_move():
    move = find_ai_move(game, Player.DEFENDER)
    print(f"Fast AI Move: {move}")

benchmark_function(test_ai_move, 4)

Using cpu device
SimpleDefendersModel(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (nn): Sequential(
    (0): Linear(in_features=676, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=8112, bias=True)
    (3): Softmax(dim=-1)
  )
)
Fast AI Move: Move from (4, 6) to (4, 4)
Fast AI Move: Move from (4, 6) to (4, 4)
Fast AI Move: Move from (4, 6) to (4, 4)
Fast AI Move: Move from (4, 6) to (4, 4)
Average time per iteration: 379.7384 ms


Bottleneck er moves.sort, der gør en forskel af cirka 324 ms altså 72%

In [52]:
import torch
from torch import nn
from sklearn.preprocessing import OneHotEncoder
import numpy as np

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

BOARD_WIDTH = 13

# The Neural Network is designed with a OneHot-Encoded input where each piece type is represented as a separate feature.
# The output is a tensor of shape (BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1) representing the policy for each piece's movement in four directions (up, down, left, right) and distances from 1 to BOARD_WIDTH - 1.

class SimpleDefendersModel(nn.Module):
    def __init__(self):
        super(SimpleDefendersModel, self).__init__()
        # self.flatten = nn.Flatten()
        self.encode = nn.functional.one_hot
        self.nn = nn.Sequential(
            nn.Linear(BOARD_WIDTH*BOARD_WIDTH*4, 1024),
            nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            nn.Linear(1024, BOARD_WIDTH**2 * 4 * (BOARD_WIDTH - 1)),  # Output for each direction and a length of maximum BOARD_WIDTH - 1
            nn.Softmax(dim=-1)  # Softmax to get probabilities for each move
        )

    def forward(self, x):
        # x = self.flatten(x)
        x = self.encode(x, num_classes=4)
        x = x.type(torch.float32).flatten().reshape(1,-1)  # Ensure the input is float32
        logits = self.nn(x)
        return logits
    
model = SimpleDefendersModel().to(device)
print(model)

def apply_direction(row, col, direction, distance):
    if direction == 0:   # Up
        return row - distance, col
    elif direction == 1: # Down
        return row + distance, col
    elif direction == 2: # Left
        return row, col - distance
    elif direction == 3: # Right
        return row, col + distance

def find_ai_move(game, player):    
    # Flatten the board and get the value of each piece
    
    flat_board = [piece.value + 1 for row in game.board for piece in row]
    x = torch.tensor(flat_board, dtype=torch.long).flatten().reshape(-1, 1)

    # Get the model's prediction
    logits = model(x)

    flat_policy = logits.reshape((BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1))
    
    best_move = None
    best_move_prob = -1

    for from_row in range(BOARD_WIDTH):
        for from_col in range(BOARD_WIDTH):
            for direction in range(4):
                for distance in range(BOARD_WIDTH - 1):
                    
                    to_row, to_col = apply_direction(from_row, from_col, direction, distance)
                    if not all(0 <= pos < BOARD_WIDTH for pos in (from_row, from_col, to_row, to_col)):
                        continue

                    from_pos = Coord(from_row, from_col)
                    to_pos = Coord(to_row, to_col)

                    # if game.is_valid_move(from_pos, to_pos):
                    if True:
                        prob = flat_policy[from_row, from_col, direction, distance].item()

                        if prob > best_move_prob:
                            best_move = Move(from_pos, to_pos)
                            best_move_prob = prob

    assert best_move != None, "No valid moves found"

    return best_move

game = Game()
game.fill_board_13_by_13()
def test_ai_move():
    move = find_ai_move(game, Player.DEFENDER)
    print(f"Fast AI Move: {move}")

benchmark_function(test_ai_move, 8)

Using cpu device
SimpleDefendersModel(
  (nn): Sequential(
    (0): Linear(in_features=676, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=8112, bias=True)
    (3): Softmax(dim=-1)
  )
)
Fast AI Move: Move from (0, 3) to (3, 3)
Fast AI Move: Move from (0, 3) to (3, 3)
Fast AI Move: Move from (0, 3) to (3, 3)
Fast AI Move: Move from (0, 3) to (3, 3)
Fast AI Move: Move from (0, 3) to (3, 3)
Fast AI Move: Move from (0, 3) to (3, 3)
Fast AI Move: Move from (0, 3) to (3, 3)
Fast AI Move: Move from (0, 3) to (3, 3)
Average time per iteration: 265.6174 ms


En forøgelse i fart på 144%, eller 44 ms ved at kun bruge et enkelt loop
Lige nu er 17% af tiden brugt på at validere mulige træk
Avg: 265 ms

In [23]:
import torch

class SimpleDefendersModel(nn.Module):
    def __init__(self):
        super(SimpleDefendersModel, self).__init__()
        # self.flatten = nn.Flatten()
        self.encode = nn.functional.one_hot
        self.nn = nn.Sequential(
            nn.Linear(BOARD_WIDTH*BOARD_WIDTH*4, 1024),
            nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            nn.Linear(1024, BOARD_WIDTH**2 * 4 * (BOARD_WIDTH - 1)),  # Output for each direction and a length of maximum BOARD_WIDTH - 1
            nn.Softmax(dim=-1)  # Softmax to get probabilities for each move
        )

    def forward(self, x):
        # x = self.flatten(x)
        x = self.encode(x, num_classes=4)
        x = x.type(torch.float32).flatten().reshape(1,-1)  # Ensure the input is float32
        logits = self.nn(x)
        return logits
    
model = SimpleDefendersModel().to(device)

def find_ai_move_fast(game, player):
    # Flatten the board and get the value of each piece
    flat_board = [piece.value + 1 for row in game.board for piece in row]
    x = torch.tensor(flat_board, dtype=torch.long).reshape(-1, 1)

    # Model prediction
    logits = model(x)
    flat_policy = logits.reshape((BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1))

    # Precompute all possible moves and their probabilities in a single pass
    from_rows, from_cols = torch.meshgrid(torch.arange(BOARD_WIDTH), torch.arange(BOARD_WIDTH), indexing='ij')
    directions = torch.arange(4)
    distances = torch.arange(BOARD_WIDTH - 1)

    # Expand to all combinations
    from_rows = from_rows.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 4, BOARD_WIDTH - 1)
    from_cols = from_cols.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 4, BOARD_WIDTH - 1)
    directions = directions.view(1, 1, 4, 1).expand(BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1)
    distances = distances.view(1, 1, 1, BOARD_WIDTH - 1).expand(BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1)

    # Calculate destination coordinates
    to_rows = from_rows + torch.where(directions == 0, -distances, torch.where(directions == 1, distances, torch.zeros_like(distances)))
    to_cols = from_cols + torch.where(directions == 2, -distances, torch.where(directions == 3, distances, torch.zeros_like(distances)))

    # Mask for valid board positions
    valid_mask = (
        (from_rows >= 0) & (from_rows < BOARD_WIDTH) &
        (from_cols >= 0) & (from_cols < BOARD_WIDTH) &
        (to_rows >= 0) & (to_rows < BOARD_WIDTH) &
        (to_cols >= 0) & (to_cols < BOARD_WIDTH)
    )

    # Get probabilities and flatten
    probs = flat_policy * valid_mask
    probs_flat = probs.flatten()

    while True:

        idx = torch.argmax(probs_flat)
        if probs_flat[idx] == 0:
            assert False, "No valid moves found"

         # Recover indices
        unravel = torch.unravel_index(idx, probs.shape)
        from_row, from_col, direction, distance = [i.item() for i in unravel]
        to_row = to_rows[from_row, from_col, direction, distance].item()
        to_col = to_cols[from_row, from_col, direction, distance].item()

        from_pos = Coord(from_row, from_col)
        to_pos = Coord(to_row, to_col)

        if game.is_valid_move(from_pos, to_pos):
            return Move(from_pos, to_pos)
        else:
            # Set the probability of this move to zero and continue
            probs_flat[idx] = 0
            continue

game = Game()
game.fill_board_13_by_13()

def test_ai_move_fast():
    move = find_ai_move_fast(game, Player.DEFENDER)

benchmark_function(test_ai_move_fast, 100)

Average time per iteration: 7.0647 ms


Improvement: 53x
Avg: 7 ms

In [24]:
import torch

class SimpleDefendersModel(nn.Module):
    def __init__(self):
        super(SimpleDefendersModel, self).__init__()
        # self.flatten = nn.Flatten()
        self.encode = nn.functional.one_hot
        self.nn = nn.Sequential(
            nn.Linear(BOARD_WIDTH*BOARD_WIDTH*4, 1024),
            nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            nn.Linear(1024, BOARD_WIDTH**2 * 4 * (BOARD_WIDTH - 1)),  # Output for each direction and a length of maximum BOARD_WIDTH - 1
            nn.Softmax(dim=-1)  # Softmax to get probabilities for each move
        )

    def forward(self, x):
        # x = self.flatten(x)
        x = self.encode(x, num_classes=4)
        x = x.type(torch.float32).flatten().reshape(1,-1)  # Ensure the input is float32
        logits = self.nn(x)
        return logits
    
model = SimpleDefendersModel().to(device)

# Precompute all possible moves and their probabilities in a single pass
from_rows, from_cols = torch.meshgrid(torch.arange(BOARD_WIDTH), torch.arange(BOARD_WIDTH), indexing='ij')
directions = torch.arange(4)
distances = torch.arange(BOARD_WIDTH - 1)

# Expand to all combinations
from_rows = from_rows.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 4, BOARD_WIDTH - 1)
from_cols = from_cols.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 4, BOARD_WIDTH - 1)
directions = directions.view(1, 1, 4, 1).expand(BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1)
distances = distances.view(1, 1, 1, BOARD_WIDTH - 1).expand(BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1)

# Calculate destination coordinates
to_rows = from_rows + torch.where(directions == 0, -distances, torch.where(directions == 1, distances, torch.zeros_like(distances)))
to_cols = from_cols + torch.where(directions == 2, -distances, torch.where(directions == 3, distances, torch.zeros_like(distances)))

# Mask for valid board positions
valid_mask = (
    (from_rows >= 0) & (from_rows < BOARD_WIDTH) &
    (from_cols >= 0) & (from_cols < BOARD_WIDTH) &
    (to_rows >= 0) & (to_rows < BOARD_WIDTH) &
    (to_cols >= 0) & (to_cols < BOARD_WIDTH)
)

def find_ai_move_fast(game, player):
    # Flatten the board and get the value of each piece
    flat_board = [piece.value + 1 for row in game.board for piece in row]
    x = torch.tensor(flat_board, dtype=torch.long).reshape(-1, 1)

    # Model prediction
    logits = model(x)
    flat_policy = logits.reshape((BOARD_WIDTH, BOARD_WIDTH, 4, BOARD_WIDTH - 1))

    # Get probabilities and flatten
    probs = flat_policy * valid_mask
    probs_flat = probs.flatten()

    while True:

        idx = torch.argmax(probs_flat)
        if probs_flat[idx] == 0:
            assert False, "No valid moves found"

         # Recover indices
        unravel = torch.unravel_index(idx, probs.shape)
        from_row, from_col, direction, distance = [i.item() for i in unravel]
        to_row = to_rows[from_row, from_col, direction, distance].item()
        to_col = to_cols[from_row, from_col, direction, distance].item()

        from_pos = Coord(from_row, from_col)
        to_pos = Coord(to_row, to_col)

        if game.is_valid_move(from_pos, to_pos):
            return Move(from_pos, to_pos)
        else:
            # Set the probability of this move to zero and continue
            probs_flat[idx] = 0
            continue

game = Game()
game.fill_board_13_by_13()

def test_ai_move_fast():
    move = find_ai_move_fast(game, Player.DEFENDER)

benchmark_function(test_ai_move_fast, 100)

Average time per iteration: 6.1787 ms


Avg: 6 ms
Here the valid_mask is precomputed

In [43]:
import torch

BOARD_WIDTH = 13
DISTANCE_SIZE = BOARD_WIDTH - 2
LABEL_SIZE = 4004 # Equal to valid_mask.sum()

# Precompute all possible moves and their probabilities in a single pass
from_rows, from_cols = torch.meshgrid(torch.arange(BOARD_WIDTH), torch.arange(BOARD_WIDTH), indexing='ij')
directions = torch.arange(4)
distances = torch.arange(1, DISTANCE_SIZE + 1)

# Expand to all combinations
from_rows = from_rows.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 4, DISTANCE_SIZE)
from_cols = from_cols.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 4, DISTANCE_SIZE)
directions = directions.view(1, 1, 4, 1).expand(BOARD_WIDTH, BOARD_WIDTH, 4, DISTANCE_SIZE)
distances = distances.view(1, 1, 1, DISTANCE_SIZE).expand(BOARD_WIDTH, BOARD_WIDTH, 4, DISTANCE_SIZE)

# Calculate destination coordinates
to_rows = from_rows + torch.where(directions == 0, -distances, torch.where(directions == 1, distances, torch.zeros_like(distances)))
to_cols = from_cols + torch.where(directions == 2, -distances, torch.where(directions == 3, distances, torch.zeros_like(distances)))

# Mask for valid board positions
valid_mask = (
    (from_rows >= 0) & (from_rows < BOARD_WIDTH) &
    (from_cols >= 0) & (from_cols < BOARD_WIDTH) &
    (to_rows >= 0) & (to_rows < BOARD_WIDTH) &
    (to_cols >= 0) & (to_cols < BOARD_WIDTH)
)

assert valid_mask.sum() == LABEL_SIZE, f"Expected {LABEL_SIZE} valid moves, but got {valid_mask.sum()}"

label_index_by_move = [Move(Coord(0, 0), Coord(0, 0))] * LABEL_SIZE

def encode_move_to_index(row, col, direction, distance):
    return (row) * BOARD_WIDTH * 4 * DISTANCE_SIZE + col * 4 * DISTANCE_SIZE + direction * DISTANCE_SIZE + distance

def apply_direction(row, col, direction, distance):
    if direction == 0:   # Up
        return row - distance, col
    elif direction == 1: # Down
        return row + distance, col
    elif direction == 2: # Left
        return row, col - distance
    elif direction == 3: # Right
        return row, col + distance

i = 0
for from_row in range(BOARD_WIDTH):
    for from_col in range(BOARD_WIDTH):
        for direction in range(4):
            for distance in range(DISTANCE_SIZE):
                to_row, to_col = apply_direction(from_row, from_col, direction, distance + 1)

                from_pos = Coord(from_row, from_col)
                to_pos = Coord(to_row, to_col)

                if valid_mask[from_row, from_col, direction, distance]:
                    label_index_by_move[i] = Move(from_pos, to_pos)
                    i += 1


class SimpleDefendersModel(nn.Module):
    def __init__(self):
        super(SimpleDefendersModel, self).__init__()
        # self.flatten = nn.Flatten()
        self.encode = nn.functional.one_hot
        self.nn = nn.Sequential(
            nn.Linear(BOARD_WIDTH*BOARD_WIDTH*4, 1024),
            nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            nn.Linear(1024, LABEL_SIZE),
            nn.Softmax(dim=-1)  # Softmax to get probabilities for each move
        )

    def forward(self, x):
        # x = self.flatten(x)
        x = self.encode(x, num_classes=4)
        x = x.type(torch.float32).flatten().reshape(1,-1)
        logits = self.nn(x)
        return logits
    
model = SimpleDefendersModel().to(device)



def find_ai_move_fast(game, player):
    # Flatten the board and get the value of each piece
    flat_board = [piece.value + 1 for row in game.board for piece in row]
    x = torch.tensor(flat_board, dtype=torch.long).reshape(-1, 1)

    # Model prediction
    logits = model(x)
    logits = logits.flatten()

    while True:

        idx = torch.argmax(logits).item()
        if logits[idx] == 0:
            assert False, "No valid moves found"

        # Recover indices
        move = label_index_by_move[idx]

        if game.is_valid_move(move.from_pos, move.to_pos):
            return move
        else:
            # Set the probability of this move to zero and continue
            logits[idx] = 0
            continue

game = Game()
game.fill_board_13_by_13()

def test_ai_move_fast():
    move = find_ai_move_fast(game, Player.DEFENDER)

benchmark_function(test_ai_move_fast, 1000)

Average time per iteration: 1.7943 ms


Avg: 2.3 ms
By subtracting all impossible moves, the size of the output label can be reduced to 4004, 54% of the original size.
Besides the redundant distance zero was has been removed

Further improvements are: Caching the NN and multithreading

In [44]:
import torch

BOARD_WIDTH = 13
DISTANCE_SIZE = BOARD_WIDTH - 2
LABEL_SIZE = 4004 # Equal to valid_mask.sum()

# Precompute all possible moves and their probabilities in a single pass
from_rows, from_cols = torch.meshgrid(torch.arange(BOARD_WIDTH), torch.arange(BOARD_WIDTH), indexing='ij')
directions = torch.arange(4)
distances = torch.arange(1, DISTANCE_SIZE + 1)

# Expand to all combinations
from_rows = from_rows.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 4, DISTANCE_SIZE)
from_cols = from_cols.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 4, DISTANCE_SIZE)
directions = directions.view(1, 1, 4, 1).expand(BOARD_WIDTH, BOARD_WIDTH, 4, DISTANCE_SIZE)
distances = distances.view(1, 1, 1, DISTANCE_SIZE).expand(BOARD_WIDTH, BOARD_WIDTH, 4, DISTANCE_SIZE)

# Calculate destination coordinates
to_rows = from_rows + torch.where(directions == 0, -distances, torch.where(directions == 1, distances, torch.zeros_like(distances)))
to_cols = from_cols + torch.where(directions == 2, -distances, torch.where(directions == 3, distances, torch.zeros_like(distances)))

# Mask for valid board positions
valid_mask = (
    (from_rows >= 0) & (from_rows < BOARD_WIDTH) &
    (from_cols >= 0) & (from_cols < BOARD_WIDTH) &
    (to_rows >= 0) & (to_rows < BOARD_WIDTH) &
    (to_cols >= 0) & (to_cols < BOARD_WIDTH)
)

assert valid_mask.sum() == LABEL_SIZE, f"Expected {LABEL_SIZE} valid moves, but got {valid_mask.sum()}"

label_index_by_move = [Move(Coord(0, 0), Coord(0, 0))] * LABEL_SIZE

def encode_move_to_index(row, col, direction, distance):
    return (row) * BOARD_WIDTH * 4 * DISTANCE_SIZE + col * 4 * DISTANCE_SIZE + direction * DISTANCE_SIZE + distance

def apply_direction(row, col, direction, distance):
    if direction == 0:   # Up
        return row - distance, col
    elif direction == 1: # Down
        return row + distance, col
    elif direction == 2: # Left
        return row, col - distance
    elif direction == 3: # Right
        return row, col + distance

i = 0
for from_row in range(BOARD_WIDTH):
    for from_col in range(BOARD_WIDTH):
        for direction in range(4):
            for distance in range(DISTANCE_SIZE):
                to_row, to_col = apply_direction(from_row, from_col, direction, distance + 1)

                from_pos = Coord(from_row, from_col)
                to_pos = Coord(to_row, to_col)

                if valid_mask[from_row, from_col, direction, distance]:
                    label_index_by_move[i] = Move(from_pos, to_pos)
                    i += 1


class SimpleDefendersModel(nn.Module):
    def __init__(self):
        super(SimpleDefendersModel, self).__init__()
        # self.flatten = nn.Flatten()
        self.encode = nn.functional.one_hot
        self.nn = nn.Sequential(
            nn.Linear(BOARD_WIDTH*BOARD_WIDTH*4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, LABEL_SIZE),
            nn.Softmax(dim=-1)  # Softmax to get probabilities for each move
        )

    def forward(self, x):
        # x = self.flatten(x)
        x = self.encode(x, num_classes=4)
        x = x.type(torch.float32).flatten().reshape(1,-1)
        logits = self.nn(x)
        return logits
    
model = SimpleDefendersModel().to(device)



def find_ai_move_fast(game, player):
    # Flatten the board and get the value of each piece
    flat_board = [piece.value + 1 for row in game.board for piece in row]
    x = torch.tensor(flat_board, dtype=torch.long).reshape(-1, 1)

    # Model prediction
    logits = model(x)
    logits = logits.flatten()

    while True:

        idx = torch.argmax(logits).item()
        if logits[idx] == 0:
            assert False, "No valid moves found"

        # Recover indices
        move = label_index_by_move[idx]

        if game.is_valid_move(move.from_pos, move.to_pos):
            return move
        else:
            # Set the probability of this move to zero and continue
            logits[idx] = 0
            continue

game = Game()
game.fill_board_13_by_13()

def test_ai_move_fast():
    move = find_ai_move_fast(game, Player.DEFENDER)

benchmark_function(test_ai_move_fast, 1000)

Average time per iteration: 6.4307 ms


Adding a hidden layer increases the time by 3 ms, doubling the time