In [1]:
import torch
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
import gc
from torch.distributions.multivariate_normal import MultivariateNormal

from timeit import default_timer as timer

torch.cuda.is_available()

True

In [2]:
def init_spins(P):
    
    spins_base = np.full((P, 2), np.array([-1., 1.]))
    
    spins = np.reshape(np.array(np.meshgrid(*spins_base)).T, newshape = (-1,) + (P, 1))
    
    # spins_T = jnp.transpose(spins, axes = (0, 2, 1))
    
    # return spins_T, spins
    return spins

In [3]:
def analytical_free_energy_difference(student_A, student_B, beta, P, device):
    
    xi_A = student_A.xi
    
    xi_B = student_B.xi
    
    spins = torch.from_numpy(np.float32(init_spins(P))).to(device)
    
    H = 1/2 * beta**2 * torch.sum((xi_A @ spins)**2, dim = 1)
    c = torch.max(H)
    print("Epsilon: %.2f" % torch.max(beta * xi_A @ spins))
    f_A = c + torch.log(torch.mean(torch.exp(H - c)))
    
    H = 1/2 * beta**2 * torch.sum((xi_B @ spins)**2, dim = 1)
    c = torch.max(H)
    print("Epsilon: %.2f" % torch.max(beta * xi_B @ spins))
    f_B = c + torch.log(torch.mean(torch.exp(H - c)))
    
    f_difference_1 = f_B - f_A
    
    H = torch.sum(logcosh(beta * xi_A @ spins), dim = 1)
    c = torch.max(H)
    f_A = c + torch.log(torch.mean(torch.exp(H - c)))
    
    H = torch.sum(logcosh(beta * xi_B @ spins), dim = 1)
    c = torch.max(H)
    f_B = c + torch.log(torch.mean(torch.exp(H - c)))
    
    f_difference_2 = f_B - f_A
    
    # 22.269562
    # 0.23609748
    # 22.390316
    # 0.22040498
    # -0.3633461
    
    return f_difference_1, f_difference_2

In [4]:
def logcosh(x):
    c = torch.maximum(-x, x)
    
    return c + torch.log1p(torch.exp(-2*c)) - torch.log(torch.tensor(2))

In [5]:
def jax_logcosh(x):
    c = jnp.maximum(-x, x)
    
    return c + jnp.log1p(jnp.exp(-2*c)) - jnp.log(2)

In [6]:
def rademacher(prob, generator : torch.Generator):
    return 2*torch.bernoulli(prob, generator = generator)-1

In [7]:
def jax_rademacher(prob, key : int):
    return 2*jax.random.bernoulli(key, prob)-1

