In [1]:
from utils import Ising, conditional, log_unnormalized_p
from plot_utils import *
from tqdm.auto import tqdm

In [7]:
def gibbs_sampling(model, state, n_samples):
    n = model.dim
    model.state = state
    
    samples = []
    states = np.array([[i, j] for i in range(n) for j in range(n)])
    for t in range(n_samples):
        np.random.shuffle(states)
        for i, j in states:
            p_Xij = conditional(model.state, i, j, model.Js, model.Jst)
            model.state[i][j] = np.random.binomial(1, p_Xij) * 2 - 1  # 0 -> -1, 1 -> 1
        samples.append(model.state.copy())

    return samples


In [8]:
def annealed_importance_sampling(p_0, p_G, betas, n_samples, n_steps, return_all_temps=False):
    """
    Annealed Importance Sampling
    
    Args:
        return_all_temps: If True, return estimates for all temperatures
    """
    assert (len(betas) == n_steps + 1)
    
    log_Z_estimates = []
    log_Z_all_temps = [] if return_all_temps else None
    
    target_Js = p_G.Js
    target_Jst = p_G.Jst
    dim = p_G.dim
    
    # for all x, p_0(x) = exp(0) = 1, log(Z_0) = log(2^(dim^2)) = dim^2 * log(2)
    log_Z_0 = dim ** 2 * np.log(2)
    
    for i in range(n_samples):
        sum_log_ratio = log_Z_0
        temp_estimates = [log_Z_0] if return_all_temps else None
        
        for j in range(n_steps):
            if j == 0:
                # Sample x_1 from p_0
                p_0.init_state()
                x_j = p_0.state.copy()
            else:
                # Sample x_{j+1} from p_beta_j given x_j using Gibbs sampling
                p_beta_j = Ising(dim, betas[j] * target_Js, betas[j] * target_Jst)
                x_j = gibbs_sampling(p_beta_j, x_j, 1)[0]
                
            # Compute log ratio: log(p̂_{j+1}(x_{j+1}) / p̂_j(x_{j+1}))
            log_p_j = log_unnormalized_p(x_j, betas[j] * target_Js, betas[j] * target_Jst)
            log_p_j_plus_1 = log_unnormalized_p(x_j, betas[j+1] * target_Js, betas[j+1] * target_Jst)
            
            sum_log_ratio += (log_p_j_plus_1 - log_p_j)
            
            if return_all_temps:
                temp_estimates.append(sum_log_ratio)
            
        log_Z_estimates.append(sum_log_ratio)
        if return_all_temps:
            log_Z_all_temps.append(temp_estimates)
    
    if return_all_temps:
        # Average across samples for each temperature
        log_Z_all_temps = np.array(log_Z_all_temps)  # Shape: (n_samples, n_steps+1)
        log_Z_means = np.mean(log_Z_all_temps, axis=0)
        log_Z_stds = np.std(log_Z_all_temps, axis=0, ddof=1)
        return log_Z_means, log_Z_stds
    else:
        return np.mean(log_Z_estimates), np.std(log_Z_estimates, ddof=1)

In [9]:
def simulated_tempering(model, state, temps, log_Z_temps, n_samples):
    """
    Gibbs sampling with tempering
    
    Args:
        model: Ising model with target Js and Jst
        state: Initial state
        temps: Temperature array T
        log_Z_temps: Log partition functions for each temperature (estimated by AIS)
        n_samples: Number of iterations
    
    Returns:
        samples: List of states
        temp_indices: Temperature index at each iteration
        sums: Sum of variable assignments at each iteration
    """
    dim = model.dim
    target_Js = model.Js
    target_Jst = model.Jst
    L = len(temps)
    
    x = state.copy()
    i = 0  # Start at temperature index 0
    
    samples = []
    temp_indices = []
    
    for _ in tqdm(range(n_samples)):
        if np.random.rand() < 0.5:
            # Gibbs sampling step
            model_i = Ising(dim, temps[i] * target_Js, temps[i] * target_Jst)
            x = gibbs_sampling(model_i, x, 1)[0]
        else:
            # Metropolis-Hastings step
            if i == 0:
                j = 1
            elif i == L - 1:
                j = L - 2
            else:
                j = i + 1 if np.random.rand() < 0.5 else i - 1
            
            log_p_j = log_unnormalized_p(x, temps[j] * target_Js, temps[j] * target_Jst)
            log_p_i = log_unnormalized_p(x, temps[i] * target_Js, temps[i] * target_Jst)
            
            log_ratio = (log_p_j - log_Z_temps[j]) - (log_p_i - log_Z_temps[i])
            
            if np.random.rand() < min(1, np.exp(log_ratio)):
                i = j
        
        samples.append(x.copy())
        temp_indices.append(i)
    
    return samples, temp_indices

