In [68]:
import torch
import itertools
from tqdm import tqdm
import os

In [69]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
device

'cuda'

In [70]:
def generate_all_matrix_combinations(n):
    possible_values = [-1, 0, 1]
    combinations = itertools.product(possible_values, repeat=n*n)
    matrices = []
    for combination in combinations:
        tensor = torch.tensor(combination).view(n, n)
        matrices.append(tensor)
    return matrices

# Example usage:
n = 3
all_matrix_combinations = torch.stack(generate_all_matrix_combinations(n))
(all_matrix_combinations.shape[0]) == 3 ** 9

True

In [71]:
sums = all_matrix_combinations.sum(dim=1).sum(dim=1)
bi_turns_matrxies = all_matrix_combinations[ (sums == 0) | (sums == 1)]
bi_turns_matrxies.shape

torch.Size([6046, 3, 3])

In [76]:
bi_turns_matrxies[0]

tensor([[-1, -1, -1],
        [-1,  0,  1],
        [ 1,  1,  1]])

In [92]:
streaks = torch.cat((bi_turns_matrxies.sum(dim = 1), bi_turns_matrxies.sum(dim = 2)), dim = 1)
pos_won = torch.any(streaks == 3, dim = 1)
neg_won = torch.any(streaks == -3, dim = 1)
~(pos_won & neg_won), streaks

(tensor([False, False,  True,  ...,  True, False, False]),
 tensor([[-1,  0,  1, -3,  0,  3],
         [-1,  1,  0, -3,  0,  3],
         [-2,  1,  1, -3,  1,  2],
         ...,
         [ 2,  0, -1,  3,  0, -2],
         [ 1,  0,  0,  3,  1, -3],
         [ 1,  1, -1,  3,  1, -3]]))

In [110]:
final_matrxies = bi_turns_matrxies[~(pos_won & neg_won)]
final_matrxies.shape

torch.Size([5890, 3, 3])

In [111]:
import torch

# Example game_states_tensor
game_states_tensor = torch.tensor([
    [[1, 0, 1],
     [0, 1, 0],
     [1, 0, 1]],

    [[1, 1, 0],
     [1, 0, 1],
     [1, 1, 1]],

    [[1, 1, 1],
     [1, 0, 1],
     [1, 1, 1]],

    [[1, 0, 1],
     [1, 0, 1],
     [1, 0, 1]]
     
])
size =game_states_tensor.shape[0]
game_size = game_states_tensor.shape[-1]
a = game_states_tensor.repeat(size,1,1).reshape(size,size,game_size,game_size)
b = game_states_tensor.repeat(size,1,1).reshape(size,size,game_size,game_size).permute(1,0,2,3)
c = torch.stack((a,b), dim = 2)
c.shape

torch.Size([4, 4, 2, 3, 3])

In [112]:
c[3][2]

tensor([[[1, 1, 1],
         [1, 0, 1],
         [1, 1, 1]],

        [[1, 0, 1],
         [1, 0, 1],
         [1, 0, 1]]])

In [113]:
parent = c[2:,:,0]
child = c[2:,:,1]
child.shape

torch.Size([2, 4, 3, 3])

In [114]:
add_one_pice = torch.sum((parent == 0) & (child == 1), dim = (-2,-1)) == 1
remove_none = torch.sum(~(parent == child), dim = (-2,-1)) == 1
c[2:][remove_none & add_one_pice]

tensor([[[[1, 1, 0],
          [1, 0, 1],
          [1, 1, 1]],

         [[1, 1, 1],
          [1, 0, 1],
          [1, 1, 1]]]])

In [115]:
def make_combi(game_states_tensor):
    size =game_states_tensor.shape[0]
    game_size = game_states_tensor.shape[-1]
    a = game_states_tensor.to(device).repeat(size,1,1).reshape(size,size,game_size,game_size)
    b = game_states_tensor.to(device).repeat(size,1,1).reshape(size,size,game_size,game_size).permute(1,0,2,3)
    return a, b
def gather_edges(parent, child):
    parent = parent.to(device)
    child = child.to(device)
    add_one_pice = torch.sum((parent == 0) & ((child == 1) | (child == -1)), dim = (-2,-1)) == 1
    one_move_made = torch.sum(~(parent == child), dim = (-2,-1)) == 1
    edges = torch.stack((parent[one_move_made & add_one_pice], child[one_move_made & add_one_pice]),dim = 1)
    return edges
gather_edges(*make_combi(game_states_tensor)).shape

torch.Size([1, 2, 3, 3])

In [116]:
def save_tensor_chunks(tensor, chunk_size, directory, prefix='chunk'):
    num_chunks = tensor.size(0) // chunk_size
    remainder = tensor.size(0) % chunk_size
    if remainder > 0:
        num_chunks += 1

    for i in range(num_chunks):
        start_idx = i * chunk_size
        end_idx = min((i + 1) * chunk_size, tensor.size(0))
        chunk_tensor = tensor[start_idx:end_idx]
        chunk_filename = f'{prefix}_{i}.pth'
        chunk_path = os.path.join(directory, chunk_filename)
        torch.save(chunk_tensor, chunk_path)

In [117]:
torch.cuda.empty_cache()
edges = gather_edges(*make_combi(final_matrxies))
edges.shape

torch.Size([18459, 2, 3, 3])

In [118]:
def to_tuple(t):

    return to_tuple_prim(t.tolist())

def to_tuple_prim(ls):

    if type(ls[0]) != list:
        return tuple(ls)
    
    return tuple([to_tuple_prim(r) for r in ls])


In [119]:
to_tuple(edges[0][0])

((-1, -1, -1), (0, 1, 1), (0, 1, 1))