In [8]:
class RBM(torch.nn.Module):
    def __init__(self, N : int, P : int, standard_deviation : float, device : torch.device, random_number_generator : torch.Generator):
        super(RBM, self).__init__()
        
        self.N = N
        self.P = P
        
        self.training_device = device
        
        self.random_number_generator = random_number_generator
        
        self.xi = torch.nn.Parameter(torch.zeros((N, P), device = device), requires_grad = False)
        
        self.initialize_weights(standard_deviation)
        
        self.descent_vector = torch.nn.Parameter(torch.zeros((N, P), device = device), requires_grad = False)
    
    # @torch.jit.export
    def initialize_weights(self, standard_deviation : float):
        xi = torch.randn((self.N, self.P), device = self.training_device, generator = self.random_number_generator)
        
        xi.copy_((xi - torch.flip(xi, dims = (0,)))/torch.sqrt(torch.tensor(2)))
    
        C = (torch.transpose(xi, 0, 1) @ xi) / self.N
        
        L = torch.linalg.cholesky(C)
        xi.copy_(torch.linalg.solve_triangular(L, xi.T, upper = False).T)
        
        self.xi.copy_(standard_deviation * xi)
    
    # @torch.jit.export
    def sample_hidden_given_visible(self, tau, sigma, beta : float):
        P_tau_given_sigma = torch.sigmoid(beta * sigma @ self.xi)
        tau.copy_(rademacher(P_tau_given_sigma, generator = self.random_number_generator))
        
        return P_tau_given_sigma
    
    # @torch.jit.export
    def sample_visible_given_hidden(self, sigma, tau, beta : float):
        P_sigma_given_tau = torch.sigmoid(beta * tau @ torch.transpose(self.xi, 0, 1))
        sigma.copy_(rademacher(P_sigma_given_tau, generator = self.random_number_generator))
        
        return P_sigma_given_tau
    
    # @torch.jit.export
    def sample_visible_and_hidden(self, sigma, tau, beta : float, number_sampling_steps : int,
                                  number_monitored_sampling_steps : int, calculate_loss : bool):
        
        f_0 = torch.tensor(0.)
        f = torch.tensor(0.)
        
        P_sigma_given_tau = torch.zeros_like(sigma, device = self.training_device)
        P_tau_given_sigma = torch.zeros_like(tau, device = self.training_device)
        
        if number_monitored_sampling_steps != 0:
            f_0 = torch.mean(self.free_entropy(sigma, beta))
            print("Step [{}/{}], free entropy: {:.4f}".format(0, number_sampling_steps, f_0))
        
        elif calculate_loss:
            f_0 = torch.mean(self.free_entropy(sigma, beta))
        
        for sampling_step in range(1, number_sampling_steps + 1):
            if number_monitored_sampling_steps != 0:
                monitor_this_step = sampling_step % (number_sampling_steps // number_monitored_sampling_steps) == 0
            else:
                monitor_this_step = False
            
            # if anneal:
                # beta_cur = beta * sampling_step / number_sampling_steps
            # else:
                # beta_cur = beta
            
            P_sigma_given_tau = self.sample_visible_given_hidden(sigma, tau, beta)
            # print(P_sigma_given_tau)
            
            P_tau_given_sigma = self.sample_hidden_given_visible(tau, sigma, beta)
            
            if monitor_this_step:
                f = torch.mean(self.free_entropy(sigma, beta))
                print("Step [{}/{}], free entropy: {:.4f}".format(sampling_step, number_sampling_steps, f))
        
        if calculate_loss:
            f = torch.mean(self.free_entropy(sigma, beta))
            loss = f - f_0
        else:
            loss = torch.tensor(0.)
        
        return P_sigma_given_tau, loss
    
    # @torch.jit.export
    def sample_visible(self, sigma, beta : float, number_sampling_steps : int,
                       number_monitored_sampling_steps : int):
        
        tau = torch.zeros((len(sigma), self.P), device = self.training_device)
        
        P_tau_given_sigma = self.sample_hidden_given_visible(tau, sigma, beta)
        
        P_sigma_given_tau, _ = self.sample_visible_and_hidden(sigma, tau, beta, number_sampling_steps,
                                                              number_monitored_sampling_steps = number_monitored_sampling_steps,
                                                              calculate_loss = False)
        
        del tau
        
        return P_sigma_given_tau
    
    # @torch.jit.export
    def free_entropy(self, sigma, beta : float):
        f = torch.sum(logcosh(beta * sigma @ self.xi), dim = 1)
        
        return f
    
    # @torch.jit.export
    def contrastive_divergence(self, sigma, beta : float, number_sampling_steps : int,
                               monitor_sampling : bool, calculate_loss : bool):
        if monitor_sampling:
            number_monitored_sampling_steps = number_sampling_steps
        else:
            number_monitored_sampling_steps = 0
        
        tau = torch.zeros((len(sigma), self.P), device = self.training_device)
        
        P_tau_given_sigma = self.sample_hidden_given_visible(tau, sigma, beta)
        
        positive_gradient = torch.mean(torch.reshape(sigma, (-1, self.N, 1)) @ torch.reshape(tau, (-1, 1, self.P)), dim = 0)
        
        _, loss = self.sample_visible_and_hidden(sigma, tau, beta, number_sampling_steps,
                                                 number_monitored_sampling_steps = number_monitored_sampling_steps,
                                                 calculate_loss = calculate_loss)
        
        negative_gradient = torch.mean(torch.reshape(sigma, (-1, self.N, 1)) @ torch.reshape(tau, (-1, 1, self.P)), dim = 0)
        
        del tau
        
        gradient = positive_gradient - negative_gradient
        
        reconstruction_error = torch.sum(gradient**2)
        
        gradient = beta*gradient
        
        return loss, reconstruction_error, gradient
    
    # @torch.jit.export
    def weights_update(self, sigma, beta : float, alpha : float, learning_rate : float,
                       decay_rate : float, momentum : float, number_sampling_steps : int,
                       monitor_sampling : bool, calculate_loss : bool):
        
        loss, reconstruction_error, gradient = self.contrastive_divergence(sigma, beta, number_sampling_steps,
                                                                           monitor_sampling = monitor_sampling,
                                                                           calculate_loss = calculate_loss)
        
        noise = torch.randn((self.N, self.P), device = self.training_device, generator = self.random_number_generator)
        
        with torch.no_grad():
            gradient = alpha * gradient - decay_rate * self.xi
            
            self.descent_vector.copy_(momentum * self.descent_vector + learning_rate * gradient + torch.sqrt(torch.tensor(2 * (1 - momentum) / self.N)) * noise)
            
            self.xi.copy_(self.xi + learning_rate * self.descent_vector)
            
            # if normalize:
                # self.xi.copy_(self.xi / torch.linalg.norm(self.xi, ord = 2, dim = 0))
        
        return loss, reconstruction_error
    
    # @torch.jit.export
    def train_weights(self, loader : torch.utils.data.DataLoader, beta : float, alpha : float, initial_learning_rate : float,
                      decay_rate : float, momentum : float, number_sampling_steps : int,
                      number_training_epochs : int, number_monitored_training_epochs : int, monitor_sampling : bool):
        
        for training_epoch in range(1, number_training_epochs + 1):
            if number_monitored_training_epochs != 0:
                monitor_training_this_epoch = training_epoch % (number_training_epochs // number_monitored_training_epochs) == 0
            else:
                monitor_training_this_epoch = False
            
            # learning_rate = initial_learning_rate / (1 + training_epoch)
            learning_rate = initial_learning_rate
            
            average_loss = torch.tensor(0.)
            average_reconstruction_error = torch.tensor(0.)
            
            for batch, sigma_batch in enumerate(loader):
                
                if number_monitored_training_epochs != 0:
                    monitor_sampling_this_epoch = monitor_training_this_epoch & monitor_sampling
                else:
                    monitor_sampling_this_epoch = False
                
                sigma_batch = sigma_batch.view(-1, self.N).to(self.training_device)
                
                loss, reconstruction_error = self.weights_update(sigma_batch, beta, alpha, learning_rate, decay_rate, momentum, number_sampling_steps,
                                                                 monitor_sampling = monitor_sampling_this_epoch, calculate_loss = monitor_training_this_epoch)
                
                if monitor_training_this_epoch:
                    average_loss += (loss.detach().item() - average_loss) / (batch + 1)
                    average_reconstruction_error = (reconstruction_error.detach().item() - average_reconstruction_error) / (batch + 1)
            
            if monitor_training_this_epoch:
                print("Epoch [{}/{}], loss: {:.4f}, reconstruction error: {:.4f}"
                      .format(training_epoch, number_training_epochs, average_loss, average_reconstruction_error))

In [None]:
class RBM_combination():
    def __init__(self, RBM_A, RBM_B, beta : float):
        
        self.xi = jnp.array([[RBM_A.xi.detach().cpu().numpy()],
                             [RBM_B.xi.detach().cpu().numpy()]])
        ### (2, 1, N, P)
        
        # self.xi = jnp.array([RBM_A.xi.detach().cpu().numpy(), RBM_B.xi.detach().cpu().numpy()])
        # (2, N, P)
        
        self.beta = beta
    
    @partial(jax.jit, static_argnums = 0)
    def energy(self, sigma):
        
        H = self.beta * sigma @ self.xi
        ### (2, 2, M, P) = (2, M, N) @ (2, 1, N, P)
        
        # H = self.beta * sigma @ self.xi
        # (2, M, P) = (M, N) @ (2, N, P)
        
        return H
    
    # def free_entropy(self, H_A, H_B, t : float):
    @partial(jax.jit, static_argnums = 0)
    def free_entropy(self, H, t : float):
        
        # f = jax.sum(logcosh((1 - t) * H_A), axis = -1) + jax.sum(logcosh(t * H_B), axis = -1)
        f = jnp.sum(jnp.sum(jax_logcosh(t * H), axis = -1), axis = 0)
        ### (2, M)
        
        # f = jnp.sum(jnp.sum(jax_logcosh(t * H), axis = -1), axis = 0)
        # (M) = reduce((2, M, P), axes = (0, 2))
        
        return f
    
    # def free_entropy_A(self, sigma, beta : float):
        # f = jnp.sum(logcosh(beta * sigma @ self.xi_A), axis = -1)
        
        # return f
    
    # def free_entropy_B(self, sigma, beta : float):
        # f = jnp.sum(logcosh(beta * sigma @ self.xi_B), axis = -1)
        
        # return f
    
    # def sample_hidden_given_visible(self, H_A, H_B, t : float, key : int):
    @partial(jax.jit, static_argnums = 0)
    def sample_hidden_given_visible(self, H, t : float, key : int):
        # key, key_tau_A, key_tau_B = jax.random.split(key, num = 3)
        key, key_tau = jax.random.split(key, num = 2)
        
        # P_tau_A_given_sigma = jax.nn.sigmoid((1 - t) * H_A)
        # tau_A = jax_rademacher(P_tau_A_given_sigma, key_tau_A)
        
        # P_tau_B_given_sigma = jax.nn.sigmoid(t * H_B)
        # tau_B = jax_rademacher(P_tau_B_given_sigma, key_tau_B)
        
        P_tau_given_sigma = jax.nn.sigmoid(t * H)
        tau = jax_rademacher(P_tau_given_sigma, key_tau)
        ### (2, 2, M, P)
        
        # P_tau_given_sigma = jax.nn.sigmoid(t * H)
        # tau = jax_rademacher(P_tau_given_sigma, key_tau)
        # (2, M, P)
        
        # return tau_A, tau_B, key
        return tau, key
    
    # def sample_visible_given_hidden(self, tau_A, tau_B, t : float, key : int):
    @partial(jax.jit, static_argnums = 0)
    def sample_visible_given_hidden(self, tau, t : float, key : int):
        key, key_sigma = jax.random.split(key, num = 2)
        
        # P_sigma_given_tau = jax.nn.sigmoid((1 - t) * self.beta * tau_A @ torch.transpose(self.xi_A, (0, 1)) + t * self.beta * tau_B @ torch.transpose(self.xi_B, (0, 1)))
        # sigma = jax_rademacher(P_sigma_given_tau, key_sigma)
        
        P_sigma_given_tau = jax.nn.sigmoid(jnp.sum(t * self.beta * tau @ jnp.transpose(self.xi, (0, 1, 3, 2)), axis = 0))
        sigma = jax_rademacher(P_sigma_given_tau, key_sigma)
        
        # P_sigma_given_tau = jax.nn.sigmoid(jnp.sum(t * self.beta * tau @ jnp.transpose(self.xi, (0, 2, 1)), axis = 0))
        # sigma = jax_rademacher(P_sigma_given_tau, key_sigma)
        # (2, M, P) = (M, N) @ (2, N, P)
        # (M, N) = sum((2, M, P) @ (2, P, N), axis = 0)
        
        return sigma, key
    
    # def update_visible(self, H_A, H_B, t : float, key : int):
    @partial(jax.jit, static_argnums = 0)
    def update_visible(self, H, t : float, key : int):
        
        tau, key = self.sample_hidden_given_visible(H, t, key)
        
        sigma, key = self.sample_visible_given_hidden(tau, t, key)
        
        return sigma, key

In [9]:
class RBM_combination():
    def __init__(self, RBM_A, RBM_B, beta : float):
        
        ### self.xi = jnp.array([[RBM_A.xi.detach().cpu().numpy()],
                             ### [RBM_B.xi.detach().cpu().numpy()]])
        ### (2, 1, N, P)
        
        self.xi = jnp.array([RBM_A.xi.detach().cpu().numpy(), RBM_B.xi.detach().cpu().numpy()])
        # (2, N, P)
        
        self.beta = beta
    
    @partial(jax.jit, static_argnums = 0)
    def energy(self, sigma):
        
        ### H = self.beta * sigma @ self.xi
        ### (2, 2, M, P) = (2, M, N) @ (2, 1, N, P)
        
        H = self.beta * sigma @ self.xi
        # (2, M, P) = (M, N) @ (2, N, P)
        
        return H
    
    # def free_entropy(self, H_A, H_B, t : float):
    @partial(jax.jit, static_argnums = 0)
    def free_entropy(self, H, t : float):
        
        # f = jax.sum(logcosh((1 - t) * H_A), axis = -1) + jax.sum(logcosh(t * H_B), axis = -1)
        ### f = jnp.sum(jnp.sum(jax_logcosh(t * H), axis = -1), axis = 0)
        ### (2, M)
        
        f = jnp.sum(jax_logcosh(t * H), axis = -1)
        
        # student.free_entropy = torch.sum(logcosh(beta * sigma @ self.xi), dim = 1)
        # torch.mean(student_A.free_entropy(sigma_teacher, beta)) - torch.mean(student_B.free_entropy(sigma_teacher, beta)) + f_difference
        # (2, M) = reduce((2, M, P), axes = (2))
        
        return f
    
    # def free_entropy_A(self, sigma, beta : float):
        # f = jnp.sum(logcosh(beta * sigma @ self.xi_A), axis = -1)
        
        # return f
    
    # def free_entropy_B(self, sigma, beta : float):
        # f = jnp.sum(logcosh(beta * sigma @ self.xi_B), axis = -1)
        
        # return f
    
    # def sample_hidden_given_visible(self, H_A, H_B, t : float, key : int):
    @partial(jax.jit, static_argnums = 0)
    def sample_hidden_given_energy(self, H, t : float, key : int):
        # key, key_tau_A, key_tau_B = jax.random.split(key, num = 3)
        key, key_tau = jax.random.split(key, num = 2)
        
        # P_tau_A_given_sigma = jax.nn.sigmoid((1 - t) * H_A)
        # tau_A = jax_rademacher(P_tau_A_given_sigma, key_tau_A)
        
        # P_tau_B_given_sigma = jax.nn.sigmoid(t * H_B)
        # tau_B = jax_rademacher(P_tau_B_given_sigma, key_tau_B)
        
        ### P_tau_given_sigma = jax.nn.sigmoid(t * H)
        ### tau = jax_rademacher(P_tau_given_sigma, key_tau)
        ### (2, 2, M, P)
        
        P_tau_given_sigma = jax.nn.sigmoid(t * H)
        tau = jax_rademacher(P_tau_given_sigma, key_tau)
        # (2, M, P)
        
        # return tau_A, tau_B, key
        return tau, key
    
    # def sample_visible_given_hidden(self, tau_A, tau_B, t : float, key : int):
    @partial(jax.jit, static_argnums = 0)
    def sample_energy_given_hidden(self, tau, t : float, key : int):
        key, key_sigma = jax.random.split(key, num = 2)
        
        # P_sigma_given_tau = jax.nn.sigmoid((1 - t) * self.beta * tau_A @ torch.transpose(self.xi_A, (0, 1)) + t * self.beta * tau_B @ torch.transpose(self.xi_B, (0, 1)))
        # sigma = jax_rademacher(P_sigma_given_tau, key_sigma)
        
        ### P_sigma_given_tau = jax.nn.sigmoid(jnp.sum(t * self.beta * tau @ jnp.transpose(self.xi, (0, 1, 3, 2)), axis = 0))
        ### sigma = jax_rademacher(P_sigma_given_tau, key_sigma)
        
        P_sigma_given_tau = jax.nn.sigmoid(jnp.sum(t * self.beta * tau @ jnp.transpose(self.xi, (0, 2, 1)), axis = 0))
        sigma = jax_rademacher(P_sigma_given_tau, key_sigma)
        # (2, M, P) = (M, N) @ (2, N, P)
        # (M, N) = sum((2, M, P) @ (2, P, N), axis = 0)
        H = self.energy(sigma)
        
        return H, key
    
    # def update_visible(self, H_A, H_B, t : float, key : int):
    @partial(jax.jit, static_argnums = 0)
    def update_energy(self, H, t : float, key : int):
        
        tau, key = self.sample_hidden_given_visible(H, t, key)
        
        H, key = self.sample_visible_given_hidden(tau, t, key)
        
        return H, key

In [10]:
class RBM_combination():
    def __init__(self, RBM_A, RBM_B, beta : float):
        
        self.xi_A = RBM_A.xi.detach().cpu().numpy()
        self.xi_B = RBM_B.xi.detach().cpu().numpy()
        # (N, P)
        
        self.beta = beta
    
    @partial(jax.jit, static_argnums = 0)
    def energy(self, sigma):
        
        ### (M, P) = (M, N) @ (N, P)
        
        # H_A = self.beta * sigma @ self.xi_A
        # H_B = self.beta * sigma @ self.xi_B
        
        H_A, H_B = self.beta * sigma @ jnp.array([self.xi_A, self.xi_B])
        
        return H_A, H_B
    
    @partial(jax.jit, static_argnums = 0)
    def free_entropy(self, H):
        f = jnp.sum(jax_logcosh(H), axis = -1)
        
        return f
    
    @partial(jax.jit, static_argnums = 0)
    def sample_hidden_given_energy(self, H_A, H_B, t : float, key : int):
        # key, key_tau_A, key_tau_B = jax.random.split(key, num = 3)
        key, key_tau = jax.random.split(key, num = 2)
        
        # P_tau_A_given_sigma = jax.nn.sigmoid((1 - t) * H_A)
        # tau_A = jax_rademacher(P_tau_A_given_sigma, key_tau_A)
        
        # P_tau_B_given_sigma = jax.nn.sigmoid(t * H_B)
        # tau_B = jax_rademacher(P_tau_B_given_sigma, key_tau_B)
        
        P_tau_given_sigma = jax.nn.sigmoid(jnp.array([(1 - t) * H_A, t * H_B]))
        tau_A, tau_B = jax_rademacher(P_tau_given_sigma, key_tau)
        
        return tau_A, tau_B, key
    
    # def sample_visible_given_hidden(self, tau_A, tau_B, t : float, key : int):
    @partial(jax.jit, static_argnums = 0)
    def sample_energy_given_hidden(self, tau_A, tau_B, t : float, key : int):
        key, key_sigma = jax.random.split(key, num = 2)
        
        # tau_A = (1 - t) * self.beta * tau_A
        # tau_B = t * self.beta * tau_B
        
        P_sigma_given_tau = jax.nn.sigmoid((1 - t) * self.beta * tau_A @ jnp.transpose(self.xi_A, (1, 0))
                                           + t * self.beta * tau_B @ jnp.transpose(self.xi_B, (1, 0)))
        
        sigma = jax_rademacher(P_sigma_given_tau, key_sigma)
        
        H_A, H_B = self.energy(sigma)
        
        return H_A, H_B, key
    
    # def update_visible(self, H_A, H_B, t : float, key : int):
    @partial(jax.jit, static_argnums = 0)
    def update_energy(self, H_A, H_B, t : float, key : int):
        
        tau_A, tau_B, key = self.sample_hidden_given_energy(H_A, H_B, t, key)
        
        H_A, H_B, key = self.sample_energy_given_hidden(tau_A, tau_B, t, key)
        
        return H_A, H_B, key

In [None]:
# number_annealing_steps + 1, partial(student_combination, beta, number_annealing_steps),
                                                            # (f_difference, H_A, H_B, student_combination.key)
@partial(jax.jit, static_argnums = (0, 1))
# @partial(jax.jit, static_argnums = (0, 1, 2))
def free_entropy_difference(student_combination, number_annealing_steps, annealing_step, carry_over):
# def free_entropy_difference(student_combination, number_annealing_steps, iterate_backward, annealing_step, carry_over):
    f_difference, H, key = carry_over
    
    t = jax.numpy.array([[1 - annealing_step/number_annealing_steps, annealing_step/number_annealing_steps],
                         [annealing_step/number_annealing_steps, 1 - annealing_step/number_annealing_steps]])[..., jnp.newaxis, jnp.newaxis]
    
    # if iterate_backward:
        # t = jax.numpy.array([annealing_step/number_annealing_steps, 1 - annealing_step/number_annealing_steps])[..., jnp.newaxis, jnp.newaxis]
    # else:
        # t = jax.numpy.array([1 - annealing_step/number_annealing_steps, annealing_step/number_annealing_steps])[..., jnp.newaxis, jnp.newaxis]
    
    f_difference += student_combination.free_entropy(H, t)
    
    sigma, key = student_combination.update_visible(H, t, key)
    
    H = student_combination.energy(sigma)
    
    f_difference -= student_combination.free_entropy(H, t)
    
    return f_difference, H, key

In [None]:
# number_annealing_steps + 1, partial(student_combination, beta, number_annealing_steps),
                                                            # (f_difference, H_A, H_B, student_combination.key)
### @partial(jax.jit, static_argnums = (0, 1))
@partial(jax.jit, static_argnums = (0, 1, 2))
### def free_entropy_difference(student_combination, number_annealing_steps, annealing_step, carry_over):
def free_entropy_difference(student_combination, number_annealing_steps, iterate_backward, annealing_step, carry_over):
    f_difference, H, key = carry_over
    
    # t = jax.numpy.array([[1 - annealing_step/number_annealing_steps, annealing_step/number_annealing_steps],
                         # [annealing_step/number_annealing_steps, 1 - annealing_step/number_annealing_steps]])[..., jnp.newaxis, jnp.newaxis]
    
    if iterate_backward:
        t = jax.numpy.array([annealing_step/number_annealing_steps, 1 - annealing_step/number_annealing_steps])[..., jnp.newaxis, jnp.newaxis]
    else:
        t = jax.numpy.array([1 - annealing_step/number_annealing_steps, annealing_step/number_annealing_steps])[..., jnp.newaxis, jnp.newaxis]
    
    f_difference += jnp.sum(student_combination.free_entropy(H, t), axis = 0)
    
    H, key = student_combination.update_energy(H, t, key)
    
    f_difference -= jnp.sum(student_combination.free_entropy(H, t), axis = 0)
    
    return f_difference, H, key

In [11]:
@partial(jax.jit, static_argnums = (0, 1, 2))
def free_entropy_difference(student_combination, number_annealing_steps, iterate_backward, annealing_step, carry_over):
    f_difference, H_A, H_B, key = carry_over
    
    # t = jax.numpy.array([[1 - annealing_step/number_annealing_steps, annealing_step/number_annealing_steps],
                         # [annealing_step/number_annealing_steps, 1 - annealing_step/number_annealing_steps]])[..., jnp.newaxis, jnp.newaxis]
    
    if iterate_backward:
        t = 1 - annealing_step/number_annealing_steps
    else:
        t = annealing_step/number_annealing_steps
    
    f_difference += (1 - t) * student_combination.free_entropy(H_A) + t * student_combination.free_entropy(H_B)
    
    H_A, H_B, key = student_combination.update_energy(H_A, H_B, t, key)
    
    f_difference += -(1 - t) * student_combination.free_entropy(H_A) - t * student_combination.free_entropy(H_B)
    
    return f_difference, H_A, H_B, key

In [None]:
def evaluate(teacher, student_A, student_B, beta, M, number_sampling_steps,
             number_annealing_steps, number_annealing_runs,
             number_monitored_sampling_steps, number_monitored_annealing_runs, seed):
    
    student_combination = RBM_combination(student_A, student_B, beta)
    
    device = teacher.training_device
    
    N = teacher.N
    # P = teacher.P
    # P_1 = student_1.P
    # P_2 = student_2.P
    
    alpha = M/N
    
    random_number_generator = teacher.random_number_generator
    key = jax.random.PRNGKey(seed)
    
    sigma_teacher_torch = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)
    sigma_student_torch = torch.nn.Parameter(torch.zeros((2, M, N), device = device), requires_grad = False)
    
    # number_sampling_steps = 10
    # number_monitored_sampling_steps = 10
    
    # number_annealing_runs = 20
    
    t_i = jnp.array([[1., 0.],
                     [0., 1.]])[..., jnp.newaxis, jnp.newaxis]
            
    t_f = jnp.array([[0., 1.],
                     [1., 0.]])[..., jnp.newaxis, jnp.newaxis]
            
    t_diff = jnp.array([[1.],
                        [-1.]])[..., jnp.newaxis, jnp.newaxis]
    
    mean_f_difference = 0
    var_f_difference = 0
        
    mean_log_likelihood_difference = 0
    var_log_likelihood_difference = 0
    
    for annealing_run in range(1, number_annealing_runs + 1):
        
        if number_monitored_annealing_runs != 0:
            monitor_annealing_this_run = annealing_run % (number_annealing_runs // number_monitored_annealing_runs) == 0
        else:
            monitor_annealing_this_run = False
        
        sigma_teacher_torch.copy_(torch.sign(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator))))
        sigma_student_torch.copy_(torch.sign(torch.sign(torch.randn((2, M, N), device = device, generator = random_number_generator))))
        
        teacher.sample_visible(sigma_teacher_torch, beta, number_sampling_steps,
                               number_monitored_sampling_steps)
        
        student_A.sample_visible(sigma_student_torch[0], beta, number_sampling_steps,
                                 number_monitored_sampling_steps)
        
        student_B.sample_visible(sigma_student_torch[1], beta, number_sampling_steps,
                                 number_monitored_sampling_steps)
        
        sigma_teacher = jnp.array(sigma_teacher_torch.detach().cpu().numpy())
        sigma_student = jnp.array(sigma_student_torch.detach().cpu().numpy())
        
        H = student_combination.energy(sigma_student)
        
        sigma_student, key = student_combination.update_visible(H, t_i, key)
        
        H = student_combination.energy(sigma_student)
        
        f_difference = -student_combination.free_entropy(H, t_i)
        
        ### Bottleneck!
        start = timer()
        
        f_difference, H, key = jax.lax.fori_loop(1, number_annealing_steps + 1, partial(free_entropy_difference, student_combination, number_annealing_steps),
                                                 (f_difference, H, key))
        
        end = timer()
        print("Time elapsed: %.4f." % (end - start))
        
        # f_difference.copy_(f_difference + student_B.free_entropy(sigma_student, beta))
        f_difference += student_combination.free_entropy(H, t_f)
        
        c = jnp.max(f_difference, axis = -1, keepdims = True)
        f_difference = jnp.squeeze(c) + jnp.log(jnp.mean(jnp.exp(f_difference - c), axis = -1))
        f_difference = f_difference.at[1].multiply(-1.)
        
        H = student_combination.energy(sigma_teacher)
        
        # jnp.mean(student_combination.free_entropy(H, t_diff))
        
        log_likelihood_difference = jnp.mean(student_combination.free_entropy(H, t_diff)) + f_difference
        
        # log_posterior_difference = log_likelihood_difference
        
        if monitor_annealing_this_run:
            print("Run [{}/{}], free entropy difference: {:.4f}, {:.4f}".format(annealing_run, number_annealing_runs, np.mean(f_difference[0]), np.mean(f_difference[1])))
        
        mean_f_difference += (f_difference - mean_f_difference) / annealing_run # (annealing_run + 1)
        var_f_difference += (f_difference - mean_f_difference)**2 / annealing_run # (annealing_run + 1)
        var_f_difference *= (annealing_run - 1) / annealing_run # annealing_run / (annealing_run + 1)
        
        mean_log_likelihood_difference += (log_likelihood_difference - mean_log_likelihood_difference) / annealing_run # (annealing_run + 1)
        var_log_likelihood_difference += (log_likelihood_difference - mean_log_likelihood_difference)**2 / annealing_run # (annealing_run + 1)
        var_log_likelihood_difference *= (annealing_run - 1) / annealing_run # annealing_run / (annealing_run + 1)
    
    del sigma_teacher_torch
    del sigma_student_torch
    gc.collect()
    
    var_f_difference *= number_annealing_runs / (number_annealing_runs - 1)
    var_log_likelihood_difference *= number_annealing_runs / (number_annealing_runs - 1)
    
    # mean_log_posterior_difference = mean_log_likelihood_difference + 1/alpha * 1/2 * torch.sum(student_B.xi**2) - 1/alpha * 1/2 * torch.sum(student_A.xi**2)
    mean_log_posterior_difference = mean_log_likelihood_difference - 1/alpha * 1/2 * np.squeeze(t_diff * jnp.sum(student_combination.xi**2,
                                                                                                                 axis = (-2, -1), keepdims = True))
    var_log_posterior_difference = var_log_likelihood_difference
    
    return mean_f_difference, np.sqrt(var_f_difference), mean_log_likelihood_difference, np.sqrt(var_log_likelihood_difference), mean_log_posterior_difference, np.sqrt(var_log_posterior_difference)

