In [1]:
import os
import sys
sys.path.insert(0, "..")
import math

from tqdm import tqdm

import torch

from my_datasets import *

In [16]:
A = torch.rand(4,4)
X = torch.rand(4,4)
B = torch.rand(4,4)

In [38]:
S = (A @ X @ B).T.contiguous().view(-1) - torch.kron(B.T.contiguous(), A) @ X.T.contiguous().view(-1)
S, S.abs().sum()

(tensor([ 2.3842e-07,  0.0000e+00,  0.0000e+00,  1.1921e-07,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  1.1921e-07,  0.0000e+00,  0.0000e+00,
         -1.1921e-07,  1.1921e-07,  2.3842e-07,  0.0000e+00, -2.3842e-07,
          2.3842e-07]),
 tensor(1.4305e-06))

In [64]:
U = (A @ X @ B).T.contiguous().view(-1) - torch.kron(B, A.T.contiguous()).T.contiguous() @ X.T.contiguous().view(-1)
U

tensor([ 2.3842e-07,  0.0000e+00,  0.0000e+00,  1.1921e-07,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  1.1921e-07,  0.0000e+00,  0.0000e+00,
        -1.1921e-07,  1.1921e-07,  2.3842e-07,  0.0000e+00, -2.3842e-07,
         2.3842e-07])

In [13]:
n = 32
dataset = SmallTfSuccTokensDataset(n, 10000, ante_prob=0.5, conseq_prob=0.5)

running_hits = 0
running_state_mean = 0
running_succ_mean = 0
running_succ_when_state_nonzero = 0
num_nonzero_states = 0

pbar = tqdm(range(1, 1+len(dataset)))

for i in pbar:
    item = dataset[i]
    rules, succ = item["tokens"], item["labels"]
    state = rules[-1][-n:]

    succ1, hits = step_rules(rules.unsqueeze(0), state.unsqueeze(0))
    
    assert (succ1.view(-1).long() == succ).sum() == n

    running_hits += hits.sum()
    running_state_mean += state.float().mean()
    running_succ_mean += succ.float().mean()

    avg_hits = running_hits / i
    avg_state_mean = running_state_mean / i
    avg_succ_mean = running_succ_mean / i

    if state.sum() > 0:
        num_nonzero_states += 1
        running_succ_when_state_nonzero += succ.float().mean()
        avg_succ_when_state_nonzero = running_succ_when_state_nonzero / num_nonzero_states

    if num_nonzero_states == 0:
        avg_succ_when_state_nonzero = 0
    
    desc_str = f"n {n} hits {avg_hits:.3f}, " \
        + f"state mean {avg_state_mean:.3f}, " \
        + f"succ mean {avg_succ_mean:.3f}, " \
        + f"nnzsucc mean {avg_succ_when_state_nonzero:.3f}"

    if i % 100 == 0:
        pbar.set_description(desc_str)

    

n 32 hits 3.087, state mean 0.395, succ mean 0.532, nnzsucc mean 0.839: 100%|██████████| 10000/10000 [00:02<00:00, 3650.36it/s]


In [3]:
succ - succ1

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])

In [4]:
def test_rules(n, r, ap, bp):
    running_hits = 0
    running_state_mean = 0
    running_succ_mean = 0
    
    pbar = tqdm(range(1,501))
    for i in pbar:
        a = (torch.rand(r, n) < ap).long()
        b = (torch.rand(r, n) < bp).long()
        rules = torch.cat([a, b], dim=1).long()
        
        state = (a * (torch.rand(r,1) < (1/r))).sum(dim=0).clamp(0,1)
    
        succ, hits = step_rules(rules.unsqueeze(0), state.unsqueeze(0))
        
        running_hits += hits.sum()
        running_state_mean += state.float().mean()
        running_succ_mean += succ.float().mean()
    
        avg_hits = running_hits / i
        avg_state_mean = running_state_mean / i
        avg_succ_mean = running_succ_mean / i
        
        pbar.set_description(f"n {n}: hits {avg_hits:.3f}, " \
                    + f"state mean {avg_state_mean:.3f}, " \
                    + f"succ mean {avg_succ_mean:.3f}"
        )
    return {
        "n": n,
        "r": r,
        "avg_hits": avg_hits,
        "avg_state_mean": avg_state_mean,
        "avg_succ_mean": avg_succ_mean
    }

In [5]:
torch.manual_seed(101)

for n in [10, 100, 1000]:
    # r = int(n * math.sqrt(n))
    r = 2 * n
    ap, bp = 0.5, 0.5
    # ap, bp = 1 / math.sqrt(n), 1 / math.sqrt(n)
    test_rules(n, r, ap, bp)

n 10: hits 3.260, state mean 0.423, succ mean 0.598: 100%|██████████| 500/500 [00:00<00:00, 2354.40it/s]
n 100: hits 0.988, state mean 0.367, succ mean 0.504: 100%|██████████| 500/500 [00:00<00:00, 1328.24it/s]
n 1000: hits 0.966, state mean 0.381, succ mean 0.513: 100%|██████████| 500/500 [00:22<00:00, 22.68it/s]


In [6]:
torch.manual_seed(101)

for n in [10, 30, 50]:
    r = int(n * math.sqrt(n))
    ap, bp = 0.5, 0.5
    # ap, bp = 1 / math.sqrt(n), 1 / math.sqrt(n)
    test_rules(n, r, ap, bp)

n 10: hits 4.312, state mean 0.404, succ mean 0.588: 100%|██████████| 500/500 [00:00<00:00, 1843.23it/s]
n 30: hits 4.072, state mean 0.403, succ mean 0.548: 100%|██████████| 500/500 [00:00<00:00, 2027.71it/s]
n 50: hits 4.222, state mean 0.393, succ mean 0.529: 100%|██████████| 500/500 [00:00<00:00, 1678.38it/s]


In [7]:
torch.manual_seed(102)

n = 40
r = int(n * math.sqrt(n))
ap, bp = 1 / math.sqrt(n), 1 / math.sqrt(n)
test_rules(n, r, ap, bp)

n 40: hits 2.236, state mean 0.144, succ mean 0.324: 100%|██████████| 500/500 [00:00<00:00, 1916.14it/s]


{'n': 40,
 'r': 252,
 'avg_hits': tensor(2.2360),
 'avg_state_mean': tensor(0.1437),
 'avg_succ_mean': tensor(0.3240)}

In [8]:
s

NameError: name 's' is not defined