In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import math, os, time, copy
import torch.fft as tfft
import pandas as pd
import torch_dct as dct
from numpy import size
from mpl_toolkits.axes_grid1 import make_axes_locatable
from ewaldnn2d import *

torch.random.manual_seed(1234) # for reproducibility

# Global settings
dtype = torch.float64
device = "cpu"
N_batch = 100
N_epochs = 10000
lr = 1e-1 # we will use a LR scheduler, so this is just an initial value
min_delta = 1e-5 # min change in the monitored quantity to qualify as an improvement
patience = 100    # epochs to wait for improvement before stopping training'
pad_mode = "reflect" # in this example, we use reflective padding for local feature generation
N_pow = 1 # number of local features per grid point
N_train = 1500
N_test = 250
N_val = 250
N_energy_terms = 1 # number of energy terms for the LERN model

# grid and basis settings
N_x = 32 # number of grid points in x direction
N_y = N_x # number of grid points in y direction
m_x = torch.arange(0, N_x, dtype=dtype, device=device)             # (N_x,)
m_y = torch.arange(0, N_y, dtype=dtype, device=device)             # (N_y,)
abs_val = torch.sqrt(m_x[:, None]**2 + m_y[None, :]**2)  # (M_x, M_y)
x = torch.linspace(0, 1, N_x, dtype=dtype, device=device)            # (N_x,)
y = torch.linspace(0, 1, N_y, dtype=dtype, device=device)            # (N_y,)
DM_x = torch.cos(torch.pi * torch.outer(m_x, x))                  # (M_x, N_x)
DM_y = torch.cos(torch.pi * torch.outer(m_y, y))                  # (M_y, N_y)
DerDM_x = -torch.pi * m_x[:, None] * torch.sin(torch.pi * torch.outer(m_x, x))  # (M_x, N_x) # derivative of design matrix
DerDM_y = -torch.pi * m_y[:, None] * torch.sin(torch.pi * torch.outer(m_y, y))  # (M_y, N_y) # derivative of design matrix

data_regime = "rough" # "smooth" or "rough"
if data_regime == "smooth":
    M_cutoff = 10 # maximum harmonic   
    std_harm = 2.0 / (1.0 + 0.2 * abs_val)**2 * (abs_val <= M_cutoff).double()  # (M_x, M_y)
elif data_regime == "rough":
    std_harm = 2.0 / (1.0 + 0.0 * abs_val)**2 # (M_x, M_y)
else:
    raise ValueError("regime must be 'smooth' or 'rough'")
std_harm[0, 0] = 0.0 # no uniform density offset

# interaction kernel parameters
qs = 0.5 # screening momentum
amp = 1.0 # amplitude of interaction kernel
kernel_regime = "screened_coulomb"

# unscreened Hartree-Fock total energy function
def E_HF(rho: torch.Tensor, d_rho_x: torch.Tensor, d_rho_y: torch.Tensor, eng_dens_flag: bool = False) -> torch.Tensor:
        return amp * E_int_ms_dct(rho, kernel=kernel_regime, eng_dens_flag=eng_dens_flag, qs=0.0)

# screened total energy function
def E_SC(rho: torch.Tensor, d_rho_x: torch.Tensor, d_rho_y: torch.Tensor, eng_dens_flag: bool = False) -> torch.Tensor:
        return amp * E_int_ms_dct(rho, kernel=kernel_regime, eng_dens_flag=eng_dens_flag, qs=qs)
    