In [None]:
def evaluate(teacher, student_A, student_B, beta, M, number_sampling_steps,
             number_annealing_steps, number_annealing_runs,
             number_monitored_sampling_steps, number_monitored_annealing_runs, seed):
    
    student_combination = RBM_combination(student_A, student_B, beta)
    
    device = teacher.training_device
    
    N = teacher.N
    # P = teacher.P
    # P_1 = student_1.P
    # P_2 = student_2.P
    
    alpha = M/N
    
    random_number_generator = teacher.random_number_generator
    key = jax.random.PRNGKey(seed)
    
    sigma_teacher_torch = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)
    # sigma_student_torch = torch.nn.Parameter(torch.zeros((2, M, N), device = device), requires_grad = False)
    sigma_student_torch = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)
    
    # number_sampling_steps = 10
    # number_monitored_sampling_steps = 10
    
    # number_annealing_runs = 20
            
    t_diff = jnp.array([1., 1.])[..., jnp.newaxis, jnp.newaxis]
    
    f_difference = jnp.array([0., 0.])
    
    mean_f_difference = jnp.array([0., 0.])
    var_f_difference = jnp.array([0., 0.])
        
    mean_log_likelihood_difference = jnp.array([0., 0.])
    var_log_likelihood_difference = jnp.array([0., 0.])
    
    # iterate_backward = [False, True]
    students = [student_A, student_B]
    
    for annealing_run in range(1, number_annealing_runs + 1):
        
        if number_monitored_annealing_runs != 0:
            monitor_annealing_this_run = annealing_run % (number_annealing_runs // number_monitored_annealing_runs) == 0
        else:
            monitor_annealing_this_run = False
        
        sigma_teacher_torch.copy_(torch.sign(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator))))
        
        teacher.sample_visible(sigma_teacher_torch, beta, number_sampling_steps,
                               number_monitored_sampling_steps)
        
        sigma_teacher = jnp.array(sigma_teacher_torch.detach().cpu().numpy())
        
        for j, student in enumerate(students):
            iterate_backward = (j == 1)
            
            if iterate_backward:
                t_ini = jnp.array([0., 1.])[..., jnp.newaxis, jnp.newaxis]
                t_fin = jnp.array([1., 0.])[..., jnp.newaxis, jnp.newaxis]
            else:
                t_ini = jnp.array([1., 0.])[..., jnp.newaxis, jnp.newaxis]
                t_fin = jnp.array([0., 1.])[..., jnp.newaxis, jnp.newaxis]
            
            # sigma_student_torch.copy_(torch.sign(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator))))
            sigma_student_torch.copy_(sigma_teacher_torch)
            
            student.sample_visible(sigma_student_torch, beta, number_sampling_steps,
                                   number_monitored_sampling_steps)
            
            sigma_student = jnp.array(sigma_student_torch.detach().cpu().numpy())
            
            H = student_combination.energy(sigma_student)
            
            # sigma_student, key = student_combination.update_visible(H, t_ini, key)
            
            # H = student_combination.energy(sigma_student)
            
            f_difference_cur = -jnp.sum(student_combination.free_entropy(H, t_ini), axis = 0)
            
            start = timer()
            
            ### Bottleneck!
            f_difference_cur, H, key = jax.lax.fori_loop(1, number_annealing_steps + 1,
                                                         partial(free_entropy_difference, student_combination, number_annealing_steps, iterate_backward),
                                                         (f_difference_cur, H, key))
            
            end = timer()
            print("Time elapsed: %.4f." % (end - start))
            
            f_difference_cur += jnp.sum(student_combination.free_entropy(H, t_fin), axis = 0)
            
            c = jnp.max(f_difference_cur, axis = -1)
            f_difference = f_difference.at[j].set(c + jnp.log(jnp.mean(jnp.exp(f_difference_cur - c), axis = -1)))
        
        f_difference = f_difference.at[1].multiply(-1.)
        
        H = student_combination.energy(sigma_teacher)
        
        # jnp.mean(student_combination.free_entropy(H, t_diff))
        
        f_numerator = student_combination.free_entropy(H, t_diff)
        f_numerator = f_numerator[0] - f_numerator[1]
        
        log_likelihood_difference = jnp.mean(f_numerator) + f_difference
        
        # log_posterior_difference = log_likelihood_difference
        
        if monitor_annealing_this_run:
            print("Run [{}/{}], free entropy difference: {:.4f}, {:.4f}".format(annealing_run, number_annealing_runs, f_difference[0], f_difference[1]))
        
        mean_f_difference += (f_difference - mean_f_difference) / annealing_run # (annealing_run + 1)
        var_f_difference += (f_difference - mean_f_difference)**2 / annealing_run # (annealing_run + 1)
        var_f_difference *= (annealing_run - 1) / annealing_run # annealing_run / (annealing_run + 1)
        
        mean_log_likelihood_difference += (log_likelihood_difference - mean_log_likelihood_difference) / annealing_run # (annealing_run + 1)
        var_log_likelihood_difference += (log_likelihood_difference - mean_log_likelihood_difference)**2 / annealing_run # (annealing_run + 1)
        var_log_likelihood_difference *= (annealing_run - 1) / annealing_run # annealing_run / (annealing_run + 1)
    
    del sigma_teacher_torch
    del sigma_student_torch
    gc.collect()
    
    var_f_difference *= number_annealing_runs / (number_annealing_runs - 1)
    var_log_likelihood_difference *= number_annealing_runs / (number_annealing_runs - 1)
    
    # mean_log_posterior_difference = mean_log_likelihood_difference + 1/alpha * 1/2 * torch.sum(student_B.xi**2) - 1/alpha * 1/2 * torch.sum(student_A.xi**2)
    mean_log_posterior_difference = mean_log_likelihood_difference - 1/alpha * 1/2 * jnp.sum(t_diff * jnp.sum(student_combination.xi**2,
                                                                                                              axis = (-2, -1), keepdims = True))
    var_log_posterior_difference = var_log_likelihood_difference
    
    return mean_f_difference, np.sqrt(var_f_difference), mean_log_likelihood_difference, np.sqrt(var_log_likelihood_difference), mean_log_posterior_difference, np.sqrt(var_log_posterior_difference)

