In [1]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import ltn

batch_size = 4096

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True)

In [3]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, training=False):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 16*4*4)
        x = self.fc1(x)
        #if training:
            #x = self.dropout(x)
        x = self.fc2(x)
        if training:
            x = self.dropout(x)
        x = self.fc3(x)
        return x             

class LogitsToPredicate(nn.Module):
    def __init__(self, logits_model):
        super(LogitsToPredicate, self).__init__()
        self.logits_model = logits_model
        self.softmax = torch.nn.Softmax(dim=1)
    
    def forward(self, x, l, training=False):
        logits = self.logits_model(x, training=training)
        probs = self.softmax(logits)
        out = torch.sum(probs * l, dim=1)
        return out

In [4]:
lenet = LeNet()
P = ltn.Predicate(LogitsToPredicate(lenet))

Not = ltn.Connective(ltn.fuzzy_ops.NotStandard())
And = ltn.Connective(ltn.fuzzy_ops.AndProd())
Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f")
SatAgg = ltn.fuzzy_ops.SatAgg()

l_0 = ltn.Constant(torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
l_1 = ltn.Constant(torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
l_2 = ltn.Constant(torch.tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0]))
l_3 = ltn.Constant(torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0]))
l_4 = ltn.Constant(torch.tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0]))
l_5 = ltn.Constant(torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0]))
l_6 = ltn.Constant(torch.tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0]))
l_7 = ltn.Constant(torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0]))
l_8 = ltn.Constant(torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]))
l_9 = ltn.Constant(torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))

In [5]:
from sklearn.metrics import accuracy_score

def compute_sat_level(loader):
    mean_sat = 0
    for data, label in loader:
        x_0 = ltn.Variable("x_0", data[label == 0])
        x_1 = ltn.Variable("x_1", data[label == 1])
        x_2 = ltn.Variable("x_2", data[label == 2])
        x_3 = ltn.Variable("x_3", data[label == 3])
        x_4 = ltn.Variable("x_4", data[label == 4])
        x_5 = ltn.Variable("x_5", data[label == 5])
        x_6 = ltn.Variable("x_6", data[label == 6])
        x_7 = ltn.Variable("x_7", data[label == 7])
        x_8 = ltn.Variable("x_8", data[label == 8])
        x_9 = ltn.Variable("x_9", data[label == 9])
        mean_sat += SatAgg(
            Forall(x_0, P(x_0, l_0)),
            Forall(x_1, P(x_1, l_1)),
            Forall(x_2, P(x_2, l_2)),
            Forall(x_3, P(x_3, l_3)),
            Forall(x_4, P(x_4, l_4)),
            Forall(x_5, P(x_5, l_5)),
            Forall(x_6, P(x_6, l_6)),
            Forall(x_7, P(x_7, l_7)),
            Forall(x_8, P(x_8, l_8)),
            Forall(x_9, P(x_9, l_9))
        )
    mean_sat /= len(loader)
    return mean_sat

def compute_acc(loader):
    mean_acc = 0
    for data, label in loader:
        predictioins = lenet(data).detach().numpy()
        predictioins = np.argmax(predictioins, axis=1)
        mean_acc += accuracy_score(label, predictioins)
    
    return mean_acc / len(loader)


In [6]:
optimizer = torch.optim.Adam(P.parameters(), lr=0.001)

In [7]:
for epoch in range(20):
    train_loss = 0.0
    for batch_idx, (data, label) in enumerate(trainloader):
        optimizer.zero_grad()
        x_0 = ltn.Variable("x_0", data[label == 0])
        x_1 = ltn.Variable("x_1", data[label == 1])
        x_2 = ltn.Variable("x_2", data[label == 2])
        x_3 = ltn.Variable("x_3", data[label == 3])
        x_4 = ltn.Variable("x_4", data[label == 4])
        x_5 = ltn.Variable("x_5", data[label == 5])
        x_6 = ltn.Variable("x_6", data[label == 6])
        x_7 = ltn.Variable("x_7", data[label == 7])
        x_8 = ltn.Variable("x_8", data[label == 8])
        x_9 = ltn.Variable("x_9", data[label == 9])
        sat_agg = SatAgg(
            Forall(x_0, P(x_0, l_0, training=True)),
            Forall(x_1, P(x_1, l_1, training=True)),
            Forall(x_2, P(x_2, l_2, training=True)),
            Forall(x_3, P(x_3, l_3, training=True)),
            Forall(x_4, P(x_4, l_4, training=True)),
            Forall(x_5, P(x_5, l_5, training=True)),
            Forall(x_6, P(x_6, l_6, training=True)),
            Forall(x_7, P(x_7, l_7, training=True)),
            Forall(x_8, P(x_8, l_8, training=True)),
            Forall(x_9, P(x_9, l_9, training=True))
        )

        loss = 1. - sat_agg
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(trainloader)

    print(f"Epoch {epoch} - Train Loss: {train_loss} - Train SAT: {compute_sat_level(trainloader)} - Test SAT: {compute_sat_level(testloader)} - Test Acc: {compute_acc(testloader)}")

Epoch 0 - Train Loss: 0.859576694170634 - Train SAT: 0.24459140002727509 - Test SAT: 0.24902211129665375 - Test Acc: 0.4899931150903392
Epoch 1 - Train Loss: 0.5942545255025228 - Train SAT: 0.5663737654685974 - Test SAT: 0.5763744115829468 - Test Acc: 0.8042438698377582
Epoch 2 - Train Loss: 0.39217299620310464 - Train SAT: 0.6659346222877502 - Test SAT: 0.6777152419090271 - Test Acc: 0.8814275961006638
Epoch 3 - Train Loss: 0.32418870528539023 - Train SAT: 0.709102213382721 - Test SAT: 0.7165148854255676 - Test Acc: 0.9080727726308998
Epoch 4 - Train Loss: 0.2860700766245524 - Train SAT: 0.7403632998466492 - Test SAT: 0.748497486114502 - Test Acc: 0.9302902896847346
Epoch 5 - Train Loss: 0.25621575911839806 - Train SAT: 0.767669141292572 - Test SAT: 0.775672435760498 - Test Acc: 0.9446434538624632
Epoch 6 - Train Loss: 0.22957884470621745 - Train SAT: 0.7909060716629028 - Test SAT: 0.8014370799064636 - Test Acc: 0.95149091422382
Epoch 7 - Train Loss: 0.20888715585072834 - Train SAT: 0

In [8]:
lenet.eval()
with torch.no_grad():
    data, label = next(iter(testloader))
    reslut = lenet(data).detach().numpy()
    print("ground truth", label)
    print("prediction", np.argmax(reslut, axis=1))


ground truth tensor([1, 3, 5,  ..., 0, 8, 1])
prediction [1 3 5 ... 0 8 1]
