# Restricted Boltzmann Machine (RBM) Experiments

This notebook showcases experimental results of Restricted Boltzmann Machines on the MNIST dataset, including training process, sample generation, weight analysis, and feature learning.

In [None]:
# Import necessary libraries
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.utils.data as data
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# Add project root directory to path
sys.path.append('..')

# Import project modules
from src.boltzmann.rbm import RBM
from src.boltzmann.sampler import RBMSampler
from src.boltzmann.experiments import RBMExperiments
from src.utils.data_loader import MNISTLoader
from src.utils.preprocessing import binary_to_image
import config

## 1. Data Loading and Preprocessing

In [None]:
# Create data loader
loader = MNISTLoader()

# Load MNIST data
train_data = loader.get_train_data(binary_values={0, 1})
test_data = loader.get_test_data(binary_values={0, 1})

# Create data loaders
train_loader = data.DataLoader(train_data, batch_size=config.RBM_CONFIG['batch_size'], shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=config.RBM_CONFIG['batch_size'], shuffle=False)

# Visualize some samples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, (data, labels) in enumerate(train_loader):
    if i >= 2:  # Only show first two batches
        break
    for j in range(5):
        if j >= len(data):
            break
        img = binary_to_image(data[j].numpy(), (28, 28))
        axes[i, j].imshow(img, cmap='binary')
        axes[i, j].set_title(f'Label: {labels[j].item()}')
        axes[i, j].axis('off')
plt.tight_layout()
plt.show()

## 2. RBM Training

In [None]:
# Create RBM
rbm = RBM(
    n_visible=config.RBM_CONFIG['n_visible'],
    n_hidden=config.RBM_CONFIG['n_hidden'],
    k=config.RBM_CONFIG['k'],
    learning_rate=config.RBM_CONFIG['learning_rate'],
    momentum=config.RBM_CONFIG['momentum'],
    weight_decay=config.RBM_CONFIG['weight_decay'],
    use_cuda=config.RBM_CONFIG['use_cuda']
)

# Train RBM
print("Training RBM...")
train_errors = []
epoch_times = []
n_epochs = 20

for epoch in range(n_epochs):
    epoch_start_time = time.time()
    batch_errors = []
    
    for batch_idx, (data, _) in enumerate(train_loader):
        # Flatten data and binarize
        batch = data.view(data.size(0), -1)
        batch = (batch > 0.5).float()  # Binarize
        
        if rbm.use_cuda:
            batch = batch.cuda()
        
        # Train a batch
        error = rbm.train_batch(batch)
        batch_errors.append(error)
    
    # Calculate average error
    avg_error = np.mean(batch_errors)
    train_errors.append(avg_error)
    
    # Record training time
    epoch_time = time.time() - epoch_start_time
    epoch_times.append(epoch_time)
    
    print(f"Epoch {epoch+1}/{n_epochs}, Error: {avg_error:.6f}, Time: {epoch_time:.2f}s")

print("RBM training completed!")

## 3. Training Process Visualization