In [12]:
def evaluate(teacher, student_A, student_B, beta, M, number_sampling_steps,
             number_annealing_steps, number_annealing_runs,
             number_monitored_sampling_steps, number_monitored_annealing_runs, seed):
    
    student_combination = RBM_combination(student_A, student_B, beta)
    
    device = teacher.training_device
    
    N = teacher.N
    # P = teacher.P
    # P_1 = student_1.P
    # P_2 = student_2.P
    
    alpha = M/N
    
    random_number_generator = teacher.random_number_generator
    key = jax.random.PRNGKey(seed)
    
    sigma_teacher_torch = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)
    # sigma_student_torch = torch.nn.Parameter(torch.zeros((2, M, N), device = device), requires_grad = False)
    sigma_student_torch = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)
    
    # number_sampling_steps = 10
    # number_monitored_sampling_steps = 10
    
    # number_annealing_runs = 20
    
    mean_f_difference = jnp.array([0., 0.])
    var_f_difference = jnp.array([0., 0.])
        
    mean_log_likelihood_difference = jnp.array([0., 0.])
    var_log_likelihood_difference = jnp.array([0., 0.])
    
    # iterate_backward = [False, True]
    students = [student_A, student_B]
    
    for annealing_run in range(1, number_annealing_runs + 1):
        
        if number_monitored_annealing_runs != 0:
            monitor_annealing_this_run = annealing_run % (number_annealing_runs // number_monitored_annealing_runs) == 0
        else:
            monitor_annealing_this_run = False
        
        sigma_teacher_torch.copy_(torch.sign(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator))))
        
        teacher.sample_visible(sigma_teacher_torch, beta, number_sampling_steps,
                               number_monitored_sampling_steps)
        
        sigma_teacher = jnp.array(sigma_teacher_torch.detach().cpu().numpy())
        
        f_difference = jnp.array([0., 0.])
        f_difference_cur = jnp.zeros(M)
        
        for j, student in enumerate(students):
            
            iterate_backward = (j == 1)
            t_ini = np.float32(j)
            t_fin = 1. - t_ini
            
            sigma_student_torch.copy_(torch.sign(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator))))
            # sigma_student_torch.copy_(sigma_teacher_torch)
            
            student.sample_visible(sigma_student_torch, beta, number_sampling_steps,
                                   number_monitored_sampling_steps)
            
            sigma_student = jnp.array(sigma_student_torch.detach().cpu().numpy())
            
            H_A, H_B = student_combination.energy(sigma_student)
            
            # sigma_student, key = student_combination.update_visible(H, t_ini, key)
            
            # H = student_combination.energy(sigma_student)
            
            f_difference_cur += -(1 - t_ini) * student_combination.free_entropy(H_A) - t_ini * student_combination.free_entropy(H_B)
            
            start = timer()
            
            ### Bottleneck!
            f_difference_cur, H_A, H_B, key = jax.lax.fori_loop(1, number_annealing_steps,
                                                                partial(free_entropy_difference, student_combination, number_annealing_steps, iterate_backward),
                                                                (f_difference_cur, H_A, H_B, key))
            
            end = timer()
            print("Time elapsed: %.4f." % (end - start))
            
            f_difference_cur += (1 - t_fin) * student_combination.free_entropy(H_A) + t_fin * student_combination.free_entropy(H_B)
            
            c = jnp.max(f_difference_cur, axis = -1)
            f_difference = f_difference.at[j].set(c + jnp.log(jnp.mean(jnp.exp(f_difference_cur - c), axis = -1)))
        
        f_difference = f_difference.at[1].multiply(-1.)
        
        H_A, H_B = student_combination.energy(sigma_teacher)
        
        log_likelihood_difference = jnp.mean(student_combination.free_entropy(H_A)) - jnp.mean(student_combination.free_entropy(H_B)) + f_difference
        
        # log_posterior_difference = log_likelihood_difference
        
        if monitor_annealing_this_run:
            print("Run [{}/{}], free entropy difference: {:.4f}, {:.4f}".format(annealing_run, number_annealing_runs, f_difference[0], f_difference[1]))
            f_difference = analytical_free_energy_difference(student_combination)
            print(f_difference)
        
        mean_f_difference += (f_difference - mean_f_difference) / annealing_run # (annealing_run + 1)
        var_f_difference += (f_difference - mean_f_difference)**2 / annealing_run # (annealing_run + 1)
        var_f_difference *= (annealing_run - 1) / annealing_run # annealing_run / (annealing_run + 1)
        
        mean_log_likelihood_difference += (log_likelihood_difference - mean_log_likelihood_difference) / annealing_run # (annealing_run + 1)
        var_log_likelihood_difference += (log_likelihood_difference - mean_log_likelihood_difference)**2 / annealing_run # (annealing_run + 1)
        var_log_likelihood_difference *= (annealing_run - 1) / annealing_run # annealing_run / (annealing_run + 1)
    
    del sigma_teacher_torch
    del sigma_student_torch
    gc.collect()
    
    var_f_difference *= number_annealing_runs / (number_annealing_runs - 1)
    var_log_likelihood_difference *= number_annealing_runs / (number_annealing_runs - 1)
    
    # mean_log_posterior_difference = mean_log_likelihood_difference + 1/alpha * 1/2 * torch.sum(student_B.xi**2) - 1/alpha * 1/2 * torch.sum(student_A.xi**2)
    mean_log_posterior_difference = mean_log_likelihood_difference + 1/2 * 1/alpha * (jnp.sum(student_combination.xi_B**2) - jnp.sum(student_combination.xi_A**2))
    var_log_posterior_difference = var_log_likelihood_difference
    
    return mean_f_difference, np.sqrt(var_f_difference), mean_log_likelihood_difference, np.sqrt(var_log_likelihood_difference), mean_log_posterior_difference, np.sqrt(var_log_posterior_difference)

