In [1]:
import sys
import os
sys.path.append(os.path.abspath('../src')) # include top level package in python path

In [2]:
%env CUDA_LAUNCH_BLOCKING=1
import torch
from torch import nn
from torch.utils.data import DataLoader
from data.data_synth import DNFDataset
from model import fuzzy_logic
from model import embed_logic
from model.fuzzy_layer import FuzzyDNF, FuzzySignedConjunction, FuzzyUnsignedDisjunction, FuzzyLoss
from model.embed_layer import EmbedDNF
from cache import TestMetric, TrainingRegime
from plot import plot_loss, plot_bit_density
from dnf import format_dnf, format_vars
import matplotlib.pyplot as plt

env: CUDA_LAUNCH_BLOCKING=1


In [3]:
%%capture
from tqdm.notebook import tqdm
tqdm().pandas()

In [4]:
#device = 'cpu'
device = 'cuda'

In [5]:
class OutputDistanceMetric(TestMetric):
    def name(self):
        return "output-distance"
    
    def measure_model(self, model, it):
        count = 0
        correct = 0
        for xs, y in it():
            xs = xs.to(device)
            y = y.to(device)
            pred = model(xs).squeeze()
            correct += (pred - y.float()).abs().sum().item()
            count += y.numel()
        return (correct / count)
    
class ClassBalanceMetric(TestMetric):
    def __init__(self):
        super().__init__()
        self.balance = None
    
    def name(self):
        return "class-balance"
    
    def measure_model(self, model, it):
        if self.balance is None:
            total_ones = 0
            total = 0
            for _, y in it():
                total_ones += y.sum()
                total += y.numel()
            self.balance = (total_ones / total).item()
            print(self.balance)
        return self.balance
    
class OutputAverageMetric(TestMetric):
    def name(self):
        return "output-avg"
    
    def measure_model(self, model, it):
        count = 0
        total = 0
        for xs, _ in it():
            pred = model(xs.to(device)).squeeze()
            total += pred.sum().item()
            count += pred.numel()
        return (total / count)
    
class CrispnessMetric(TestMetric):
    def name(self):
        return "crispness"

    def measure_model(self, model, it):
        try:
            return model.crispness()
        except:
            return 0.0

In [6]:
class DNFRegime(TrainingRegime):
    def __init__(self, name, no_dims, new_model_f, no_runs=1, lr=1e-2):
        super().__init__("./dnfs/", no_runs)
        
        self.name = name
        self.no_dims = no_dims
        self.new_model_f = new_model_f
        self.lr = lr
        
        self.dataset = DNFDataset(no_dims, 50_000, 0.5, 8, 0.05)
        self.dataset.set_data(
            *self.cache(
                "weights", 
                lambda: self.dataset.get_data()
            )
        )
        
        self.test_dataset = DNFDataset(no_dims, 5_000, 0.5, 8, 0.0)
        self.test_dataset.set_data(
            *self.dataset.get_data()
        )
        
        self.tests = [
            OutputDistanceMetric(),
            CrispnessMetric(),
            OutputAverageMetric(),
            ClassBalanceMetric(),
        ]
        
        self.optims = [None] * no_runs
        
    def get_optim(self, run_no):
        optim = self.optims[run_no - 1]
        if optim is None:
            model = self.get_loaded_model(run_no)
            optim = self.optims[run_no - 1] = (
                torch.optim.Adam(model.parameters(), lr=self.lr)
            )
        return optim
        
    def regime_str(self):
        return (
            "%s | %s DIMS | ADAM, LR = %s"
            % (
                self.name.upper(), 
                self.no_dims,
                self.lr,
                
            )
        ) 
        
    def regime_filename_elems(self):
        return [
            "dnf", 
            self.name, 
            "%sdim" % str(self.no_dims), 
            "%slr" % str(self.lr),
        ]
        
    def training_dataloader(self, run_no):
        return DataLoader(self.dataset, batch_size=128)
    
    def testing_dataloader(self, run_no):
        return DataLoader(self.test_dataset, batch_size=128)
    
    def training_step(self, run_no, model):
        optim = self.get_optim(run_no)
        #oss_fn = FuzzyLoss(logic=model.logic, exp=2)
    
        loss_fn = torch.nn.BCELoss()
        
        def step(tup):
            bs, y = tup
            bs = bs.to(device)
            y_hats = model(bs).squeeze() 
            loss = loss_fn(y_hats, y.to(device))
                
            optim.zero_grad()
            loss.backward()
            optim.step()
            
        return step
        
    def new_model(self):
        return self.new_model_f()