In [None]:
# Plot training process
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot training error
ax1.plot(train_errors, 'b-', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Reconstruction Error', fontsize=12)
ax1.set_title('Training Error', fontsize=14)
ax1.grid(True, alpha=0.3)

# Plot training time
ax2.plot(epoch_times, 'g-', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Time (seconds)', fontsize=12)
ax2.set_title('Training Time per Epoch', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Weight Matrix Analysis

In [None]:
# Get weight matrix
weights = rbm.get_weights()
weights_np = weights.detach().cpu().numpy()

# Visualize weight matrix heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(weights_np, cmap='coolwarm', center=0)
plt.title('Weight Matrix', fontsize=14)
plt.xlabel('Hidden Units', fontsize=12)
plt.ylabel('Visible Units', fontsize=12)
plt.show()

# Plot weight distribution histogram
plt.figure(figsize=(8, 6))
plt.hist(weights_np.flatten(), bins=50, alpha=0.7, edgecolor='black')
plt.title('Weight Distribution', fontsize=14)
plt.xlabel('Weight Value', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()

# Visualize weight patterns corresponding to hidden units
n_hidden = weights_np.shape[1]
n_display = min(64, n_hidden)

# Create figure
grid_size = int(np.sqrt(n_display))
fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))

if grid_size == 1:
    axes = np.array([[axes]])

for i in range(grid_size):
    for j in range(grid_size):
        idx = i * grid_size + j
        if idx < n_display:
            # Reshape weights to image
            weight_img = weights_np[:, idx].reshape(28, 28)
            axes[i, j].imshow(weight_img, cmap='seismic')
            axes[i, j].axis('off')
            axes[i, j].set_title(f'Hidden {idx}', fontsize=8)
        else:
            axes[i, j].axis('off')

# Set main title
fig.suptitle('Hidden Unit Weight Patterns', fontsize=16)
plt.tight_layout()
plt.show()

## 5. Sample Generation

In [None]:
# Generate samples
print("Generating samples...")
n_samples = 64
n_gibbs_steps = 1000

samples = rbm.generate_samples(n_samples, n_gibbs_steps)

# Convert samples to image format
samples_np = samples.detach().cpu().numpy()
images = binary_to_image(samples_np, (28, 28))

# Create grid image
grid_size = int(np.sqrt(n_samples))
fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))

for i in range(grid_size):
    for j in range(grid_size):
        idx = i * grid_size + j
        if idx < len(images):
            axes[i, j].imshow(images[idx], cmap='binary')
            axes[i, j].axis('off')
        else:
            axes[i, j].axis('off')

# Set main title
fig.suptitle('Generated Samples', fontsize=16)
plt.tight_layout()
plt.show()

## 6. Feature Learning Visualization

In [None]:
# Get hidden layer representations
n_samples = 1000
sample_loader = data.DataLoader(train_data, batch_size=n_samples, shuffle=True)

# Get a batch of data
for data, labels in sample_loader:
    batch = data.view(data.size(0), -1)
    batch = (batch > 0.5).float()  # Binarize
    break

# Get hidden layer representations
hidden_repr = rbm.get_hidden_representation(batch)
hidden_np = hidden_repr.detach().cpu().numpy()
labels_np = labels.numpy()

# Use t-SNE for dimensionality reduction
print("Running t-SNE...")
tsne = TSNE(n_components=2, random_state=42)
hidden_tsne = tsne.fit_transform(hidden_np)

# Use PCA for dimensionality reduction
pca = PCA(n_components=2)
hidden_pca = pca.fit_transform(hidden_np)

# Visualize dimensionality reduction results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Plot t-SNE results
scatter1 = ax1.scatter(hidden_tsne[:, 0], hidden_tsne[:, 1], c=labels_np, cmap='tab10', alpha=0.7)
ax1.set_xlabel('t-SNE Component 1', fontsize=12)
ax1.set_ylabel('t-SNE Component 2', fontsize=12)
ax1.set_title('t-SNE Visualization of Hidden Representations', fontsize=14)
plt.colorbar(scatter1, ax=ax1)

# Plot PCA results
scatter2 = ax2.scatter(hidden_pca[:, 0], hidden_pca[:, 1], c=labels_np, cmap='tab10', alpha=0.7)
ax2.set_xlabel('PCA Component 1', fontsize=12)
ax2.set_ylabel('PCA Component 2', fontsize=12)
ax2.set_title('PCA Visualization of Hidden Representations', fontsize=14)
plt.colorbar(scatter2, ax=ax2)

plt.tight_layout()
plt.show()

## 7. Hidden Layer Activation Analysis

In [None]:
# Calculate average hidden activations for each digit
n_hidden = hidden_np.shape[1]
n_digits = 10

digit_activations = np.zeros((n_digits, n_hidden))
digit_counts = np.zeros(n_digits)

for i, label in enumerate(labels_np):
    digit_activations[label] += hidden_np[i]
    digit_counts[label] += 1

# Calculate averages
for digit in range(n_digits):
    if digit_counts[digit] > 0:
        digit_activations[digit] /= digit_counts[digit]

