In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import torch
from hydra.utils import instantiate
from hydra import initialize, compose
import hydra
import wandb
from omegaconf import OmegaConf
import copy

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from model.rbm.rbm import RBM
from model.rbm.rbm_torch import RBMtorch
from model.rbm.rbm_fulltorch import RBMTorchFull
from scripts.run import setup_model, load_model_instance


In [None]:
# Binarize MNIST transform (pixels to {0,1})
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x > 0.5).float())
])

# Download MNIST
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset  = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
#print number of batches
print(f"Number of training batches: {len(train_loader)}")

In [None]:
def preprocess_batch(x, dev):
    x = x.to(dev)
    x = x.view(x.size(0), -1)   # flatten -> (batch, 784)
    return torch.chunk(x, 4, dim=1)  # four partitions, each (batch, 196)

In [None]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path="config")
config=compose(config_name="config.yaml")

# Pick device from config or fall back to auto-detect
if config.device == "gpu" and torch.cuda.is_available():
    dev = torch.device(f"cuda:{config.gpu_list[0]}")
else:
    dev = torch.device("cpu")

# Move RBM to device

RBM = RBMTorchFull(config).to(dev)

In [None]:
num_epochs = 5
losses = []

for epoch in range(num_epochs):
    for batch_idx, (x, _) in enumerate(train_loader):
        post_samples = preprocess_batch(x, dev)
        loss = RBM.step_on_batch(post_samples)
        losses.append(loss)

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch} Batch {batch_idx}: Loss = {loss:.4f}")

In [None]:
def visualize_flat_samples(p0, p1, p2, p3, n=8):
    """
    Stitch partitions from a flat-split RBM into 28x28 images.
    p0..p3: (batch, 196) binary samples
    n: number of images to display
    """
    samples = []
    for i in range(n):
        flat = torch.cat([p0[i], p1[i], p2[i], p3[i]], dim=0)  # length 784
        img = flat.view(28, 28).cpu().numpy()
        samples.append(img)

    fig, axes = plt.subplots(1, n, figsize=(n*2, 2))
    for ax, img in zip(axes, samples):
        ax.imshow(img, cmap="gray")
        ax.axis("off")
    plt.show()



In [None]:
# Generate unconditional samples
p0, p1, p2, p3 = .block_gibbs_sampling(batch_size=16)

# Visualize
visualize_flat_samples(p0, p1, p2, p3, n=8)