In [7]:
class FuzzyNN(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.logic = fuzzy_logic.SchweizerSklarLogic(torch.tensor(-2.0).to(device))
        self.conjs = FuzzySignedConjunction(
            in_features=in_features, out_features=16, logic=self.logic
        )
        self.disj = FuzzyUnsignedDisjunction(
            in_features=16, out_features=1, logic=self.logic,
        )
        self.model = nn.Sequential(
            self.conjs,
            self.disj
        )
        self.logsumexp = nn.Sequential(
            self.conjs,
        )

    def forward(self, input):
        return self.model(input).squeeze()
            
    def logsumexp_forward(self, input):
        return self.conjs(input).logsumexp(dim=-1)
    
    def crispness(self):
        cjw = self.conjs.weights.value()
        cjs = self.conjs.signs.value()
        djw = self.disj.weights.value()
        device = cjw.device
        numel = torch.tensor(0.0, device=cjw.device)
        off_crisp = torch.tensor(0.0, device=cjw.device)
        off_crisp += torch.min(cjw, 1 - cjw).sum() 
        off_crisp += torch.min(cjs, 1 - cjs).sum()
        off_crisp += torch.min(djw, 1 - djw).sum()
        numel = cjw.numel() + cjs.numel() + djw.numel()
        return (off_crisp / numel).item()

In [14]:
def defer_regime(*args, **kargs):
    def return_regime():
        return DNFRegime(*args, **kargs)
    return return_regime

def new_fuzzy_f():
    return FuzzyNN(50).to(device)

regimes = [
    defer_regime("fuzzynn50_2", no_dims=50, new_model_f=new_fuzzy_f, lr=1e-3, no_runs=1),
]

torch.autograd.set_detect_anomaly(True)
if True:
    for regime_f in regimes:
        regime = regime_f()
        regime.load_latest_models()
        regime.load_all_results()
        for i in range(1, regime.no_runs + 1):
            for j in range(1, int(1 + (40/3))):
                regime.loop_until(i, min(3*j, 40))
        del regime

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #1, Training:   0%|          | 0/391 [00:00<?, ?it/s]

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #1, Testing [output-distance]:   0%|          | 0/40 …

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #1, Testing [output-avg]:   0%|          | 0/40 [00:0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #1, Testing [class-balance]:   0%|          | 0/40 [0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #2, Training:   0%|          | 0/391 [00:00<?, ?it/s]

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #2, Testing [output-distance]:   0%|          | 0/40 …

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #2, Testing [output-avg]:   0%|          | 0/40 [00:0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #2, Testing [class-balance]:   0%|          | 0/40 [0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #3, Training:   0%|          | 0/391 [00:00<?, ?it/s]

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #3, Testing [output-distance]:   0%|          | 0/40 …

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #3, Testing [output-avg]:   0%|          | 0/40 [00:0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #3, Testing [class-balance]:   0%|          | 0/40 [0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #4, Training:   0%|          | 0/391 [00:00<?, ?it/s]

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #4, Testing [output-distance]:   0%|          | 0/40 …

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #4, Testing [output-avg]:   0%|          | 0/40 [00:0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #4, Testing [class-balance]:   0%|          | 0/40 [0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #5, Training:   0%|          | 0/391 [00:00<?, ?it/s]

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #5, Testing [output-distance]:   0%|          | 0/40 …

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #5, Testing [output-avg]:   0%|          | 0/40 [00:0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #5, Testing [class-balance]:   0%|          | 0/40 [0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #6, Training:   0%|          | 0/391 [00:00<?, ?it/s]

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #6, Testing [output-distance]:   0%|          | 0/40 …

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #6, Testing [output-avg]:   0%|          | 0/40 [00:0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #6, Testing [class-balance]:   0%|          | 0/40 [0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #7, Training:   0%|          | 0/391 [00:00<?, ?it/s]

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #7, Testing [output-distance]:   0%|          | 0/40 …

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #7, Testing [output-avg]:   0%|          | 0/40 [00:0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #7, Testing [class-balance]:   0%|          | 0/40 [0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #8, Training:   0%|          | 0/391 [00:00<?, ?it/s]

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #8, Testing [output-distance]:   0%|          | 0/40 …

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #8, Testing [output-avg]:   0%|          | 0/40 [00:0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #8, Testing [class-balance]:   0%|          | 0/40 [0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #9, Training:   0%|          | 0/391 [00:00<?, ?it/s]

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #9, Testing [output-distance]:   0%|          | 0/40 …

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #9, Testing [output-avg]:   0%|          | 0/40 [00:0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #9, Testing [class-balance]:   0%|          | 0/40 [0…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #10, Training:   0%|          | 0/391 [00:00<?, ?it/s…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #10, Testing [output-distance]:   0%|          | 0/40…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #10, Testing [output-avg]:   0%|          | 0/40 [00:…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #10, Testing [class-balance]:   0%|          | 0/40 […

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #11, Training:   0%|          | 0/391 [00:00<?, ?it/s…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #11, Testing [output-distance]:   0%|          | 0/40…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #11, Testing [output-avg]:   0%|          | 0/40 [00:…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #11, Testing [class-balance]:   0%|          | 0/40 […

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #12, Training:   0%|          | 0/391 [00:00<?, ?it/s…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #12, Testing [output-distance]:   0%|          | 0/40…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #12, Testing [output-avg]:   0%|          | 0/40 [00:…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #12, Testing [class-balance]:   0%|          | 0/40 […

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #13, Training:   0%|          | 0/391 [00:00<?, ?it/s…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #13, Testing [output-distance]:   0%|          | 0/40…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #13, Testing [output-avg]:   0%|          | 0/40 [00:…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #13, Testing [class-balance]:   0%|          | 0/40 […

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #14, Training:   0%|          | 0/391 [00:00<?, ?it/s…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #14, Testing [output-distance]:   0%|          | 0/40…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #14, Testing [output-avg]:   0%|          | 0/40 [00:…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #14, Testing [class-balance]:   0%|          | 0/40 […

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #15, Training:   0%|          | 0/391 [00:00<?, ?it/s…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #15, Testing [output-distance]:   0%|          | 0/40…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #15, Testing [output-avg]:   0%|          | 0/40 [00:…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #15, Testing [class-balance]:   0%|          | 0/40 […

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #16, Training:   0%|          | 0/391 [00:00<?, ?it/s…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #16, Testing [output-distance]:   0%|          | 0/40…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #16, Testing [output-avg]:   0%|          | 0/40 [00:…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #16, Testing [class-balance]:   0%|          | 0/40 […

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #17, Training:   0%|          | 0/391 [00:00<?, ?it/s…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #17, Testing [output-distance]:   0%|          | 0/40…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #17, Testing [output-avg]:   0%|          | 0/40 [00:…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #17, Testing [class-balance]:   0%|          | 0/40 […

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #18, Training:   0%|          | 0/391 [00:00<?, ?it/s…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #18, Testing [output-distance]:   0%|          | 0/40…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #18, Testing [output-avg]:   0%|          | 0/40 [00:…

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #18, Testing [class-balance]:   0%|          | 0/40 […

[FUZZYNN50_2 | 50 DIMS | ADAM, LR = 0.001] Run #1, Epoch #19, Training:   0%|          | 0/391 [00:00<?, ?it/s…

KeyboardInterrupt: 

In [15]:
reg = regimes[0]()
reg.load_latest_models()
reg.load_all_results()
torch.set_printoptions(sci_mode=False)
if reg.models[0] != None:
    print(reg.models[0].disj.weights.value())
    used_disjs = reg.models[0].disj.weights.value() > 0.5
    used_disjs = used_disjs.squeeze()
    conj_weights = reg.models[0].conjs.weights.value()[:,used_disjs]
    conj_signs = reg.models[0].conjs.signs.value()[:,used_disjs]
    print(format_dnf(conj_signs, conj_weights))
    print("=")
print(format_dnf(reg.dataset.conj_signs, reg.dataset.conj_weights))

tensor([[0.9595],
        [0.8233],
        [0.9011],
        [0.9997],
        [0.9996],
        [0.7219],
        [0.9213],
        [0.8220],
        [0.9979],
        [0.9993],
        [0.9897],
        [0.3377],
        [0.6928],
        [0.8525],
        [0.9790],
        [0.9824]], device='cuda:0', grad_fn=<SigmoidBackward0>)
( 1 ∧ ¬25 )
∨ ( ¬16 )
∨ ( ¬25 ∧ 46 )
∨ ( ¬2 ∧ ¬37 )
∨ ( ¬25 ∧ ¬30 )
∨ ( ¬16 )
∨ ( ¬37 ∧ ¬49 )
∨ ( ¬16 ∧ ¬19 )
∨ ( ¬16 )
∨ ( ¬16 )
∨ ( 13 )
∨ ( ¬16 ∧ ¬30 ∧ ¬36 ∧ 47 )
∨ ( 13 )
∨ ( ¬2 ∧ ¬6 ∧ ¬18 ∧ ¬41 )
∨ ( ¬16 )
=
( 9 ∧ ¬14 ∧ ¬34 ∧ ¬44 )
∨ ( ¬16 )
∨ ( 5 ∧ 13 )
∨ ( ¬30 ∧ 46 )
∨ ( ¬25 ∧ ¬29 )
∨ ( ¬4 ∧ ¬19 ∧ 23 ∧ 25 ∧ ¬30 ∧ ¬36 ∧ ¬38 )
∨ ( ¬2 ∧ ¬37 )
∨ ( 2 ∧ ¬17 ∧ ¬39 ∧ ¬49 )


In [None]:
def train_dnf(model: FuzzyDNF, dataloader, lr, log):
    optimizer = torch.optim.Adam(model.parameters(), lr)
    loss_fn = torch.nn.BCELoss()

    seen = 0
    items = []
    obs_losses = []
    props = []
    dnf_strs = []
    prev_dnf_params = None

    for batch, (X, y) in tqdm(enumerate(dataloader), total=len(dataloader)):
        X = X.to(device)
        y = y.to(device)
        pred = model.logsumexp_forward(X)
        mult = torch.where(y > 0.5, torch.tensor(-1.0).to(device), torch.tensor(1.0).to(device))
        train_loss = (mult * pred).sum()
        #print(model(X).shape, y.shape)
        test_loss = loss_fn(model(X), y)
        # Optimise NN model
        optimizer.zero_grad()
        if log:
            train_loss.backward()
        else:
            (test_loss ** 2).backward()
        optimizer.step()
        
        """
        dnf_params = model.harden_params()
        if prev_dnf_params == None or not all(map(torch.equal, prev_dnf_params, dnf_params)):
            dnf_strs.append(model.params_to_str())
            prev_dnf_params = tuple(map(lambda t: t.clone().detach(), dnf_params))
        """
        seen += len(X)
        items.append(seen)
        obs_losses.append(test_loss.item())
        props.append(y.mean())
        
    return torch.tensor(items), torch.tensor(obs_losses), torch.tensor(props), dnf_strs

In [None]:
dnf_dataset = DNFDataset(10, 100, 0.5, 2, 0.5)
for X, y in dnf_dataset:
    print(format_vars(X))
    print("Given formula:")
    print(format_dnf(dnf_dataset.conj_signs, dnf_dataset.conj_weights))
    print("Produces:")
    print(str(y) + "\n")

In [None]:
dnf_dim = 10
dnf_dataset = DNFDataset(dnf_dim, 200000, 0.5, 2, 0.5)
dnf_dataloader = DataLoader(dnf_dataset, batch_size=128, shuffle=True)

dnf_model = FuzzyNN(dnf_dim)
dnf_model.to(device)

dnf_items, dnf_losses, dnf_props, dnf_strs = train_dnf(dnf_model, dnf_dataloader, lr=1e-3, log=False)

print(dnf_strs)

plot_loss(dnf_items, dnf_losses, ylabel='Observed Loss', approx=True)
plot_loss(dnf_items, dnf_props, ylabel='Observed Proportion of True Output', approx=True)

In [None]:
dnf_model = FuzzyNN(dnf_dim)
dnf_model.to(device)

dnf_items, dnf_losses, dnf_props, dnf_strs = train_dnf(dnf_model, dnf_dataloader, lr=1e-3, log=False)

print(dnf_strs)

plot_loss(dnf_items, dnf_losses, ylabel='Observed Loss', approx=True)
plot_loss(dnf_items, dnf_props, ylabel='Observed Proportion of True Output', approx=True)

In [None]:
class PerceptronDNF(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.model(input)

In [None]:
dnf_dim = 10
dnf_dataset = DNFDataset(dnf_dim, 200000, 0.5, 2, 0.5)
dnf_dataloader = DataLoader(dnf_dataset, batch_size=128, shuffle=True)

dnf_model = PerceptronDNF(dnf_dim)
dnf_model.to(device)

dnf_items, dnf_losses, dnf_props, dnf_strs = train_dnf(dnf_model, dnf_dataloader, lr=5e-4)

print(dnf_strs)

plot_loss(dnf_items, dnf_losses, ylabel='Observed Loss', approx=True)
plot_loss(dnf_items, dnf_props, ylabel='Observed Proportion of True Output', approx=True)

In [None]:
def train_embed_dnf(model: nn.Module, dataloader, optim, reg_weight=1e-2):
    loss_fn = torch.nn.BCELoss()

    seen = 0
    items = []
    obs_losses = []
    props = []

    for batch, (X, y) in tqdm(enumerate(dataloader), total=len(dataloader)):
        X = X.to(device)
        y = y.to(device)
        pred = torch.squeeze(model(X))
        loss = loss_fn(pred, y)

        # Optimise NN model
        optim.zero_grad()
        (loss + reg_weight * model.logic.logic_reg()).backward()
        optim.step()
        model.logic.zero_reg()
        
        seen += len(X)
        items.append(seen)
        obs_losses.append(loss.item())
        props.append(y.mean())
        
    return torch.tensor(items), torch.tensor(obs_losses), torch.tensor(props)

In [None]:
class EmbedNN(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.logic = embed_logic.EmbedLogic(5, calculate_reg=False)
        self.logic2 = embed_logic.EmbedLogic(5, calculate_reg=False)
        self.model = nn.Sequential(
            EmbedDNF((in_features, 16, 8), logic=self.logic),
            EmbedDNF((8, 16, 1), logic=self.logic2),
        )

    def forward(self, input):
        return self.logic.decode(self.model(self.logic.encode(input)))

In [None]:
embed_input_dim = 10
embed_dataset = DNFDataset(embed_input_dim, 150000, 0.5, 2, 0.5)
embed_dataloader = DataLoader(embed_dataset, batch_size=128, shuffle=True)

embed_model = EmbedNN(embed_input_dim)
embed_model.to(device)
embed_optimizer = torch.optim.Adam(embed_model.parameters(), lr=1e-4)

embed_items, embed_losses, embed_props = train_embed_dnf(embed_model, embed_dataloader, embed_optimizer, 0)
#dnf_items2, dnf_losses2, dnf_props2, _ = train_dnf(dnf_model, dnf_dataloader, lr=1e-4)

#dnf_items = torch.cat((dnf_items, dnf_items.max() + dnf_items2), dim=0)
#dnf_losses = torch.cat((dnf_losses, dnf_losses2), dim=0)
#dnf_props = torch.cat((dnf_props, dnf_props2), dim=0)

plot_loss(embed_items, embed_losses, ylabel='Observed Loss', approx=True)
plot_loss(embed_items, embed_props, ylabel='Observed Proportion of True Output', approx=True)

dnf_model.logic

In [None]:
embed_model.logic.neg(embed_model.logic.F())

In [None]:
dnf_losses[-1]

In [None]:
embed_model.logic.F()

In [None]:
embed_model.logic._not.w_1

In [None]:
embed_model.logic._not.b_1

In [None]:
embed_model.logic._not.w_2