~ 2 GB per sigma

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

T = 0.25
beta = 1/T

n_alpha = 20
alpha_range = np.array([1]) # alpha_range = np.linspace(0.1, 1, num = n_alpha, endpoint = True)

N = 16000
P = 2
P_t = 3
P_large = 18
m_0 = 0.2

seed = 693
random_number_generator = torch.Generator(device = device)
random_number_generator.manual_seed(seed)

data_loader_seed = 45
data_loader_generator = torch.Generator(device = "cpu")
data_loader_generator.manual_seed(data_loader_seed)

m_A_range = np.zeros(n_alpha)
m_B_range = np.zeros((n_alpha, 2))
# m_C_range = np.zeros(n_alpha)

teacher = RBM(N, P, 1/np.sqrt(N), device = device, random_number_generator = random_number_generator).to(device)

student_A = RBM(N, P_t, 1/np.sqrt(N), device = device, random_number_generator = random_number_generator).to(device)
student_A.xi[:, 0 : P].copy_(np.sqrt(1 - m_0) * student_A.xi[:, 0 : P] + np.sqrt(m_0) * teacher.xi.detach())

student_B = RBM(N, P_t, 1/np.sqrt(N), device = device, random_number_generator = random_number_generator).to(device)
student_B.xi[:, 0 : P].copy_(np.sqrt(1 - m_0) * student_B.xi[:, 0 : P] + np.sqrt(m_0) * teacher.xi.detach())
student_B.xi[:, P : P_t].copy_(np.sqrt(1 - m_0) * student_B.xi[:, P : P_t] + np.sqrt(m_0) * teacher.xi[:, 0 : 1].detach())

