In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

In [2]:
# 定义分类器网络 q(y|x)
class Classifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        logits = self.fc(x)
        q_y = F.softmax(logits, dim=1)
        return q_y, logits



In [3]:
# 计算 ELBO
def compute_elbo(x, q_y, logits, mu_y, epsilon=1e-8):
    # x: (batch_size, num_pixels)
    # q_y: (batch_size, num_classes)
    # logits: (batch_size, num_classes)
    # mu_y: (num_classes, num_pixels)
    
    num_classes = mu_y.size(0)
    batch_size = x.size(0)
    
    # log p(y) = log(1/num_classes)
    log_p_y = torch.log(torch.tensor(1.0 / num_classes, dtype=torch.float32, device=x.device))
    
    # log p(x|y) = sum(x * log mu_y + (1 - x) * log (1 - mu_y))
    log_mu = torch.log(mu_y + epsilon)
    log_one_minus_mu = torch.log(1 - mu_y + epsilon)
    log_p_x_given_y = (x @ log_mu.t()) + ((1 - x) @ log_one_minus_mu.t())
    
    # E_q[log p(x, y)] = sum_q log p(x, y)
    e_q_log_p_x_y = torch.sum(q_y * log_p_x_given_y, dim=1)
    e_q_log_p_y = torch.sum(q_y * log_p_y, dim=1)
    e_q_log_p_x_y += e_q_log_p_y
    
    # E_q[log q(y|x)] = sum_q log q(y|x)
    log_q_y = nn.functional.log_softmax(logits, dim=1)
    e_q_log_q_y = torch.sum(q_y * log_q_y, dim=1)
    
    # ELBO = E_q_log_p_x_y - E_q_log_q_y
    elbo = e_q_log_p_x_y - e_q_log_q_y
    return elbo


In [4]:
# 训练循环
def train(model, mu_y, optimizer, train_loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (images, labels) in enumerate(train_loader):
            images = images.view(images.size(0), -1)
            q_y, logits = model(images)
            elbo = compute_elbo(images, q_y, logits, mu_y)
            loss = -elbo.mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        test_accuracy(model, test_loader)


In [5]:
# 测试准确率
def test_accuracy(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.view(images.size(0), -1)
            q_y, logits = model(images)
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

In [8]:
# 获取mu_y
def compute_class_means(dataset, num_classes, input_dim):
    class_means = torch.zeros(num_classes, input_dim, dtype=torch.float32)
    class_counts = torch.zeros(num_classes, dtype=torch.int32)
    for image, label in dataset:
        # 确保图像展平为 [784] 的形状
        class_means[label] += image.view(-1)
        class_counts[label] += 1
    for y in range(num_classes):
        if class_counts[y] > 0:
            class_means[y] /= class_counts[y]
    return class_means

In [9]:
# 主程序
if __name__ == '__main__':
    input_dim = 784
    num_classes = 10
    batch_size = 64

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.MNIST(root='D:\\jupyter\\softHebb\\data', train=True, download=False, transform=transform)
    test_dataset = datasets.MNIST(root='D:\\jupyter\\softHebb\\data', train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    mu_y = compute_class_means(train_dataset, num_classes, input_dim)

    model = Classifier(input_dim, num_classes)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    train(model, mu_y, optimizer, train_loader, epochs=10)

Epoch 1, Loss: nan
Test Accuracy: 9.80%
Epoch 2, Loss: nan
Test Accuracy: 9.80%
Epoch 3, Loss: nan
Test Accuracy: 9.80%
Epoch 4, Loss: nan
Test Accuracy: 9.80%
Epoch 5, Loss: nan
Test Accuracy: 9.80%
Epoch 6, Loss: nan
Test Accuracy: 9.80%
Epoch 7, Loss: nan
Test Accuracy: 9.80%
Epoch 8, Loss: nan
Test Accuracy: 9.80%
Epoch 9, Loss: nan
Test Accuracy: 9.80%
Epoch 10, Loss: nan
Test Accuracy: 9.80%