In [18]:
dim = 5
Js = 0
Jsts = [1.0, 1.2, 1.5, 2.0]
temps = np.linspace(0.5, 2.0, 31)
n_iterations = 100000
K = 50
M = 50
betas = np.array([0.02 * j for j in range(K + 1)])

# Estimate partition functions
log_Z_all_temps_dict = {}
for Jst in tqdm(Jsts):
    log_Z_list = []
    for T in temps:
        p_0 = Ising(dim, 0, 0)
        p_G = Ising(dim, Js * T, Jst * T)
        log_Z, _ = annealed_importance_sampling(p_0, p_G, betas, M, K)
        log_Z_list.append(log_Z)
    log_Z_all_temps_dict[Jst] = np.array(log_Z_list)


  0%|          | 0/4 [00:00<?, ?it/s]

In [19]:
output_dir = "plots"
if not os.path.exists(output_dir):
    os.makedirs(output_dir, exist_ok=True)

def save_plot(data, ylabel, title, filename):
    plt.figure(figsize=(10, 4))
    if isinstance(data, (int, float)):
        plt.axhline(y=data, linewidth=0.5)
        plt.ylim(0, 2.5)
    else:
        plt.plot(data, linewidth=0.5, alpha=0.7)
    plt.xlabel('Iteration')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    plt.close()

filenames = []
for Jst in tqdm(Jsts):
    model = Ising(dim, Js, Jst)
    model.init_state()
    init_state = model.state.copy()
    
    # Vanilla Gibbs sampling
    x_gibbs = init_state.copy()
    gibbs_sums = []
    for _ in tqdm(range(n_iterations), desc=f"Gibbs Jst={Jst}", leave=False):
        x_gibbs = gibbs_sampling(Ising(dim, Js, Jst), x_gibbs, 1)[0]
        gibbs_sums.append(np.sum(x_gibbs))
    
    # Simulated tempering
    samples_temp, temp_indices = simulated_tempering(
        model, init_state.copy(), temps, log_Z_all_temps_dict[Jst], n_iterations
    )
    temp_sums = [np.sum(samples_temp[i]) for i in range(len(samples_temp))]
    
    # Samples at T=1
    T1_idx = np.argmin(np.abs(temps - 1.0))  # get the index of temperature 1
    sums_at_T1 = [np.sum(samples_temp[i]) for i in range(len(samples_temp)) if temp_indices[i] == T1_idx]
    
    # Save plots
    filenames.append(f'{output_dir}/gibbs_temp_Jst{Jst}.png')
    filenames.append(f'{output_dir}/gibbs_sum_Jst{Jst}.png')
    filenames.append(f'{output_dir}/tempering_temp_Jst{Jst}.png')
    filenames.append(f'{output_dir}/tempering_sum_Jst{Jst}.png')
    filenames.append(f'{output_dir}/tempering_T1_Jst{Jst}.png')
    
    save_plot(1.0, 'Temperature', f'Gibbs Temperature (Jst={Jst})', filenames[-5])
    save_plot(gibbs_sums, 'Sum of assignments', f'Gibbs Sum (Jst={Jst})', filenames[-4])
    save_plot(temps[temp_indices], 'Temperature', f'Tempering Temperature (Jst={Jst})', filenames[-3])
    save_plot(temp_sums, 'Sum of assignments', f'Tempering Sum (Jst={Jst})', filenames[-2])
    save_plot(sums_at_T1, 'Sum of assignments', f'Tempering at T=1 (Jst={Jst})', filenames[-1])


  0%|          | 0/4 [00:00<?, ?it/s]

Gibbs Jst=1.0:   0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Gibbs Jst=1.2:   0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Gibbs Jst=1.5:   0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Gibbs Jst=2.0:   0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

In [20]:
merge_images(filenames, len(Jsts), 5, "simulated_tempering.png")