In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Dataset
import os


class Binarize(object):
    def __call__(self, tensor):
        # Binarize based on a 0.5 threshold (after normalization)
        return (tensor > 0.5).float()
    

class RBM:
    def __init__(self, visible_dim, hidden_dim, learning_rate, batch_size, n_iter, verbose, random_state):
        # Number of visible and hidden units
        self.visible_dim = visible_dim
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_iter = n_iter
        self.verbose = verbose
        self.random_state = random_state

    def sample_from_prob(self, prob):
        # Sampling binary states from probabilities
        return torch.bernoulli(prob)

    def v_to_h(self, v):
        # Propagate visible layer to hidden layer
        h_prob = torch.sigmoid(torch.matmul(v, self.W.t()) + self.h_bias)
        h_sample = self.sample_from_prob(h_prob)
        return h_prob, h_sample

    def h_to_v(self, h):
        # Propagate hidden layer to visible layer
        v_prob = torch.sigmoid(torch.matmul(h, self.W) + self.v_bias)
        v_sample = self.sample_from_prob(v_prob)
        return v_prob, v_sample
    
    def persistent_contrastive_divergence(self, v, iter):
        # Gibbs Sampling for Persistent Contrastive Divergence
        h_pos, _ = self.v_to_h(v)
        _, v_neg  = self.h_to_v(self.h_samples_)
        h_neg, _ = self.v_to_h(v_neg)
        
        # Positive and negative phase
        pos_phase = torch.matmul(h_pos.t(), v)  # Should match [hidden_dim, visible_dim]
        neg_phase = torch.matmul(h_neg.t(), v_neg)  # Should match [hidden_dim, visible_dim]

        # Update weights and biases
        lr = self.learning_rate/(self.batch_size)
        self.W += lr*(pos_phase - neg_phase) / v.size(0)
        self.v_bias += lr*torch.sum(v - v_neg, dim=0)
        self.h_bias += lr*torch.sum(h_pos - h_neg, dim=0)

        # Update the persistent chain
        self.h_samples_ = self.sample_from_prob(h_neg)

    def _free_energy(self, v):
        # Energy function of the RBM
        vbias_term = torch.matmul(v, self.v_bias.reshape(-1, 1))
        hidden_term = torch.sum(torch.log(1 + torch.exp(torch.matmul(v, self.W.t()) + self.h_bias)), dim=1)
        return -vbias_term - hidden_term

    def score_samples(self, X):
        """Compute the pseudo-likelihood of X.

        Parameters
        ----------
        X : Tensor of shape (n_samples, n_features)
            Values of the visible layer. Must be all-boolean (not checked).

        Returns
        -------
        pseudo_likelihood : Tensor of shape (n_samples,)
            Value of the pseudo-likelihood (proxy for likelihood).
        """
        # Ensure X is a PyTorch tensor
        X = torch.as_tensor(X).float()

        # Randomly corrupt one feature in each sample in X
        n_samples, n_features = X.shape
        ind = (torch.arange(n_samples), torch.randint(0, n_features, (n_samples,)))

        # Create a copy of X and corrupt one feature in each sample
        X_corrupted = X.clone()
        X_corrupted[ind] = 1 - X_corrupted[ind]

        # Calculate free energy for the original and corrupted inputs
        fe = self._free_energy(X)
        fe_corrupted = self._free_energy(X_corrupted)

        # Compute the pseudo-likelihood using the logistic function of the difference
        pseudo_likelihood = -n_features * torch.logaddexp(torch.tensor(0.0), -(fe_corrupted - fe))

        return pseudo_likelihood

    def encoder(self, dataset: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
        """
        Generate top level latent variables
        """
        p_h_given_v, _ = self.v_to_h(dataset)
        return p_h_given_v

    def encode(self, dataloader: DataLoader) -> DataLoader:
        """
        Encode data
        """
        latent_vars = []
        labels = []
        for data, label in dataloader:
            data = data.view(-1, self.visible_dim)
            label = label.unsqueeze(1).to(torch.float32)
            latent_vars.append(self.encoder(data, label))
            labels.append(label)
        latent_vars = torch.cat(latent_vars, dim=0)
        labels = torch.cat(labels, dim=0)
        latent_dataset = TensorDataset(latent_vars, labels)

        return DataLoader(latent_dataset, batch_size=dataloader.batch_size, shuffle=False)
    
    def fit(self, data_loader: DataLoader, k=1):
        """
        Fit the RBM model
        """
        self.W = torch.randn(self.hidden_dim, self.visible_dim, dtype=torch.float32)
        self.v_bias = torch.zeros(self.visible_dim, dtype=torch.float32)
        self.h_bias = torch.zeros(self.hidden_dim, dtype=torch.float32)
        self.h_samples_ = torch.zeros(self.batch_size, self.hidden_dim, dtype=torch.float32)
        
        for epoch in range(self.n_iter):
            epoch_loss = 0
            for batch in data_loader:
                v, _ = batch  # Ignore labels
                v = v.view(-1, self.visible_dim)  # Flatten the input images to [batch_size, visible_dim]
                # Update parameters
                self.persistent_contrastive_divergence(v, epoch)

            print(f"Epoch {epoch+1}/{self.n_iter} finished")

In [34]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    Binarize()
])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(mnist, batch_size=64, shuffle=False, drop_last=True)

