In [None]:
!pip install torch==2.7.0
!pip install scikit-learn==1.6.1
!pip install numpy==2.2.6
!pip install matplotlib==3.10.3
!pip install scipy==1.15.3
!pip install pennylane==0.41.4


In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datetime import datetime
from torch.optim.lr_scheduler import LambdaLR
from sklearn.neighbors import KernelDensity
from sklearn.cluster import KMeans
from sklearn.model_selection import KFold
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from google.colab import files
import csv
import pennylane

##Implementation

In [None]:
class RBM(nn.Module):
    """
    Classical Restricted Boltzmann Machine (RBM) with CD or PCD training.

    Args:
        nv (int): Number of visible units.
        nh (int): Number of hidden units.
        k (int): Number of Gibbs sampling steps (CD-k).
        training (str): Either 'CD' (default) or 'PCD'.
        dataset (torch.Tensor, optional): Dataset for training.
        visible_normalized (bool): Whether visible units are real values in [0,1].
        lr (float): Learning rate.
        lr_trend (str): 'linear' (decay to 0) or 'constant'.
        bs (int): Batch size.
        epochs (int): Number of training epochs.
        print_step (int): Frequency of printing loss and metric information.
        verbose (bool): Whether to print progress.
        track_learning (bool): Whether to track training progress with metrics.
        track_method (str): Tracking method, can include 'KDE', 'MMD', 'NDB', or 'rec_error'.
        track_step (int): How often to compute tracking metrics.
        KDE_bandwidth (float): Optional bandwidth override for KDE tracker.
    """

    def __init__(self, nv, nh, k=1, training='CD', dataset=None,
                 visible_normalized=False, lr=0.001, lr_trend='linear',
                 bs=1024, epochs=100, print_step=100, verbose=True,
                 track_learning=False, track_method='KDE', track_step=100,
                 KDE_bandwidth=None):

        super().__init__()
        self.nv = nv
        self.nh = nh
        self.k = k
        self.training = training
        self.dataset = dataset
        self.visible_normalized = visible_normalized
        self.lr = lr
        self.lr_trend = lr_trend
        self.bs = bs
        self.epochs = epochs
        self.print_step = print_step
        self.verbose = verbose
        self.track_learning = track_learning
        self.track_method = track_method
        self.track_step = track_step
        self.KDE_bandwidth = KDE_bandwidth

        # Model parameters
        self.W = nn.Parameter(torch.randn(nh, nv) * 0.01)
        self.bv = nn.Parameter(torch.zeros(nv))
        self.bh = nn.Parameter(torch.zeros(nh))

        self.rec_errs = []  # Track reconstruction error over epochs

    def _hidden_prob(self, v):
        """Compute hidden probabilities p(h=1 | v)."""
        return torch.sigmoid(F.linear(v, self.W, self.bh))

    def _sample_hidden(self, v):
        """Sample hidden units h ~ p(h=1 | v)."""
        p = self._hidden_prob(v)
        return torch.bernoulli(p)

    def _visible_prob(self, h):
        """Compute visible probabilities p(v=1 | h)."""
        return torch.sigmoid(F.linear(h, self.W.t(), self.bv))

    def _sample_visible(self, h):
        """Sample visible units v ~ p(v=1 | h)."""
        p = self._visible_prob(h)
        if self.visible_normalized:
            return p  # Continuous in [0,1]
        else:
            return torch.bernoulli(p)

    def free_energy(self, v):
        """Compute the free energy of visible vector v."""
        v_term = torch.matmul(v, self.bv)
        h_term = torch.sum(F.softplus(F.linear(v, self.W, self.bh)), dim=1)
        return -v_term - h_term

    def forward(self, v_initial, v_given=None, n_steps=None):
        """Gibbs sampling starting from v_initial for n_steps."""
        v = v_initial.clone().detach()
        n_steps = n_steps or self.k
        for _ in range(n_steps):
            h = self._sample_hidden(v)
            v = self._sample_visible(h)
            if v_given is not None:
                v[v_given != -1] = v_given[v_given != -1]
        return v

    def sample(self, n_samples, v_given=None, therm=10000):
        """Generate samples from the model."""
        v = torch.bernoulli(torch.rand(n_samples, self.nv)).to(self.W.device)
        if v_given is not None:
            v[v_given != -1] = v_given[v_given != -1]
        v = self.forward(v, v_given, n_steps=therm)
        return v

    def fit(self, dataset):
        """Train the RBM on the dataset using CD or PCD."""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)

        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        if self.lr_trend == 'linear':
            lr_scheduler = LambdaLR(optimizer, lambda epoch: 1 - (epoch / self.epochs))
        else:
            lr_scheduler = LambdaLR(optimizer, lambda epoch: 1)

        dataset = dataset.to(device)
        n = len(dataset)
        memory = torch.bernoulli(torch.rand(self.bs, self.nv)).to(device)  # For PCD

        if self.track_learning:
            track_metrics = get_track_metrics(self.track_method)

        for epoch in tqdm(range(1, self.epochs + 1)):
            t0 = time.time()
            perm = torch.randperm(n)

            for i in range(0, n, self.bs):
                v0 = dataset[perm[i:i + self.bs]]
                v = v0.clone().detach()

                if self.training == 'PCD':
                    v = self.forward(memory)

                # Positive phase
                ph0 = self._hidden_prob(v0)
                # Negative phase
                hk = self._sample_hidden(v)
                vk = self._sample_visible(hk)
                phk = self._hidden_prob(vk)

                # Weight and bias updates
                dW = torch.einsum('bi,bj->ij', ph0, v0) - torch.einsum('bi,bj->ij', phk, vk)
                dW /= len(v0)

                dbv = torch.mean(v0 - vk, dim=0)
                dbh = torch.mean(ph0 - phk, dim=0)

                self.W.grad = -dW.to(device)
                self.bv.grad = -dbv.to(device)
                self.bh.grad = -dbh.to(device)

                optimizer.step()
                optimizer.zero_grad()

                if self.training == 'PCD':
                    memory = vk.detach()

            lr_scheduler.step()

            if epoch % self.track_step == 0 and self.track_learning:
                samples = self.sample(n_samples=len(dataset), therm=1000).detach().cpu()
                results = track_metrics(dataset.cpu(), samples, KDE_bandwidth=self.KDE_bandwidth)
                if 'rec_error' in results:
                    self.rec_errs.append(results['rec_error'])
                if self.verbose:
                    print(f"[Epoch {epoch}] " + " | ".join([f"{k}: {v:.4f}" for k, v in results.items()]))

            elif epoch % self.print_step == 0 and self.verbose:
                print(f"[Epoch {epoch}] Time: {time.time() - t0:.1f}s")

    def get_reconstruction_error(self, dataset, n_samples=10000, therm=10000):
        """
        Compute reconstruction error: L2 distance to nearest generated sample.
        """
        with torch.no_grad():
            samples = self.sample(n_samples=n_samples, therm=therm)
            dataset = dataset.to(samples.device)

            distances = torch.cdist(dataset.float(), samples.float(), p=2)
            min_distances = distances.min(dim=1).values
            return torch.mean(min_distances).item()

    def check_overfit(self, train_data, val_data):
        """
        Estimate overfitting as free energy difference between train and val sets.
        """
        return self.free_energy(train_data).mean() - self.free_energy(val_data).mean()

    def recPlot(self):
        """Plot reconstruction error over epochs (log scale)."""
        if self.rec_errs:
            plt.plot(self.rec_errs)
            plt.yscale('log')
            plt.xlabel("Tracking step")
            plt.ylabel("Reconstruction error (L2)")
            plt.title("Reconstruction error over training")
            plt.show()

    def saveModel(self, filename=None, pickle_protocol=None):
        """Save model to file."""
        if filename is None:
            filename = f"./rbm_nv{self.nv}_nh{self.nh}.pt"
        torch.save(self.state_dict(), filename, pickle_protocol=pickle_protocol)


##Example Initalisation and Training:

In [None]:

dataset = []
dataset.to('cuda')

# number of hidden units
nh = 5
# number of visible units
nv = 30

rbm5H = RBM(nv=nv,
          nh=nh,
          k=2,
          training='PCD',
          visible_normalized=True,
          lr=0.007,
          lr_trend='linear',
          bs=1000,
          epochs=10000,
          print_step=100,
          verbose=True,
          track_method=['rec_error'],
          track_learning=True,
          track_step = 100)

rbm5H.fit(dataset)
rbm5H.recPlot()