In [1]:
import utilities
import torch
t = utilities.Tokenizer()
import numpy as np
import random

In [2]:
a = torch.tensor([[1, 0, 0],[0,0,0], [-2, -2, -2]])
b = torch.tensor([[-2,1,0],[0,1,1], [-1, 1, -1]])

In [3]:
x = t.tokenize(a)
y = t.tokenize(b)
print(a, b)
print(x, y)
w = t.detokenize(x)
x = t.detokenize(y)

print(w, x)

tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]]) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]])
tensor([63, 62,  0], dtype=torch.int32) tensor([65, 92, 41], dtype=torch.int32)
tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]], dtype=torch.int32) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]], dtype=torch.int32)


In [4]:
def generate_sample_r1(S: int, vals: list[int], factor_dist: list[float]):
    nonzero = False
    while not nonzero:
        t = np.random.choice(vals, size=(3, S), p=factor_dist)
        m = np.tensordot(np.tensordot(t[0, :], t[1, :], axes=0), t[2, :], axes=0)
        assert m.shape == (S, S, S)
        nonzero = np.any(m)
    return t, m

In [5]:
S = 3
vals = [-2, -1, 0, 1, 2]
factor_dist = [0.1, 0.2, 0.4, 0.2, 0.1]

t, m = generate_sample_r1(S, vals, factor_dist)

print(f"tensors {t}")
print(f"result {m}")

tensors [[-1  0  0]
 [ 1  1  0]
 [ 1  0  2]]
result [[[-1  0 -2]
  [-1  0 -2]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]]


In [12]:
def main(S: int, r_limit: int, factor_dist: dict, N: int, seed: int = None):
    if seed is not None:
        random.seed(seed)

    low, high = min(factor_dist.keys()), max(factor_dist.keys())
    vals = list(factor_dist.keys())
    dist = [factor_dist[i] for i in factor_dist.keys()]

    
    tokenizer = utilities.Tokenizer(range=(low, high))

    SAR_pairs = []
    
    for i in range(N):
        R = random.randint(1, r_limit)
        T = torch.zeros((S, S, S), dtype=torch.int)
        reward = 0
        for j in range(R):
            sample, m = generate_sample_r1(S, vals, dist)
            T += torch.from_numpy(m)
            tokens = tokenizer.tokenize(torch.from_numpy(sample.T))
            reward += -1
            SAR_pairs.append((T, tokens, reward))

    return SAR_pairs
            

In [14]:
print(main(3, 10, {-2: .1, -1 : .2, 0: 0.4, 1: 0.2, 2: 0.1}, 10, seed=0))

[(tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([109,  62,  28], dtype=torch.int32), -1), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([38, 58, 56], dtype=torch.int32), -2), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([93, 37, 65], dtype=torch.int32), -3), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   