# Create figure
fig, ax = plt.subplots(figsize=(12, 8))

# Plot heatmap
sns.heatmap(digit_activations, cmap='viridis', ax=ax)
ax.set_xlabel('Hidden Units', fontsize=12)
ax.set_ylabel('Digits', fontsize=12)
ax.set_title('Average Hidden Activations per Digit', fontsize=14)

plt.tight_layout()
plt.show()

## 8. Sampling Methods Comparison

In [None]:
# Create sampler
sampler = RBMSampler(rbm)

# Compare different sampling methods
n_samples = 64
n_steps = 1000

# Standard Gibbs sampling
print("Running Gibbs sampling...")
gibbs_samples, _ = sampler.gibbs_sample(n_samples, n_steps)

# Block Gibbs sampling
print("Running block Gibbs sampling...")
block_samples, _ = sampler.block_gibbs_sample(n_samples, n_steps)

# Tempered transition sampling
print("Running tempered transition sampling...")
annealed_samples, _ = sampler.tempered_transition_sample(n_samples, n_steps)

# Parallel tempering sampling
print("Running parallel tempering sampling...")
pt_samples, _ = sampler.parallel_tempering_sample(n_samples, n_steps)

# Convert samples to image format
gibbs_images = binary_to_image(gibbs_samples.detach().cpu().numpy(), (28, 28))
block_images = binary_to_image(block_samples.detach().cpu().numpy(), (28, 28))
annealed_images = binary_to_image(annealed_samples.detach().cpu().numpy(), (28, 28))
pt_images = binary_to_image(pt_samples.detach().cpu().numpy(), (28, 28))

# Select number of images to display
n_display = 16

# Create figure
fig, axes = plt.subplots(4, n_display, figsize=(20, 8))

# Display Gibbs sampling results
for i in range(n_display):
    if i < len(gibbs_images):
        axes[0, i].imshow(gibbs_images[i], cmap='binary')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Gibbs Sampling', fontsize=12)

# Display block Gibbs sampling results
for i in range(n_display):
    if i < len(block_images):
        axes[1, i].imshow(block_images[i], cmap='binary')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Block Gibbs Sampling', fontsize=12)

# Display tempered transition results
for i in range(n_display):
    if i < len(annealed_images):
        axes[2, i].imshow(annealed_images[i], cmap='binary')
    axes[2, i].axis('off')
    if i == 0:
        axes[2, i].set_title('Tempered Transition', fontsize=12)

# Display parallel tempering results
for i in range(n_display):
    if i < len(pt_images):
        axes[3, i].imshow(pt_images[i], cmap='binary')
    axes[3, i].axis('off')
    if i == 0:
        axes[3, i].set_title('Parallel Tempering', fontsize=12)

# Set main title
fig.suptitle('Sampling Methods Comparison', fontsize=16)
plt.tight_layout()
plt.show()

## 9. Dream Experiment

In [None]:
# Run dream experiment
digit = 0
n_steps = 1000

print(f"Running dream experiment starting from digit {digit} for {n_steps} steps")

# Get sample of specified digit
sample = loader.get_specific_digit_sample(digit, binary_values={0, 1})
sample = sample.unsqueeze(0)  # Add batch dimension

# Dream
dream_sample, dream_history = sampler.dream(sample, n_steps)

# Convert samples to image format
original_image = binary_to_image(sample.detach().cpu().numpy(), (28, 28))[0]
dream_image = binary_to_image(dream_sample.detach().cpu().numpy(), (28, 28))[0]

# Convert dream history to images
dream_images = []
for state in dream_history:
    img = binary_to_image(state.detach().cpu().numpy(), (28, 28))[0]
    dream_images.append(img)

# Select key frames to display
n_frames = min(8, len(dream_images))
frame_indices = np.linspace(0, len(dream_images)-1, n_frames, dtype=int)

# Create figure
fig, axes = plt.subplots(2, n_frames+1, figsize=(15, 6))