for (data, label) in data_loader:
    print(data.squeeze(dim=1).shape)
    print(data.view(-1, 28*28).shape)
    break
    

torch.Size([64, 28, 28])
torch.Size([64, 784])


In [18]:
# Initialize and train the RBM
visible_dim = 28 * 28  # MNIST images are 28x28
hidden_dim = 128       # You can adjust this value
learning_rate = 0.06
batch_size = 64
n_iter = 10
verbose = 1
random_state = 42
rbm = RBM(visible_dim, hidden_dim, learning_rate, batch_size, n_iter, verbose, random_state)

rbm.fit(data_loader)

Epoch 1/10 finished
Epoch 2/10 finished
Epoch 3/10 finished
Epoch 4/10 finished
Epoch 5/10 finished
Epoch 6/10 finished
Epoch 7/10 finished
Epoch 8/10 finished
Epoch 9/10 finished
Epoch 10/10 finished


In [19]:
rbm.W

tensor([[-1.1653,  0.8080, -0.1664,  ..., -0.5660,  0.3053,  0.4966],
        [ 0.3320, -0.3331, -0.9235,  ..., -0.1498, -1.1905,  0.6766],
        [ 0.2480,  1.1154, -0.6515,  ..., -1.5787, -0.8422,  0.6336],
        ...,
        [-0.6814, -0.2682,  0.9927,  ...,  0.5435,  0.9702,  1.1919],
        [ 2.2777, -1.0583, -0.3714,  ...,  0.4620, -0.1399,  0.6953],
        [-0.9509,  1.8501,  1.3058,  ..., -1.3820,  1.3236,  0.2145]])

In [12]:
rbm.W.shape

torch.Size([128, 784])

In [35]:
import numpy as np
from sklearn.datasets import fetch_openml

# Load the MNIST dataset from OpenML
mnist = fetch_openml('mnist_784', version=1, as_frame=False)

# Extract features and labels
X, y = mnist['data'], mnist['target']

# Convert labels to integers
y = y.astype(np.uint8)

# Check the data shapes and some samples
print('Features shape:', X.shape)  # Should be (70000, 784)
print('Labels shape:', y.shape)     # Should be (70000,)
print('First 5 labels:', y[:5])     # Check the first 5 labels


Features shape: (70000, 784)
Labels shape: (70000,)
First 5 labels: [5 0 4 1 9]


In [40]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("Training set shape:", X_train.shape, y_train.shape)
print("Testing set shape:", X_test.shape, y_test.shape)


Training set shape: (56000, 784) (56000,)
Testing set shape: (14000, 784) (14000,)
