In [22]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.nn.functional import conv2d
from scipy.ndimage import convolve, generate_binary_structure
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [23]:
device

device(type='cpu')

In [24]:
def get_clustered_temperatures(n_temperatures, center, low, high, fraction_center=0.7, width=0.15):

    # 1. Calcular cuántos puntos van al centro
    n_center = int(n_temperatures * fraction_center)
    n_side = (n_temperatures - n_center) // 2
    remainder = n_temperatures - (2 * n_side + n_center)  # por si es impar

    # 2. Partes: izquierda (sparse), centro (dense), derecha (sparse)
    Ts_left = np.linspace(low, center - width, n_side, endpoint=False)
    Ts_center = np.linspace(center - width, center + width, n_center, endpoint=False)
    Ts_right = np.linspace(center + width, high, n_side + remainder, endpoint=True)

    # 3. Unir y retornar
    Ts = np.concatenate([Ts_left, Ts_center, Ts_right])
    return np.sort(Ts)

In [25]:
kernel_np = generate_binary_structure(2, 1)
kernel_np[1][1] = False
KERNEL = torch.tensor(kernel_np.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(device)

def get_energy_arr(lattices):
    arr = -lattices * conv2d(lattices, KERNEL, padding='same')
    return arr
    
def get_energy(lattices):
    # applies the nearest neighbours summation
    return get_energy_arr(lattices).sum(axis=(1,2,3))

def get_dE_arr(lattices):
    return -2*get_energy_arr(lattices)

In [26]:
@torch.no_grad()
def metropolis(spin_tensor_batch, warm_times, eq_times, T, N):
    
    energies = torch.empty((eq_times, spin_tensor_batch.shape[0]), device=device)
    avg_spins = torch.empty((eq_times, spin_tensor_batch.shape[0]), device=device)
    T = T.reshape([-1,1,1,1])

    for t in range(warm_times + eq_times):
        i = np.random.randint(0,2)
        j = np.random.randint(0,2)
        dE = get_dE_arr(spin_tensor_batch)[:,:,i::2,j::2]
        probs = torch.exp(-dE / T)
        rands = torch.rand_like(dE)

        # Apply the Metropolis condition
        flip_mask = (dE < 0) | (rands < probs)
        spin_tensor_batch[:,:,i::2,j::2][flip_mask] *= -1

        # After warm-up, record values
        if t >= warm_times:
            idx = t - warm_times
            energies[idx] = get_energy(spin_tensor_batch)
            avg_spins[idx] = spin_tensor_batch.sum(dim=(1,2,3)) / (N * N)

    return avg_spins, energies, spin_tensor_batch

In [27]:
def generate_random_lattices(n_lattices, N, p=0.75):
    """Generate n_lattices of size N x N with spin values -1 or +1."""
    init_random = torch.rand((n_lattices, 1, N, N), device=device)
    lattices = torch.ones_like(init_random, device=device)
    lattices[init_random < p] = -1
    return lattices

In [30]:
warm_times = 2000
eq_times = 100

# Ts = get_clustered_temperatures(n_temperatures=100, center=2.26, low=0.5, high=4)
Ts = np.linspace(0.5, 4, 6)

N = 50
n_lattices = 2

avg_spins = torch.empty((len(Ts), n_lattices), device=device)
std_spins = torch.empty((len(Ts), n_lattices), device=device)
avg_energies = torch.empty((len(Ts), n_lattices), device=device)
std_energies = torch.empty((len(Ts), n_lattices), device=device)
last_states = torch.empty((len(Ts), n_lattices, 1, N, N), device=device, dtype=torch.int8)


for i, T in enumerate(Ts):

    lattices = generate_random_lattices(n_lattices, N, p=0.25)

    temps = T*torch.ones(lattices.shape[0]).to(device)
    spins, energies, state = metropolis(lattices, warm_times, eq_times, temps, N)

    avg_spins[i] = torch.mean(spins, axis=0)
    std_spins[i] = torch.std(spins, axis=0)
    avg_energies[i] = torch.mean(energies, axis=0)
    std_energies[i] = torch.std(energies, axis=0)
    last_states[i] = state

    print(f"Temperature {T:.2f} done")

# Save avg and std of spins and energies
np.savez(f'N_{N}_avg_std_spins_energies.npz',
         avg_spins=avg_spins.cpu().numpy(), 
         std_spins=std_spins.cpu().numpy(), 
         avg_energies=avg_energies.cpu().numpy(), 
         std_energies=std_energies.cpu().numpy()
         )

Temperature 0.50 done
Temperature 1.20 done
Temperature 1.90 done
Temperature 2.60 done
Temperature 3.30 done
Temperature 4.00 done