# student_C = RBM(N, P_large, 1/np.sqrt(N), device = device, random_number_generator = random_number_generator).to(device)

# Sampling converges extremely quickly
number_teacher_sampling_steps = 50
number_monitored_sampling_steps = 10

# decay_rate = 1 is the theoretically correct value in our setting. The other hyperparameters are set heuristically.
number_student_sampling_steps = 1
learning_rate = 0.005 # 0.5 # 0.005
decay_rate = 1
momentum = 0.9
number_student_training_epochs = 3000 # 30 # 3000
number_monitored_training_epochs = 30

number_sampling_steps = 50
number_annealing_steps = 1000
number_annealing_runs = 20
number_monitored_sampling_steps = 2
number_monitored_annealing_runs = 20
seed = 10

for i, alpha in enumerate(alpha_range):
    M = int(alpha*N)
    
    data_batch_size = M
    
    sigma = torch.nn.Parameter(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)), requires_grad = False)
    
    teacher.sample_visible(sigma, beta, number_teacher_sampling_steps,
                           number_monitored_sampling_steps)
    
    loader = torch.utils.data.DataLoader(dataset = sigma, batch_size = data_batch_size, shuffle = True, generator = data_loader_generator)
    
    # loader : torch.utils.data.DataLoader, beta : float, alpha : float, initial_learning_rate : float,
                      # decay_rate : float, momentum : float, number_sampling_steps : int,
                      # number_training_epochs : int, number_monitored_training_epochs : int = 0, monitor_sampling : bool = False
    
    student_A.train_weights(loader, beta, alpha, learning_rate, decay_rate, momentum, number_student_sampling_steps,
                            number_student_training_epochs, number_monitored_training_epochs, monitor_sampling = False)
    
    student_B.train_weights(loader, beta, alpha, learning_rate, decay_rate, momentum, number_student_sampling_steps,
                            number_student_training_epochs, number_monitored_training_epochs, monitor_sampling = False)
    
    # student_C.train_weights(loader, beta, learning_rate, decay_rate, momentum, number_student_sampling_steps,
                            # number_student_training_epochs, number_monitored_training_epochs)
    
    m_A = torch.transpose(teacher.xi, 0, 1) @ student_A.xi
    print(m_A)
    m_A_range[0] = torch.mean(torch.diagonal(m_A)).item()
    
    s_A = torch.transpose(student_A.xi, 0, 1) @ student_A.xi
    
    print(s_A)
    
    m_B = torch.transpose(teacher.xi, 0, 1) @ student_B.xi
    print(m_B)
    m_B_range[0, 0] = torch.mean(torch.diagonal(m_B)[1 :]).item()
    m_B_range[0, 1] = (m_B[0, 0] + m_B[0, P]).item()/2
    
    s_B = torch.transpose(student_B.xi, 0, 1) @ student_B.xi
    
    print(s_B)
    
    # del sigma
    
    # outputs = evaluate(teacher, student_A, student_B, beta, M, number_sampling_steps,
                       # number_annealing_steps, number_annealing_runs,
                       # number_monitored_sampling_steps, number_monitored_annealing_runs, seed)
    
    # mean_f_difference, std_f_difference, mean_log_likelihood_difference, std_log_likelihood_difference, mean_log_posterior_difference, std_log_posterior_difference = outputs
    
    f_difference_1, f_difference_2 = analytical_free_energy_difference(student_A, student_B, beta, P_t, device)
    print(f_difference_1)
    print(f_difference_2)
    
    log_likelihood_difference_1 = torch.mean(student_A.free_entropy(sigma, beta)) - torch.mean(student_B.free_entropy(sigma, beta)) + f_difference_1
    log_likelihood_difference_2 = torch.mean(student_A.free_entropy(sigma, beta)) - torch.mean(student_B.free_entropy(sigma, beta)) + f_difference_2
    print(log_likelihood_difference_1)
    print(log_likelihood_difference_2)
    
    log_posterior_difference_1 = log_likelihood_difference_1 + 1/2 * 1/alpha * (torch.sum(student_B.xi**2) - torch.sum(student_A.xi**2))
    log_posterior_difference_2 = log_likelihood_difference_2 + 1/2 * 1/alpha * (torch.sum(student_B.xi**2) - torch.sum(student_A.xi**2))
    print(log_posterior_difference_1)
    print(log_posterior_difference_2)
    
    # m_C = torch.transpose(teacher.xi, 0, 1) @ student_C.xi
    # print(m_C)
    # m_C_range[i] = torch.mean(torch.diagonal(m_C)).item()
    
    # s_C = torch.transpose(student_C.xi, 0, 1) @ student_C.xi
    
    # print(s_C)
    
    # Reinitialize weights
    student_A.initialize_weights(1/np.sqrt(N))
    student_A.xi[:, 0 : P].copy_(np.sqrt(1 - m_0) * student_A.xi[:, 0 : P] + np.sqrt(m_0) * teacher.xi.detach())
    
    student_B.initialize_weights(1/np.sqrt(N))
    student_B.xi[:, 0 : P].copy_(np.sqrt(1 - m_0) * student_B.xi[:, 0 : P] + np.sqrt(m_0) * teacher.xi.detach())
    student_B.xi[:, P : P_t].copy_(np.sqrt(1 - m_0) * student_B.xi[:, P : P_t] + np.sqrt(m_0) * teacher.xi[:, 0 : 1].detach())
    
    # student_C.initialize_weights(1/np.sqrt(N))

