 Helmholtz_Machine

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [4]:
torch.manual_seed(42)

<torch._C.Generator at 0x7a96ac3313d0>

In [5]:
input_size = 784  # 28x28 images
hidden_size = 200  # Latent variable size
num_classes = 10  # MNIST digits
batch_size = 100
num_epochs = 10
learning_rate = 0.001
dropout_prob = 0.5  # Dropout probability


In [6]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
class RecognitionNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, dropout_prob=0.5):
        super(RecognitionNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc_class = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
        x = x.view(-1, input_size)
        h = self.dropout(self.relu(self.fc1(x)))
        z_probs = torch.sigmoid(self.fc2(h))
        if self.training:
            z = torch.bernoulli(z_probs)
        else:
            z = z_probs
        class_logits = self.fc_class(h)
        return z, z_probs, class_logits

In [8]:
#Top - Down Approach
class GenerativeNetwork(nn.Module):
    def __init__(self, hidden_size, input_size, dropout_prob=0.5):
        super(GenerativeNetwork, self).__init__()
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, input_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_prob)
    def forward(self, z):
        h = self.dropout(self.relu(self.fc1(z)))
        x_recon_probs = torch.sigmoid(self.fc2(h))
        return x_recon_probs
class HelmholtzMachine(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, dropout_prob=0.5):
        super(HelmholtzMachine, self).__init__()
        self.recognition = RecognitionNetwork(input_size, hidden_size, num_classes, dropout_prob)
        self.generative = GenerativeNetwork(hidden_size, input_size, dropout_prob)

    def forward(self, x):
        z, z_probs, class_logits = self.recognition(x)
        x_recon = self.generative(z)
        return x_recon, z, z_probs, class_logits

    def generate(self, batch_size):
        """Generate samples by sampling from the prior and running the generative model"""
        device = next(self.parameters()).device
        # Sample from prior
        z_dream = torch.bernoulli(torch.ones(batch_size, hidden_size) * 0.5).to(device)
        x_dream = self.generative(z_dream)
        return x_dream, z_dream

In [9]:
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    wake_loss_total = 0
    sleep_loss_total = 0
    class_loss_total = 0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        x_recon, z, z_probs, class_logits = model(data)
        recon_loss = nn.BCELoss()(x_recon, data.view(-1, input_size))
        class_loss = criterion(class_logits, target)
        wake_loss = recon_loss + class_loss
        wake_loss.backward(retain_graph=True)
        optimizer.step()
        optimizer.zero_grad()
        x_dream, z_dream = model.generate(data.size(0))
        _, _, z_dream_recon_probs, _ = model(x_dream.detach())
        sleep_loss = nn.BCELoss()(z_dream_recon_probs, z_dream)
        sleep_loss.backward()
        optimizer.step()
        wake_loss_total += wake_loss.item()
        sleep_loss_total += sleep_loss.item()
        class_loss_total += class_loss.item()
        _, predicted = torch.max(class_logits.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
    accuracy = 100. * correct / total
    return wake_loss_total / len(train_loader), sleep_loss_total / len(train_loader), class_loss_total / len(train_loader), accuracy


In [10]:
def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            _, _, _, class_logits = model(data)
            test_loss += criterion(class_logits, target).item()
            _, predicted = torch.max(class_logits.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    test_loss /= len(test_loader)
    accuracy = 100. * correct / total
    return test_loss, accuracy

In [11]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = HelmholtzMachine(input_size, hidden_size, num_classes, dropout_prob).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        wake_loss, sleep_loss, class_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Wake Loss: {wake_loss:.4f}, Sleep Loss: {sleep_loss:.4f}, Class Loss: {class_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

if __name__ == "__main__":
    main()

Epoch 1/10:
  Wake Loss: 0.6618, Sleep Loss: 0.6931, Class Loss: 0.4095, Train Acc: 88.66%
  Test Loss: 0.1821, Test Acc: 94.63%
Epoch 2/10:
  Wake Loss: 0.4212, Sleep Loss: 0.6861, Class Loss: 0.2047, Train Acc: 94.00%
  Test Loss: 0.1313, Test Acc: 96.16%
Epoch 3/10:
  Wake Loss: 0.3712, Sleep Loss: 0.6859, Class Loss: 0.1631, Train Acc: 95.25%
  Test Loss: 0.1089, Test Acc: 96.72%
Epoch 4/10:
  Wake Loss: 0.3420, Sleep Loss: 0.6857, Class Loss: 0.1380, Train Acc: 95.91%
  Test Loss: 0.0930, Test Acc: 97.33%
Epoch 5/10:
  Wake Loss: 0.3242, Sleep Loss: 0.6852, Class Loss: 0.1230, Train Acc: 96.28%
  Test Loss: 0.0851, Test Acc: 97.30%
Epoch 6/10:
  Wake Loss: 0.3094, Sleep Loss: 0.6850, Class Loss: 0.1107, Train Acc: 96.61%
  Test Loss: 0.0801, Test Acc: 97.49%
Epoch 7/10:
  Wake Loss: 0.2989, Sleep Loss: 0.6845, Class Loss: 0.1026, Train Acc: 96.90%
  Test Loss: 0.0778, Test Acc: 97.70%
Epoch 8/10:
  Wake Loss: 0.2902, Sleep Loss: 0.6843, Class Loss: 0.0956, Train Acc: 97.02%
  Test