# generate train/test split
flag_generate_data = True # if True, generate new data; if False, load existing data from disk
if flag_generate_data:
    N_batch_int = 10 # number of density profiles per data generation batch
    torch.manual_seed(1234) # for reproducibility
    rho_train, d_rho_x_train, d_rho_y_train, a_train, E_HF_train, E_loc_SC_train = generate_SC_data_2d(N_train, N_batch_int, E_HF, E_SC, std_harm=std_harm, DM_x=DM_x, DerDM_x=DerDM_x, DM_y=DM_y, DerDM_y=DerDM_y)
    rho_test, d_rho_x_test, d_rho_y_test, a_test, E_HF_test, E_loc_SC_test = generate_SC_data_2d(N_test, N_batch_int, E_HF, E_SC, std_harm=std_harm, DM_x=DM_x, DerDM_x=DerDM_x, DM_y=DM_y, DerDM_y=DerDM_y) 
    rho_val, d_rho_x_val, d_rho_y_val, a_val, E_HF_val, E_loc_SC_val = generate_SC_data_2d(N_val, N_batch_int, E_HF, E_SC, std_harm=std_harm, DM_x=DM_x, DerDM_x=DerDM_x, DM_y=DM_y, DerDM_y=DerDM_y)  
    # save data to disk
    os.makedirs("DATA2d", exist_ok=True)

    fname = f"DATA2d/LERN_dataset_{data_regime}_{kernel_regime}_{qs}_{amp}_{N_x}_{N_y}.pt"

    torch.save(
        {
            "rho_train": rho_train,
            "d_rho_x_train": d_rho_x_train,
            "d_rho_y_train": d_rho_y_train,
            "a_train": a_train,
            "E_HF_train": E_HF_train,
            "E_loc_SC_train": E_loc_SC_train,
            "rho_val": rho_val,
            "d_rho_x_val": d_rho_x_val,
            "d_rho_y_val": d_rho_y_val,
            "a_val": a_val,
            "E_HF_val": E_HF_val,
            "E_loc_SC_val": E_loc_SC_val,
            "rho_test": rho_test,
            "d_rho_x_test": d_rho_x_test,
            "d_rho_y_test": d_rho_y_test,
            "a_test": a_test,
            "E_HF_test": E_HF_test,
            "E_loc_SC_test": E_loc_SC_test,
            "data_regime": data_regime,
            "kernel_regime": kernel_regime,
        },
        fname,
    )
else:
    data = torch.load(f"DATA2d/LERN_dataset_{data_regime}_{kernel_regime}_{qs}_{amp}_{N_x}_{N_y}.pt")
    rho_train = data["rho_train"]
    E_HF_train = data["E_HF_train"]
    E_loc_SC_train = data["E_loc_SC_train"]
    rho_test = data["rho_test"]
    E_HF_test = data["E_HF_test"]
    E_loc_SC_test = data["E_loc_SC_test"]
    rho_val = data["rho_val"]
    E_HF_val = data["E_HF_val"]
    E_loc_SC_val = data["E_loc_SC_val"]

features_train = generate_loc_features_rs(rho_train, N_pow=N_pow)  # (N_train, N_x, N_y, N_pow)
features_test  = generate_loc_features_rs(rho_test, N_pow=N_pow)   # (N_test, N_x, N_y, N_pow)
features_val   = generate_loc_features_rs(rho_val, N_pow=N_pow)    # (N_val, N_x, N_y, N_pow)

# Extend features with neighbor information
R_feat = 1.0 # radius for neighbor feature extension
features_train = extend_features_neighbors_2d(features_train, R=R_feat)
features_test  = extend_features_neighbors_2d(features_test, R=R_feat)
features_val   = extend_features_neighbors_2d(features_val, R=R_feat)

# Normalize features
mean_feat, std_feat = compute_normalization_stats(features_train)
features_train_norm = normalize_features(features_train, mean_feat, std_feat)
features_test_norm = normalize_features(features_test, mean_feat, std_feat)
features_val_norm = normalize_features(features_val, mean_feat, std_feat)

# Supplement the features with the unnormalized SC energy term (for the LERN model)
features_train_norm = torch.cat([features_train_norm, E_loc_SC_train], dim=-1)  # (N_train, N_x, N_y, N_feat + 1)
features_test_norm  = torch.cat([features_test_norm,  E_loc_SC_test], dim=-1)   # (N_test,  N_x, N_y, N_feat + 1)
features_val_norm   = torch.cat([features_val_norm,   E_loc_SC_val], dim=-1)    # (N_val,   N_x, N_y, N_feat + 1)

# Normalize targets
E_mean = E_HF_train.mean()
E_std = E_HF_train.std()
E_HF_train_norm = (E_HF_train - E_mean) / E_std
E_HF_test_norm = (E_HF_test - E_mean) / E_std
E_HF_val_norm = (E_HF_val - E_mean) / E_std

# Datasets
train_dataset = TensorDataset(features_train_norm, E_HF_train_norm)
val_dataset   = TensorDataset(features_val_norm,   E_HF_val_norm)
test_dataset  = TensorDataset(features_test_norm,  E_HF_test_norm)

