In [None]:
import h5py

import scipy

import numpy as np
import matplotlib.pyplot as plt
import torch
from rbms.dataset import load_dataset

from rbms.bernoulli_bernoulli.functional import init_chains
from rbms.sampling.gibbs import sample_state
from rbms.io import load_model
from rbms.utils import get_saved_updates
from rbms.bernoulli_bernoulli.classes import BBRBM
from rbms.utils import compute_log_likelihood
from rbms.io import load_params
from rbms.plot import plot_PCA
from rbms.utils import get_flagged_updates
from rbms.utils import get_eigenvalues_history

from fastrbm.trajectory.pt import ptt_sampling, init_sampling
from fastrbm.io import load_rcm
from fastrbm.trajectory.partition_function import compute_partition_function_ptt

device = "cuda"
dtype = torch.float32

# use LaTeX fonts in the plots
plt.rcParams["mathtext.fontset"] = "stix"
plt.rcParams["font.family"] = "STIXGeneral"
plt.rcParams.update({"font.size": 15})
plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"\usepackage{bm}"

%load_ext autoreload
%autoreload 2

In [None]:
dataset, _ = load_dataset(
    "../../data/MNIST_train.h5",
    subset_labels=[0, 1],
    train_size=1.0,
    device=device,
    dtype=dtype,
)
print(dataset)

In [None]:
filename = "../../output/rbm/MNIST01_012_wb_PCD_test.h5"


import rbms
import rbms.parser
updates = get_saved_updates(filename)
params, chains, t, hyperparameters = load_model(filename, updates[-1], device, dtype)
print(hyperparameters)

rng = np.random.default_rng(hyperparameters["seed"])
train_dataset, test_dataset = dataset.split_train_test(rng, hyperparameters["train_size"])

In [None]:
data = train_dataset.data.to(device).float()

U_data, S_data, V_dataT = torch.linalg.svd(data - data.mean(0))
data_proj = data @ V_dataT.mT
data_proj = data_proj.cpu().numpy()
pc_proj = chains["visible"] @ V_dataT.mT
pc_proj = pc_proj.cpu().numpy()


for dir1 in range(0, 5, 2):
    plot_PCA(
        data_proj,
        pc_proj,
        labels=["dataset", "Permanent chains"],
        dir1=dir1,
        dir2=dir1 + 1,
    )

In [None]:
updates = get_saved_updates(filename)

print(updates)
params, chains, t, hyperparameters = load_model(filename, updates[0], device, dtype)
print(t)
w, vbias, hbias = params.weight_matrix, params.vbias, params.hbias
print(hyperparameters["learning_rate"])
params.hbias.shape

ptt_updates = get_flagged_updates(filename, "ptt")
print(ptt_updates)

In [None]:
x, y = get_eigenvalues_history(filename)
fig, ax = plt.subplots(1,1)
ax.plot(x, y)
for i in range(len(ptt_updates)):
    plt.vlines(ptt_updates[i], 0, y.max(), color="grey", linestyles="dashed")
ax.semilogx()
ax.set_title(r"Singular values of $\bm W$")
ax.set_xlabel("Training time (gradient updates)")
ax.set_ylabel(r"$\bm w$")

# Parallel Trajectory Tempering

In [None]:
# Load the RCM to sample the first model
rcm = load_rcm(filename, device=device, dtype=dtype)

# Load the parameters saved during training
list_params = []
for upd in updates:
    list_params.append(load_params(filename, upd, device, dtype))

# Perform an annealing to initialize the chains 
list_chains = init_sampling(2000, list_params, device=device, dtype=dtype, rcm=rcm)

# PTT Sampling
list_chains, _, _ = ptt_sampling(
    list_params, list_chains, index=None,rcm = rcm, it_mcmc=1000, increment=1, show_pbar=True
)

In [None]:
# We can take a look at the samples for some of the machines 
idx_plot = list(range(0,len(list_chains), 1))

for idx in idx_plot:

    pc_proj = list_chains[idx]["visible"] @ V_dataT.mT
    pc_proj = pc_proj.cpu().numpy()
    plot_PCA(data_proj, pc_proj, ["dataset", f"update {updates[idx]}"])
    plt.show()

In [None]:
log_z = compute_partition_function_ptt(list_params, list_chains)
train_ll_ptt = []
test_ll_ptt = []
for i in range(len(list_params)):
    train_ll_ptt.append(
        compute_log_likelihood(
            train_dataset.data, train_dataset.weights, list_params[i], log_z[i]
        )
    )
    test_ll_ptt.append(
        compute_log_likelihood(
            test_dataset.data, test_dataset.weights, list_params[i], log_z[i]
        )
    )

In [None]:
# Recover the log partition functions saved during training
ais_traj_train_ll = []
ais_traj_test_ll = []
for upd in updates:
    with h5py.File(filename, "r") as f:
        params = load_params(filename, upd, device, dtype)
        log_z = f[f"update_{upd}"]["log_z"][()]
        if isinstance(log_z, np.ndarray):
            log_z = log_z[0]
        # Compute the associated log-likelihood
        ais_traj_train_ll.append(
            compute_log_likelihood(
                train_dataset.data, train_dataset.weights, params, log_z
            )
        )
        ais_traj_test_ll.append(
            compute_log_likelihood(
                test_dataset.data, test_dataset.weights, params, log_z
            )
        )

In [None]:
from rbms.partition_function.ais import compute_partition_function_ais
from tqdm.notebook import tqdm
all_train_ll_ais_temp = []
all_test_ll_ais_temp = []
for upd in tqdm(updates):
    params = load_params(filename, upd, device, dtype)
    log_z = compute_partition_function_ais(1000, 5000, params)
    all_train_ll_ais_temp.append(compute_log_likelihood(train_dataset.data, train_dataset.weights, params, log_z))
    all_test_ll_ais_temp.append(compute_log_likelihood(test_dataset.data, test_dataset.weights, params, log_z))

In [None]:
fig, ax = plt.subplots(1,1)

ax.plot(updates, train_ll_ptt, color="green", label="PTT estimate")
ax.plot(
    updates,
    test_ll_ptt,
    linestyle="dashed",
    color="green",
)
ax.plot(
    updates,
    ais_traj_train_ll,
    color="red",
    label="AIS traj estimate",
)
ax.plot(
    updates,
    ais_traj_test_ll,
    linestyle="dashed",
    color="red",
)
ax.plot(
    updates,
    all_train_ll_ais_temp,
    color="blue",
    label=r"AIS $\beta$ estimate",
)
ax.plot(
    updates,
    all_test_ll_ais_temp,
    linestyle="dashed",
    color="blue",
)
ax.semilogx()
ax.legend()
ax.set_title("LL HGD")
ax.set_xlabel("Training time (gradient updates)")
ax.set_ylabel("LL (nats)")