# Display original image
axes[0, 0].imshow(original_image, cmap='binary')
axes[0, 0].set_title('Original', fontsize=12)
axes[0, 0].axis('off')
axes[1, 0].axis('off')

# Display key frames during dream process
for i, idx in enumerate(frame_indices):
    axes[0, i+1].imshow(dream_images[idx], cmap='binary')
    axes[0, i+1].set_title(f'Step {idx}', fontsize=10)
    axes[0, i+1].axis('off')
    
    # Calculate difference from original image
    diff = np.abs(dream_images[idx] - original_image)
    axes[1, i+1].imshow(diff, cmap='hot')
    axes[1, i+1].set_title(f'Diff: {np.mean(diff):.3f}', fontsize=10)
    axes[1, i+1].axis('off')

# Set main title
fig.suptitle('Dream Process', fontsize=16)
plt.tight_layout()
plt.show()

## 10. Feature Extraction and Classification

In [None]:
# Extract hidden layer features from training and test sets
print("Extracting features from training set...")
train_features = []
train_labels = []

for data, labels in train_loader:
    batch = data.view(data.size(0), -1)
    batch = (batch > 0.5).float()
    
    # Get hidden layer representation
    hidden = rbm.get_hidden_representation(batch)
    train_features.append(hidden)
    train_labels.append(labels)

train_features = torch.cat(train_features, dim=0).detach().cpu().numpy()
train_labels = torch.cat(train_labels, dim=0).numpy()

print("Extracting features from test set...")
test_features = []
test_labels = []

for data, labels in test_loader:
    batch = data.view(data.size(0), -1)
    batch = (batch > 0.5).float()
    
    # Get hidden layer representation
    hidden = rbm.get_hidden_representation(batch)
    test_features.append(hidden)
    test_labels.append(labels)

test_features = torch.cat(test_features, dim=0).detach().cpu().numpy()
test_labels = torch.cat(test_labels, dim=0).numpy()

print(f"Training features shape: {train_features.shape}")
print(f"Test features shape: {test_features.shape}")

In [None]:
# Use logistic regression classifier
print("Training logistic regression on RBM features...")
lr_rbm = LogisticRegression(max_iter=1000)
lr_rbm.fit(train_features, train_labels)
rbm_pred = lr_rbm.predict(test_features)
rbm_accuracy = accuracy_score(test_labels, rbm_pred)
print(f"Accuracy with RBM features: {rbm_accuracy:.4f}")

# Get raw pixel features for comparison
print("Training logistic regression on raw pixels...")
train_pixels = train_data.data.view(train_data.data.size(0), -1).numpy() / 255.0
test_pixels = test_data.data.view(test_data.data.size(0), -1).numpy() / 255.0

lr_raw = LogisticRegression(max_iter=1000)
lr_raw.fit(train_pixels, train_labels)
raw_pred = lr_raw.predict(test_pixels)
raw_accuracy = accuracy_score(test_labels, raw_pred)
print(f"Accuracy with raw pixels: {raw_accuracy:.4f}")

# Compare results
print(f"\nImprovement: {rbm_accuracy - raw_accuracy:.4f}")

# Visualize comparison results
fig, ax = plt.subplots(figsize=(8, 6))
methods = ['Raw Pixels', 'RBM Features']
accuracies = [raw_accuracy, rbm_accuracy]
bars = ax.bar(methods, accuracies, color=['blue', 'green'])
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Classification Accuracy Comparison', fontsize=14)
ax.set_ylim(0, 1)

# Add value labels
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
            f'{acc:.4f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## Summary

This notebook showcases various experimental results of RBM on the MNIST dataset, including:

1. RBM training process and reconstruction error analysis
2. Weight matrix analysis and visualization
3. Sample generation and quality evaluation
4. Feature learning and dimensionality reduction visualization
5. Hidden layer activation pattern analysis
6. Comparison of multiple sampling methods
7. Dream experiment and process visualization
8. Feature extraction and classification performance comparison

These experiments help us understand the working principles, learning capabilities, and application potential of RBM.