In [120]:
matrix2id = {to_tuple(m):i for i, m in enumerate(final_matrxies)}
id2matrix = {i:to_tuple(m) for i, m in enumerate(final_matrxies)}
matrix2id

{((-1, -1, -1), (-1, 1, 1), (0, 1, 1)): 0,
 ((-1, -1, -1), (-1, 1, 1), (1, 0, 1)): 1,
 ((-1, -1, -1), (-1, 1, 1), (1, 1, 0)): 2,
 ((-1, -1, -1), (0, 0, 1), (0, 1, 1)): 3,
 ((-1, -1, -1), (0, 0, 1), (1, 0, 1)): 4,
 ((-1, -1, -1), (0, 0, 1), (1, 1, 0)): 5,
 ((-1, -1, -1), (0, 1, 0), (0, 1, 1)): 6,
 ((-1, -1, -1), (0, 1, 0), (1, 0, 1)): 7,
 ((-1, -1, -1), (0, 1, 0), (1, 1, 0)): 8,
 ((-1, -1, -1), (0, 1, 1), (-1, 1, 1)): 9,
 ((-1, -1, -1), (0, 1, 1), (0, 0, 1)): 10,
 ((-1, -1, -1), (0, 1, 1), (0, 1, 0)): 11,
 ((-1, -1, -1), (0, 1, 1), (0, 1, 1)): 12,
 ((-1, -1, -1), (0, 1, 1), (1, -1, 1)): 13,
 ((-1, -1, -1), (0, 1, 1), (1, 0, 0)): 14,
 ((-1, -1, -1), (0, 1, 1), (1, 0, 1)): 15,
 ((-1, -1, -1), (0, 1, 1), (1, 1, -1)): 16,
 ((-1, -1, -1), (0, 1, 1), (1, 1, 0)): 17,
 ((-1, -1, -1), (1, -1, 1), (0, 1, 1)): 18,
 ((-1, -1, -1), (1, -1, 1), (1, 0, 1)): 19,
 ((-1, -1, -1), (1, -1, 1), (1, 1, 0)): 20,
 ((-1, -1, -1), (1, 0, 0), (0, 1, 1)): 21,
 ((-1, -1, -1), (1, 0, 0), (1, 0, 1)): 22,
 ((-1, -1, -

In [121]:
children_ids = {}
parents_ids = {}
for r in edges:
    parent = matrix2id[to_tuple(r[0])]
    child = matrix2id[to_tuple(r[1])]

    children = children_ids.get(parent, [])
    children.append(child)
    children_ids[parent] = children

    parents = parents_ids.get(child, [])
    parents.append(parent)
    parents_ids[child] = parents


children_ids

{12: [0, 9],
 59: [0, 56],
 409: [0, 406],
 1750: [0, 1747],
 15: [1, 13],
 62: [1, 60],
 412: [1, 410],
 1753: [1, 1751],
 17: [2, 16],
 64: [2, 63],
 414: [2, 413],
 1755: [2, 1754],
 78: [3, 47, 66, 75],
 428: [3, 397, 416, 425],
 1769: [3, 1738, 1757, 1766],
 81: [4, 48, 67, 79],
 431: [4, 398, 417, 429],
 1772: [4, 1739, 1758, 1770],
 83: [5, 49, 68, 82],
 433: [5, 399, 418, 432],
 1774: [5, 1740, 1759, 1773],
 91: [6, 52, 84, 88],
 441: [6, 402, 434, 438],
 1782: [6, 1743, 1775, 1779],
 94: [7, 53, 85, 92],
 444: [7, 403, 435, 442],
 1785: [7, 1744, 1776, 1783],
 96: [8, 54, 86, 95],
 446: [8, 404, 436, 445],
 1787: [8, 1745, 1777, 1786],
 99: [9, 56],
 449: [9, 406],
 1790: [9, 1747],
 102: [10, 57, 97, 100],
 452: [10, 407, 447, 450],
 1793: [10, 1748, 1788, 1791],
 104: [11, 58, 98, 103],
 454: [11, 408, 448, 453],
 1795: [11, 1749, 1789, 1794],
 3: [12, 27],
 6: [12, 39],
 10: [12, 15],
 11: [12, 17],
 106: [13, 60],
 456: [13, 410],
 1797: [13, 1751],
 108: [14, 61, 105, 107

In [122]:
print(torch.tensor(id2matrix[2524]))
print()
for c in children_ids[2524]:
    print(torch.tensor(id2matrix[c]))

tensor([[ 0,  0, -1],
        [ 1,  0,  1],
        [-1,  1,  0]])

tensor([[-1,  0, -1],
        [ 1,  0,  1],
        [-1,  1,  0]])
tensor([[ 0, -1, -1],
        [ 1,  0,  1],
        [-1,  1,  0]])
tensor([[ 0,  0, -1],
        [ 1, -1,  1],
        [-1,  1,  0]])
tensor([[ 0,  0, -1],
        [ 1,  0,  1],
        [-1,  1, -1]])


In [123]:
nodes_ids = (set(children_ids.keys()) | set(parents_ids.keys()))
len(nodes_ids)

5890

In [124]:
set([1,2,3]) - set([1,2,5])

{3}

In [129]:
out_liers = list(set(parents_ids.keys()) - set(children_ids.keys()))
len(out_liers)

198

In [145]:
torch.tensor(id2matrix[out_liers[17]])

tensor([[-1,  1,  1],
        [-1,  1,  1],
        [-1, -1,  0]])

In [None]:
scores = {}
visits = {}

In [None]:
def score_node()

In [None]:
for o in out_liers:
    parent = parents_ids[o]
    while True