# del teacher
# del student_A
# del student_B
# del student_C

gc.collect()

Step [0/50], free entropy: 5.1559
Step [25/50], free entropy: 14.7833
Step [50/50], free entropy: 14.7538
Epoch [100/3000], loss: 7.7587, reconstruction error: 9.9800
Epoch [200/3000], loss: 4.6092, reconstruction error: 8.8268
Epoch [300/3000], loss: 2.6840, reconstruction error: 7.8814
Epoch [400/3000], loss: 1.1724, reconstruction error: 7.2050
Epoch [500/3000], loss: 0.3511, reconstruction error: 6.5575
Epoch [600/3000], loss: -0.1981, reconstruction error: 6.0733
Epoch [700/3000], loss: -0.4034, reconstruction error: 5.8065
Epoch [800/3000], loss: -0.5930, reconstruction error: 5.5223
Epoch [900/3000], loss: -0.6000, reconstruction error: 5.2862
Epoch [1000/3000], loss: -0.5228, reconstruction error: 5.1892
Epoch [1100/3000], loss: -0.5493, reconstruction error: 5.1105
Epoch [1200/3000], loss: -0.4820, reconstruction error: 4.8985
Epoch [1300/3000], loss: -0.4318, reconstruction error: 4.9004
Epoch [1400/3000], loss: -0.3423, reconstruction error: 4.9164
Epoch [1500/3000], loss: -

0

In [None]:
mean_f_difference, std_f_difference, mean_log_likelihood_difference, std_log_likelihood_difference, mean_log_posterior_difference, std_log_posterior_difference

In [None]:
mean_f_difference, std_f_difference, mean_log_likelihood_difference, std_log_likelihood_difference, mean_log_posterior_difference, std_log_posterior_difference

In [None]:
def evaluate(teacher, student_1, student_2, M, number_sampling_steps,
             number_annealing_steps, number_annealing_runs,
             number_monitored_sampling_steps, number_monitored_annealing_steps,
             monitor_student_sampling = False):
    
    device = teacher.training_device
    
    t_range = torch.linspace(1/number_annealing_steps, 1, number_annealing_steps)
    
    N = teacher.N
    P = teacher.P
    P_1 = student_1.P
    P_2 = student_2.P
    
    alpha = M/N
    
    random_number_generator = teacher.random_number_generator
    
    mean_f_difference = torch.zeros(2, device = device, requires_grad = False)
    var_f_difference = torch.zeros(2, device = device, requires_grad = False)
    
    mean_log_likelihood_difference = torch.zeros(2, device = device, requires_grad = False)
    var_log_likelihood_difference = torch.zeros(2, device = device, requires_grad = False)
    
    sigma_teacher = torch.nn.Parameter(torch.zeros((M, N), device = device, requires_grad = False))
    sigma_student = torch.nn.Parameter(torch.zeros((M, N), device = device, requires_grad = False))
    
    # number_sampling_steps = 10
    # number_monitored_sampling_steps = 10
    
    # number_annealing_runs = 20
    
    tau_1 = torch.zeros((M, P_1), device = student_1.training_device)
    tau_2 = torch.zeros((M, P_2), device = student_2.training_device)
    
    for annealing_run in range(number_annealing_runs):
        sigma_teacher.copy_(torch.sign(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator))))
        sigma_student.copy_(torch.sign(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator))))
        
        teacher.sample_visible(sigma_teacher, beta, number_sampling_steps,
                               number_monitored_sampling_steps, anneal = False)
        
        for i, (student_A, student_B, tau_A, tau_B) in enumerate([(student_1, student_2, tau_1, tau_2), (student_2, student_1, tau_2, tau_1)]):
            sigma_student.copy_(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)))
            
            student_A.sample_hidden_given_visible(tau_A, sigma_student, beta)
            
            P_sigma_given_tau = torch.sigmoid(beta * tau_A @ torch.transpose(student_A.xi, 0, 1))
            sigma_student.copy_(rademacher(P_sigma_given_tau, generator = random_number_generator))
            
            f_difference = -student_A.free_entropy(sigma_student, beta)
            
            for annealing_step, t in enumerate(t_range):
                if number_monitored_annealing_steps != 0:
                    monitor_annealing_this_epoch = annealing_step % (number_annealing_steps // number_monitored_annealing_steps) == 0
                else:
                    monitor_annealing_this_epoch = False
                f_difference.copy_(f_difference + student_A.free_entropy(sigma_student, (1 - t)*beta) + student_B.free_entropy(sigma_student, t*beta))
                
                student_A.sample_hidden_given_visible(tau_A, sigma_student, (1 - t)*beta)
                student_B.sample_hidden_given_visible(tau_B, sigma_student, t*beta)
                
                P_sigma_given_tau = torch.sigmoid((1 - t)*beta * tau_A @ torch.transpose(student_A.xi, 0, 1) + t*beta * tau_B @ torch.transpose(student_B.xi, 0, 1))
                sigma_student.copy_(rademacher(P_sigma_given_tau, generator = random_number_generator))
                
                f_difference.copy_(f_difference - student_A.free_entropy(sigma_student, (1 - t)*beta) - student_B.free_entropy(sigma_student, t*beta))
                
                if monitor_annealing_this_epoch:
                    print("Run [{}/{}], annealing step [{}/{}], free entropy difference: {:.4f}".format(annealing_step, number_annealing_steps,
                                                                                                        annealing_run, number_annealing_runs, f_difference))
            
            f_difference.copy_(f_difference + student_B.free_entropy(sigma_student, beta))
            
            c = torch.max(f_difference)
            f_difference = c + torch.log(torch.mean(torch.exp(f_difference - c)))
            
            log_likelihood_difference = torch.mean(student_A.free_entropy(sigma_teacher, beta)) - torch.mean(student_B.free_entropy(sigma_teacher, beta)) + f_difference
            
            log_posterior_difference = log_likelihood_difference
            
            mean_f_difference[i] += (f_difference.detach().item() - mean_f_difference[i]) / (annealing_run + 1)
            var_f_difference[i] += (f_difference.detach().item() - mean_f_difference[i])**2 / (annealing_run + 1)
            var_f_difference[i] *= annealing_run / (annealing_run + 1)
            
            mean_log_likelihood_difference[i] += (log_likelihood_difference.detach().item() - mean_log_likelihood_difference[i]) / (annealing_run + 1)
            var_log_likelihood_difference[i] += (log_likelihood_difference.detach().item() - mean_log_likelihood_difference[i])**2 / (annealing_run + 1)
            var_log_likelihood_difference[i] *= annealing_run / (annealing_run + 1)
    
    var_f_difference *= number_annealing_run / (number_annealing_run - 1)
    var_log_likelihood_difference *= number_annealing_run / (number_annealing_run - 1)
    
    mean_log_posterior_difference = mean_log_likelihood_difference + 1/alpha * 1/2 * torch.sum(student_B.xi**2) - 1/alpha * 1/2 * torch.sum(student_A.xi**2)
    var_log_posterior_difference = var_log_likelihood_difference
    
    return mean_f_difference, torch.sqrt(var_f_difference), mean_log_likelihood_difference, torch.sqrt(var_log_likelihood_difference), mean_log_posterior_difference, torch.sqrt(var_log_posterior_difference)

In [None]:
m_A_range = np.zeros(n_alpha)
m_B_range = np.zeros((n_alpha, 2))
m_C_range = np.zeros(n_alpha)

m_A = torch.transpose(teacher.xi, 0, 1) @ student_A.xi
print(m_A)
m_A_range[0] = torch.mean(torch.diagonal(m_A)).item()

s_A = torch.transpose(student_A.xi, 0, 1) @ student_A.xi

print(s_A)

m_B = torch.transpose(teacher.xi, 0, 1) @ student_B.xi
print(m_B)
m_B_range[0, 0] = torch.mean(torch.diagonal(m_B)[1 :]).item()
m_B_range[0, 1] = (m_B[0, 0] + m_B[0, P]).item()/2

s_B = torch.transpose(student_B.xi, 0, 1) @ student_B.xi

print(s_B)

In [None]:
alpha = 1

t_range = np.linspace(0.001, 1, num = 10000, endpoint = True)

N = 16000
M = int(2*N)
P = 2
P_t = 3
m_0 = 0.2
beta = 4

sigma_teacher = torch.nn.Parameter(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)), requires_grad = False)
sigma_student = torch.nn.Parameter(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)), requires_grad = False)

