In [1]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from semantic_loss_pytorch import SemanticLoss

In [2]:
# Hyperparameters and Configuration
batch_size = 64
epochs = 5
constraint_sdd = "constraint.sdd"
constraint_vtree = "constraint.vtree"
use_semantic_loss = True

In [3]:
# Define the MLP Model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.dropout = nn.Dropout(0.3)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

In [4]:
# Accuracy Computation
def compute_accuracy(logits, labels):
    preds = torch.argmax(logits, dim=1)
    correct = (preds == labels).float().sum()
    return correct / len(labels)

In [5]:
# Training Step
def train_step(model, images, labels, optimizer, sl_module=None):
    model.train()
    logits = model(images)
    sl = 0

    ce_loss = F.cross_entropy(logits, labels)
    if sl_module is not None:
        sl = sl_module(logits=logits)
        loss = ce_loss + 0.1 * sl
    else:
        loss = ce_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    acc = compute_accuracy(logits, labels)
    return loss.item(), ce_loss.item(), float(sl), acc.item()

In [6]:
# Training and Evaluation
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True,
                                          transform=transforms.Compose([
                                              transforms.ToTensor(),
                                              transforms.Normalize((0.1307,), (0.3081,))
                                          ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

model = MLP()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
sl_module = SemanticLoss(constraint_sdd, constraint_vtree) if use_semantic_loss else None

In [7]:
# Main Training Loop
for epoch in range(epochs):
    for step, (images, labels) in enumerate(train_loader):
        loss, ce_loss, sl, acc = train_step(model, images, labels, optimizer, sl_module)
        if step % 100 == 0:
            print(f"Epoch {epoch}, Step {step}: "
                  f"Loss: {loss:.4f}, Cross Entropy: {ce_loss:.4f}, Semantic Loss: {sl:.4f}, Accuracy: {acc:.4f}")
    scheduler.step()

    # 🔍 Evaluation
    model.eval()
    test_acc = 0
    with torch.no_grad():
        for images, labels in test_loader:
            logits = model(images)
            test_acc += compute_accuracy(logits, labels).item()
    test_acc /= len(test_loader)
    print(f"Epoch {epoch} - Test Accuracy: {test_acc:.4f}")

Epoch 0, Step 0: Loss: 2.8552, Cross Entropy: 2.4254, Semantic Loss: 4.2983, Accuracy: 0.1406
Epoch 0, Step 100: Loss: 0.5670, Cross Entropy: 0.4712, Semantic Loss: 0.9580, Accuracy: 0.8438
Epoch 0, Step 200: Loss: 0.5388, Cross Entropy: 0.4792, Semantic Loss: 0.5961, Accuracy: 0.8906
Epoch 0, Step 300: Loss: 0.3120, Cross Entropy: 0.2414, Semantic Loss: 0.7065, Accuracy: 0.9375
Epoch 0, Step 400: Loss: 0.2408, Cross Entropy: 0.1779, Semantic Loss: 0.6295, Accuracy: 0.9375
Epoch 0, Step 500: Loss: 0.5037, Cross Entropy: 0.4445, Semantic Loss: 0.5922, Accuracy: 0.8438
Epoch 0, Step 600: Loss: 0.4624, Cross Entropy: 0.3888, Semantic Loss: 0.7358, Accuracy: 0.9219
Epoch 0, Step 700: Loss: 0.1714, Cross Entropy: 0.1345, Semantic Loss: 0.3693, Accuracy: 0.9531
Epoch 0, Step 800: Loss: 0.3669, Cross Entropy: 0.2950, Semantic Loss: 0.7191, Accuracy: 0.9062
Epoch 0, Step 900: Loss: 0.1394, Cross Entropy: 0.1109, Semantic Loss: 0.2847, Accuracy: 0.9844
Epoch 0 - Test Accuracy: 0.9615
Epoch 1, S