In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from tell import LogicalLayer

In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x7ff035a6f4d0>

In [4]:
x = torch.rand([1000, 10]).float()
x = torch.hstack([x,1-x])
y = ((x[:, 1] < 0.3)&(x[:, 2] > 0.5))|((x[:, 2] > 0.5)&(x[:, 3] > 0.7))
y = y.float()

In [5]:
x_test = x[int(x.shape[0]*0.8):]
x_train = x[:int(x.shape[0]*0.8)]
y_test = y[int(x.shape[0]*0.8):]
y_train = y[:int(x.shape[0]*0.8)]

In [6]:
def train(model, x_train, y_train, x_test, y_test, epochs=3000):
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze(-1)
        loss = loss_form(y_pred, y_train)
        for layer in model.children():
            if hasattr(layer, "weight"):
                loss += 0.01*layer.weight_s.sum(-1).mean()

        loss.backward()

        for p in model.parameters():
            if p.grad is None: continue
            # p.grad = torch.clamp(p.grad, -1, 1)
            p.grad = torch.where(p.grad.isnan(), torch.zeros_like(p.grad), p.grad)
            p.grad = torch.where(p.grad.isinf(), torch.zeros_like(p.grad), p.grad)

        acc = ((y_pred > 0.5).long() == y_train).float().mean()
        test_acc = ((model(x_test).squeeze(-1) > 0.5).long() == y_test).float().mean()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"epoch: {epoch}, loss: {loss.item():.4f}, acc: {acc.item():.4f}, test_acc: {test_acc.item():.4f}")
            
    rules = model[0].extract_rules()[0]

    y_rule = torch.zeros_like(y_train).bool()
    y_test_rule = torch.zeros_like(y_test).bool()
    for rule in rules:
        rule = list(rule)
        y_rule |= ((x_train[:, rule] > model[0].phi_in.t[rule]).float().prod(-1) > 0.5)
        y_test_rule |= ((x_test[:, rule] > model[0].phi_in.t[rule]).float().prod(-1) > 0.5)


    rule_acc = (y_rule == y_train).float().mean()
    rule_test_acc = (y_test_rule == y_test).float().mean()

    print(rules)

    print(f"rule acc: {rule_acc:.4f}, rule test acc: {rule_test_acc:.4f}")



In [7]:
layers = [
    LogicalLayer(x.shape[1], 1, dummy_phi_in=False),
]

device = torch.device('cpu')


x_train = x_train.to(device)
y_train = y_train.to(device)
x_test = x_test.to(device)
y_test = y_test.to(device)
model = torch.nn.Sequential(*layers).to(device)

train(model, x_train, y_train, x_test, y_test)

epoch: 0, loss: 1.3454, acc: 0.4588, test_acc: 0.4050
epoch: 100, loss: 0.2841, acc: 0.9087, test_acc: 0.9050
epoch: 200, loss: 0.2049, acc: 0.9538, test_acc: 0.9500
epoch: 300, loss: 0.1550, acc: 0.9750, test_acc: 0.9550
epoch: 400, loss: 0.1233, acc: 0.9825, test_acc: 0.9650
epoch: 500, loss: 0.1025, acc: 0.9862, test_acc: 0.9650
epoch: 600, loss: 0.0884, acc: 0.9912, test_acc: 0.9700
epoch: 700, loss: 0.0786, acc: 0.9912, test_acc: 0.9750
epoch: 800, loss: 0.0715, acc: 0.9925, test_acc: 0.9750
epoch: 900, loss: 0.0661, acc: 0.9925, test_acc: 0.9750
epoch: 1000, loss: 0.0505, acc: 0.9950, test_acc: 0.9750
epoch: 1100, loss: 0.0454, acc: 0.9937, test_acc: 0.9750
epoch: 1200, loss: 0.0413, acc: 0.9937, test_acc: 0.9750
epoch: 1300, loss: 0.0374, acc: 0.9937, test_acc: 0.9750
epoch: 1400, loss: 0.0343, acc: 0.9937, test_acc: 0.9800
epoch: 1500, loss: 0.0315, acc: 0.9937, test_acc: 0.9800
epoch: 1600, loss: 0.0291, acc: 0.9950, test_acc: 0.9800
epoch: 1700, loss: 0.0269, acc: 0.9950, tes

In [8]:
def transform_to_prop_logic(symbols, rules, thresholds):
    result = []
    for rule in rules:
        elements = [f'{symbols[i]} > {thresholds[i]:.3}' if '~' not in symbols[i] else f'{symbols[i][1:]} < {1-thresholds[i]}'  for i in rule]
        conjunction = " & ".join(elements)
        result.append('('+conjunction+')')

    prop_logic = " | ".join(result)
    return prop_logic

In [9]:
symbols = [f'x{i}' for i in range(x.shape[1]//2)] + [f'~x{i}' for i in range(x.shape[1]//2)]

In [10]:
transform_to_prop_logic(symbols, model[0].extract_rules()[0], model[0].phi_in.t)

'(x2 > 0.482 & x3 > 0.697) | (x2 > 0.482 & x1 < 0.3039592504501343)'