In [1]:
import sys
sys.path.insert(0, "..")

from tqdm import tqdm
import torch
from my_datasets import *

In [2]:
n, r, ap, bp, cl = 8, 160, 0.5, 0.5, 4

In [3]:
# def compose(rules1, rules2):
#     all_as, all_bs = [], []
#     for r1 in rules1:
#         a1, b1 = r1.chunk(2)
#         for r2 in rules2:
#             a2, b2 = r2.chunk(2)
#             # new_a = (a1 + a2 - b1).clamp(0,1)
#             new_a = (a1 + (a2 - b1).clamp(0,1)).clamp(0,1)
#             new_b = (b1 + b2).clamp(0,1)
#             all_as.append(new_a)
#             all_bs.append(new_b)
#     all_as = torch.stack(all_as)
#     all_bs = torch.stack(all_bs)
#     return torch.cat([all_as, all_bs], dim=1).long()

def compose(rules1, rules2):
    all_as, all_bs = [], []
    for r1 in rules1:
        a, b = r1.chunk(2)
        z, _ = step_rules(rules1[None,...], a.view(1,-1))
        new_b, _ = step_rules(rules2[None,...], z.view(1,-1))
        all_as.append(a)
        all_bs.append(new_b)
    all_as = torch.stack(all_as)
    all_bs = torch.cat(all_bs, dim=0)
    return torch.cat([all_as, all_bs], dim=1).long()

In [4]:
torch.manual_seed(1234)
pbar = tqdm(range(100))
for i in pbar:
    dictA = random_rules_with_chain(
        num_rules = r,
        num_vars = n,
        ante_prob = ap,
        conseq_prob = bp,
        chain_len = cl,
        return_dict = True)
    rulesA = dictA["rules"]

    dictB = random_rules_with_chain(
        num_rules = r,
        num_vars = n,
        ante_prob = ap,
        conseq_prob = bp,
        chain_len = cl,
        return_dict = True)
    rulesB = dictB["rules"]

    rulesB[:,0:n] = rulesA[:,0:n]

    s0 = torch.zeros(1,n)
    # s0 = torch.randint(0,2,(1,n))

    rulesAA = compose(rulesA, rulesA)
    rulesAAA1 = compose(rulesA, rulesAA)
    rulesAAA2 = compose(rulesAA, rulesA)
    rulesAAAA = compose(rulesAA, rulesAA)
    rulesAB = compose(rulesA, rulesB)

    sa, _ = step_rules(rulesA[None,...], s0)
    saa, _ = step_rules(rulesA[None,...], sa)
    sab, _ = step_rules(rulesB[None,...], sa)
    saaa = kstep_rules(rulesA[None,...], s0, num_steps=3)
    saaaa = kstep_rules(rulesA[None,...], s0, num_steps=4)
                      
    scaa, _ = step_rules(rulesAA[None,...], s0)
    scab, _ = step_rules(rulesAB[None,...], s0)

    scaaa1, _ = step_rules(rulesAAA1[None,...], s0)
    scaaa2, _ = step_rules(rulesAAA2[None,...], s0)
    scaaaa, _ = step_rules(rulesAAAA[None,...], s0)

    aa_diff = (saa - scaa).abs().sum()
    ab_diff = (sab - scab).abs().sum()

    aaa1_diff = (saaa - scaaa1).abs().sum()
    aaa2_diff = (saaa - scaaa2).abs().sum()
    aaaa_diff = (saaaa - scaaaa).abs().sum()

    if aa_diff > 0 or ab_diff > 0 or aaa1_diff > 0 or aaa2_diff > 0 or aaaa_diff > 0:
        break

100%|██████████| 100/100 [00:07<00:00, 13.34it/s]


In [5]:
s0

tensor([[0., 0., 0., 0., 0., 0., 0., 0.]])

In [6]:
scaa

tensor([[1, 1, 0, 0, 0, 1, 1, 1]])

In [7]:
saa

tensor([[1, 1, 0, 0, 0, 1, 1, 1]])

In [8]:
qed = prove_theorem(rulesA[None,...], torch.ones(1,n))
qed

{'rules': tensor([[[0, 0, 0,  ..., 0, 1, 1],
          [0, 0, 0,  ..., 1, 0, 1],
          [0, 1, 0,  ..., 1, 0, 1],
          ...,
          [0, 1, 0,  ..., 0, 1, 0],
          [1, 1, 0,  ..., 0, 0, 0],
          [1, 1, 0,  ..., 0, 0, 0]]]),
 'theorem': tensor([[1., 1., 1., 1., 1., 1., 1., 1.]]),
 'qed': tensor([0]),
 'states': tensor([[[1, 0, 0, 0, 0, 1, 1, 1],
          [1, 1, 0, 0, 0, 1, 1, 1],
          [1, 1, 1, 0, 0, 1, 1, 1],
          [1, 1, 1, 0, 1, 1, 1, 1],
          [1, 1, 1, 0, 1, 1, 1, 1],
          [1, 1, 1, 0, 1, 1, 1, 1],
          [1, 1, 1, 0, 1, 1, 1, 1],
          [1, 1, 1, 0, 1, 1, 1, 1]]]),
 'hits': tensor([[[1, 0, 0,  ..., 0, 0, 0],
          [1, 1, 0,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]]]),
 'chain_len': tensor([4])}

In [9]:
dictA

{'rules': tensor([[0, 0, 0,  ..., 0, 1, 1],
         [0, 0, 0,  ..., 1, 0, 1],
         [0, 1, 0,  ..., 1, 0, 1],
         ...,
         [0, 1, 0,  ..., 0, 1, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0]]),
 'init_state': tensor([0., 0., 0., 0., 0., 0., 0., 0.]),
 'chain_bits': tensor([6, 1, 2, 4]),
 'other_bits': tensor([7, 5, 0]),
 'bad_bit': tensor(3)}

In [10]:
rulesA

tensor([[0, 0, 0,  ..., 0, 1, 1],
        [0, 0, 0,  ..., 1, 0, 1],
        [0, 1, 0,  ..., 1, 0, 1],
        ...,
        [0, 1, 0,  ..., 0, 1, 0],
        [1, 1, 0,  ..., 0, 0, 0],
        [1, 1, 0,  ..., 0, 0, 0]])