# Loaders
train_loader = DataLoader(train_dataset, batch_size=N_batch, shuffle=True,  drop_last=False)
val_loader   = DataLoader(val_dataset,   batch_size=N_batch, shuffle=False, drop_last=False)
test_loader  = DataLoader(test_dataset,  batch_size=N_batch, shuffle=False, drop_last=False)

_, _, _, N_feat = features_train.shape
print(f"Number of local features per grid point: {N_feat}")

In [None]:
ckpt_dir = "LearningSC2d_checkpoints"
flag_train = True  # set to True to train models
learning_regime = "LERN2d"

n_hidden_list = [1, 2, 3, 4]
n_neurons_list = [8, 16, 32, 64]

if flag_train:
    for n_hidden in n_hidden_list:
        for n_neurons in n_neurons_list:

            run_name = f"LERN2d_" + data_regime + '_' + kernel_regime + f"_{qs}_{amp}_{N_x}_{N_y}_{N_feat}_{N_energy_terms}_{n_hidden}_{n_neurons}"

            torch.manual_seed(1234) # for reproducibility    
            model = LERN2d(
                    N_x=N_x,
                    N_y=N_y,
                    N_energy_terms=N_energy_terms,
                    N_feat=N_feat,
                    n_hidden=n_hidden,
                    n_neurons=n_neurons,
                    mean_feat=mean_feat,
                    std_feat=std_feat,
                    E_mean=E_mean,
                    E_std=E_std,
                ).to(device=device, dtype=dtype)

            optimizer = optim.Adam(model.parameters(), lr=lr)
            criterion = nn.MSELoss()

            # Reduce LR when val loss plateaus
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='min', factor=0.5, patience=50, cooldown=2, min_lr=1e-6
            )

            hist, best_epoch = train_with_early_stopping(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                max_epochs=N_epochs,
                patience=patience,
                min_delta=min_delta,
                ckpt_dir=ckpt_dir,
                run_name=run_name,
                learning_regime=learning_regime,
                N_x=N_x,
                N_y=N_y,
                device=device,
            )

In [None]:
n_hidden = 1
n_neurons = 8

run_name = f"LERN2d_" + data_regime + '_' + kernel_regime + f"_{qs}_{amp}_{N_x}_{N_y}_{N_feat}_{N_energy_terms}_{n_hidden}_{n_neurons}"

path = ckpt_dir + f"/{run_name}_history.csv"
hist_df = pd.read_csv(path)
print(hist_df.head())
hist_df.plot(x="epoch", y=["train_loss", "val_loss"], logy=True, grid=True, title=run_name)

In [None]:
model, normalization, epoch, val_loss = load_checkpoint(
        ckpt_dir + f"/{run_name}_best.pt",
        LERN2d,
        device=device
    )
model = model.to(device=device, dtype=dtype)

k = 0  # index of the test sample to visualize
features_example = features_test_norm[k:k+1, :, :, :N_feat]  # (1, N_x, N_y, N_feat)
factors = model.local_nn(features_example)
rho_example = rho_test[k:k+1, :, :]  # (1, N_x, N_y)


plt.figure(figsize=(5, 4))
im = plt.imshow(
    rho_example.numpy().squeeze().T,              # transpose so x is horizontal, y vertical
    origin="lower",
    extent=[0, 1, 0, 1],   # x from 0 to 1, y from 0 to 1
    aspect="equal"
)
plt.colorbar(im, label=r"$\rho(x,y)$")
plt.xlabel(r"$x$")
plt.ylabel(r"$y$")
if data_regime == "rough":
    plt.title("Sampled 2D density profile (rough)")
elif data_regime == "smooth":
    plt.title("Sampled 2D density profile (smooth)")
plt.tight_layout()
plt.show()


plt.figure(figsize=(5, 4))
im = plt.imshow(
    factors.detach().numpy().squeeze().T,              # transpose so x is horizontal, y vertical
    origin="lower",
    extent=[0, 1, 0, 1],   # x from 0 to 1, y from 0 to 1
    aspect="equal"
)
plt.colorbar(im, label=r"$factors(x,y)$")
plt.xlabel(r"$x$")
plt.ylabel(r"$y$")
plt.tight_layout()
plt.show()
