In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import ltn

In [28]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

In [3]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)  # 10个类别

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [12]:
# LTN常量表示类别
class_labels = {
    "airplane": ltn.Constant(torch.tensor([0])),
    "automobile": ltn.Constant(torch.tensor([1])),
    "bird": ltn.Constant(torch.tensor([2])),
    "cat": ltn.Constant(torch.tensor([3])),
    "deer": ltn.Constant(torch.tensor([4])),
    "dog": ltn.Constant(torch.tensor([5])),
    "frog": ltn.Constant(torch.tensor([6])),
    "horse": ltn.Constant(torch.tensor([7])),
    "ship": ltn.Constant(torch.tensor([8])),
    "truck": ltn.Constant(torch.tensor([9]))
}

class_labels_tensor = torch.stack([label.value for label in class_labels.values()])

# 定义一个谓词P用于检测某个样本是否属于某一类别
P = ltn.Predicate(nn.Linear(10, 1))

In [17]:
Not = ltn.Connective(ltn.fuzzy_ops.NotStandard())
Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f")
SatAgg = ltn.fuzzy_ops.SatAgg()

x = ltn.Variable("x", torch.randn(64, 10))
y = ltn.Variable("y", class_labels_tensor)

formula = Forall(
    [x, y],
    Not(P(x, y))
)

TypeError: forward() takes 2 positional arguments but 3 were given