number_sampling_steps = 100
number_monitored_sampling_steps = 100

number_repetitions = 20
average_probability_ratio = 0

tau_A = torch.zeros((M, P_t), device = student_A.training_device)
tau_B = torch.zeros((M, P_t), device = student_B.training_device)

teacher.sample_visible(sigma_teacher, beta, number_sampling_steps,
                       number_monitored_sampling_steps)

student_A.sample_hidden_given_visible(tau_A, sigma_student, beta)

P_sigma_given_tau = torch.sigmoid(beta * tau_A @ torch.transpose(student_A.xi, 0, 1))
sigma_student.copy_(rademacher(P_sigma_given_tau, generator = student_A.random_number_generator))

f_difference = -student_A.free_entropy(sigma_student, beta)

for t in t_range:
    f_difference.copy_(f_difference + student_A.free_entropy(sigma_student, (1 - t)*beta) + student_B.free_entropy(sigma_student, t*beta))
    
    student_A.sample_hidden_given_visible(tau_A, sigma_student, (1 - t)*beta)
    student_B.sample_hidden_given_visible(tau_B, sigma_student, t*beta)
    
    P_sigma_given_tau = torch.sigmoid((1 - t)*beta * tau_A @ torch.transpose(student_A.xi, 0, 1) + t*beta * tau_B @ torch.transpose(student_B.xi, 0, 1))
    sigma_student.copy_(rademacher(P_sigma_given_tau, generator = student_A.random_number_generator))
    
    f_difference.copy_(f_difference - student_A.free_entropy(sigma_student, (1 - t)*beta) - student_B.free_entropy(sigma_student, t*beta))

f_difference.copy_(f_difference + student_B.free_entropy(sigma_student, beta))

c = torch.max(f_difference)
f_difference = c + torch.log(torch.mean(torch.exp(f_difference - c)))

print(f_difference)
print(torch.mean(student_A.free_entropy(sigma_teacher, beta)) - torch.mean(student_B.free_entropy(sigma_teacher, beta)) + f_difference)

In [None]:
print(f_difference)
print(torch.mean(student_A.free_entropy(sigma_teacher, beta)) - torch.mean(student_B.free_entropy(sigma_teacher, beta)) + f_difference)

In [None]:
alpha = 1

t_range = np.linspace(0.001, 1, num = 10000, endpoint = True)

N = 16000
M = int(2*N)
P = 2
P_t = 3
m_0 = 0.2
beta = 4

sigma_teacher = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)
sigma_student = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)

number_sampling_steps = 100
number_monitored_sampling_steps = 100

number_repetitions = 20
average_probability_ratio = 0

sigma_teacher.copy_(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)))
sigma_student.copy_(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)))

tau_A = torch.zeros((M, P_t), device = student_A.training_device)
tau_B = torch.zeros((M, P_t), device = student_B.training_device)

teacher.sample_visible(sigma_teacher, beta, number_sampling_steps,
                       number_monitored_sampling_steps = 0)

student_B.sample_hidden_given_visible(tau_B, sigma_student, beta)

P_sigma_given_tau = torch.sigmoid(beta * tau_B @ torch.transpose(student_B.xi, 0, 1))
sigma_student.copy_(rademacher(P_sigma_given_tau, generator = student_B.random_number_generator))

f_difference = -student_B.free_entropy(sigma_student, beta)

for t in t_range:
    f_difference.copy_(f_difference + student_B.free_entropy(sigma_student, (1 - t)*beta) + student_A.free_entropy(sigma_student, t*beta))
    
    student_B.sample_hidden_given_visible(tau_B, sigma_student, (1 - t)*beta)
    student_A.sample_hidden_given_visible(tau_A, sigma_student, t*beta)
    
    P_sigma_given_tau = torch.sigmoid((1 - t)*beta * tau_B @ torch.transpose(student_B.xi, 0, 1) + t*beta * tau_A @ torch.transpose(student_A.xi, 0, 1))
    sigma_student.copy_(rademacher(P_sigma_given_tau, generator = student_B.random_number_generator))
    
    f_difference.copy_(f_difference - student_B.free_entropy(sigma_student, (1 - t)*beta) - student_A.free_entropy(sigma_student, t*beta))

f_difference.copy_(f_difference + student_A.free_entropy(sigma_student, beta))

c = torch.max(f_difference)
f_difference = c + torch.log(torch.mean(torch.exp(f_difference - c)))

print(f_difference)
print(torch.mean(student_B.free_entropy(sigma_teacher, beta)) - torch.mean(student_A.free_entropy(sigma_teacher, beta)) + f_difference)

In [None]:
print(beta)

print(torch.transpose(teacher.xi, 0, 1) @ student_A.xi)

print(torch.transpose(teacher.xi, 0, 1) @ student_B.xi)

print(torch.mean(student_A.free_entropy(sigma_teacher, beta)))

print(torch.mean(student_B.free_entropy(sigma_teacher, beta)))

print(torch.mean(student_B.free_entropy(sigma_teacher, beta)) - torch.mean(student_A.free_entropy(sigma_teacher, beta)))

In [None]:
print(0.3735 - 0.4344)

print(-0.3735 + 0.4356)

In [None]:
alpha = 1

N = 16000
M = int(2*N)
P = 2
P_t = 3
m_0 = 0.2
beta = 4

D = 20

sigma_teacher = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)
sigma_student = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)

number_sampling_steps = 100
number_monitored_sampling_steps = 100

number_repetitions = 20
average_probability_ratio = 0

for repetition in range(number_repetitions):
    
    sigma_teacher.copy_(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)))
    sigma_student.copy_(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)))

    teacher.sample_visible(sigma_teacher, beta, number_sampling_steps,
                           number_monitored_sampling_steps,
                           anneal = False, monitor_sampling = False)

    student_A.sample_visible(sigma_student, beta, number_sampling_steps,
                             number_monitored_sampling_steps,
                             anneal = False, monitor_sampling = False)
    
    f = student_B.free_entropy(sigma_student, beta) - student_A.free_entropy(sigma_student, beta)
    c = torch.max(f)
    
    Z_ratio = torch.exp(c)*torch.mean(torch.exp(f - c))
    
    f = student_A.free_entropy(sigma_teacher, beta) - student_B.free_entropy(sigma_teacher, beta)
    c = torch.max(f)
    
    probability_ratio = torch.exp(c)*torch.mean(torch.exp(f - c)) * Z_ratio
    print(probability_ratio)
    
    average_probability_ratio += (probability_ratio.detach().item() - average_probability_ratio) / (repetition + 1)

In [None]:
print(average_probability_ratio)

In [None]:
alpha = 1

N = 16000
M = int(2*N)
P = 2
P_t = 3
m_0 = 0.2
beta = 4

D = 20

sigma_teacher = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)
sigma_student = torch.nn.Parameter(torch.zeros((M, N), device = device), requires_grad = False)

number_sampling_steps = 100
number_monitored_sampling_steps = 100

number_repetitions = 20
average_probability_ratio = 0

for repetition in range(number_repetitions):
    
    sigma_teacher.copy_(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)))
    sigma_student.copy_(torch.sign(torch.randn((M, N), device = device, generator = random_number_generator)))

    teacher.sample_visible(sigma_teacher, beta, number_sampling_steps,
                           number_monitored_sampling_steps,
                           anneal = False, monitor_sampling = False)

    student_B.sample_visible(sigma_student, beta, number_sampling_steps,
                             number_monitored_sampling_steps,
                             anneal = False, monitor_sampling = False)
    
    f = student_A.free_entropy(sigma_student, beta) - student_B.free_entropy(sigma_student, beta)
    c = torch.max(f)
    
    Z_ratio = torch.exp(c)*torch.mean(torch.exp(f - c))
    
    f = student_B.free_entropy(sigma_teacher, beta) - student_A.free_entropy(sigma_teacher, beta)
    c = torch.max(f)
    
    probability_ratio = torch.exp(c)*torch.mean(torch.exp(f - c)) * Z_ratio
    print(probability_ratio)
    
    average_probability_ratio += (probability_ratio.detach().item() - average_probability_ratio) / (repetition + 1)

In [None]:
print(average_probability_ratio)