In [None]:
import numpy as np
import matplotlib.pyplot as plt
import dftqml.utils
import dftqml.tfmodel
import dftqml.xgb
import dftqml.data_processing
from os import path
import h5py

import absl.logging
absl.logging.set_verbosity('error')



DFT_CNN_DIR = "./models/dftio/cnn"
RDMFT_CNN_DIR = "./models/rdmftio/cnn"
DATA_DIR = "./data"
N_SPLITS = 5

def history_plot(history_dict):
    plt.plot(history_dict["mse"], label="mse")
    plt.plot(history_dict["val_mse"], label="val_mse")
    plt.xlabel("Epoch")
    plt.ylabel("Loss (mse)")
    plt.grid()
    plt.yscale("log")
    plt.legend()

def dft_cnn_path(L, N, U, ndata, split):
    return path.join(DFT_CNN_DIR, f"L{L}-N{N}-U{U}", f"ndata{ndata}", f"split{split}")

def rdmft_cnn_path(L, N, U, ndata, split, expand=False, augment=False, shuffle=False, batch_size=None):
    directory = RDMFT_CNN_DIR
    if expand:
        directory += "-expanded"
    if augment:
        directory += "-augmented"
    if shuffle:
        directory += "-shuffled"
    if (batch_size is not None) and (batch_size != 2*10*L):
        directory += f"-batch{batch_size}"
    return path.join(directory, f"L{L}-N{N}-U{U}", f"ndata{ndata}", f"split{split}")

def mse(prediction, exact):
    return np.mean((np.ravel(prediction) - np.ravel(exact)) ** 2)

In [None]:
L = 8
N = 8
U = 4.0
max_ndata = 50000
n_test_data = 1000
input_file = path.join(DATA_DIR, f"L{L}-N{N}-U{U}.hdf5")

with h5py.File(input_file, "r") as f:
    test_densities = f["densities"][max_ndata : max_ndata + n_test_data]
    test_dft_energies = f["dft_energies"][
        max_ndata : max_ndata + n_test_data
    ]  # kinetic + interaction energy
    test_one_rdms = f["one_rdms"][max_ndata : max_ndata + n_test_data]
    test_rdmft_energies = f["rdmft_energies"][max_ndata : max_ndata + n_test_data]

# transpose RDMs so the locality index is last, as expected by the data augmentation and model
test_one_rdms = test_one_rdms.transpose(0, 2, 1)

# get the expanded 1RDMs test set
expanded_test_one_rdms = dftqml.data_processing.one_rdm_compressed_to_all_correlators(test_one_rdms)

# visualize a single model

In [None]:
def plot_single_model(
    ndata, functional="dft", split=0, expand=False, augment=False, shuffle=False, batch_size=None
):
    if functional == "dft":
        model, history_dict = dftqml.tfmodel.load_model(
            dft_cnn_path(L, N, U, ndata, 0), get_history_dict=True
        )
        test_x = test_densities
        test_y = test_dft_energies
        model_label = "CNN-DFT"

    elif functional == "rdmft":
        model, history_dict = dftqml.tfmodel.load_model(
            rdmft_cnn_path(L, N, U, ndata, split, expand, augment, shuffle, batch_size),
            get_history_dict=True,
        )
        test_x = expanded_test_one_rdms if expand else test_one_rdms
        test_y = test_rdmft_energies
        model_label = (
            "CNN-RDMFT"
            + ("-expanded" if expand else "")
            + ("-augmented" if augment else "")
            + ("-shuffled" if shuffle else "")
            + (f"-batch{batch_size}" if batch_size is not None else "")
        )
    else:
        raise ValueError(f"Unknown functional: {functional}")

    test_predictions = model.predict(test_x)
    test_mse = mse(test_predictions, test_y)

    plt.figure(figsize=(8, 3.5))
    plt.suptitle(
        f"{model_label} model trained on exact data.  "
        f"L{L} N{N} U{U} ndata{ndata}\n"
        f"Test set MSE: {test_mse:.2e}"
    )

    # History plot
    plt.subplot(1, 2, 1)
    history_plot(history_dict)

    # Performance plot
    plt.subplot(1, 2, 2, aspect="equal")
    dftqml.utils.performance_plot(model, test_x, test_y, label="CNN", s=2, alpha=0.5)

    plt.tight_layout()
    plt.show()

### DFT CNN

In [None]:
plot_single_model(1000, functional='dft', split=0)

In [None]:
plot_single_model(1000, functional='rdmft', split=0)

In [None]:
plot_single_model(1000, functional='rdmft', split=0, expand=True)

In [None]:
for split in range(5):
    plot_single_model(1000, functional='rdmft', split=split, augment=True)

In [None]:
plot_single_model(1000, functional='rdmft', split=0, shuffle=True, augment=True)

In [None]:
plot_single_model(1000, functional='rdmft', split=0, shuffle=True, augment=True, batch_size=640)

In [None]:
plot_single_model(1000, functional='rdmft', split=1, shuffle=True, augment=True, batch_size=2560)

In [None]:
plot_single_model(1000, functional='rdmft', split=1, shuffle=True, augment=True, batch_size=10240)

## MSE scaling with training set size

In [None]:
ndata_list = np.array([100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000])


def mse_model_dft(ndata, split):
    model = dftqml.tfmodel.load_model(dft_cnn_path(L, N, U, ndata, split), get_history_dict=False)
    return mse(model(test_densities), test_dft_energies)


def mse_model_rdmft(ndata, split, expand=False, augment=False, shuffle=False, batch_size=None):
    try:
        model = dftqml.tfmodel.load_model(
            rdmft_cnn_path(
                L,
                N,
                U,
                ndata,
                split,
                expand=expand,
                augment=augment,
                shuffle=shuffle,
                batch_size=batch_size,
            ),
            get_history_dict=False,
        )
        input = test_one_rdms if not expand else expanded_test_one_rdms
        return mse(model(input), test_rdmft_energies)
    except (FileNotFoundError, OSError):
        return None


mse_list_dft = np.array(
    [[mse_model_dft(ndata, split) for split in range(N_SPLITS)] for ndata in ndata_list]
)

mse_list_rdmft = np.array(
    [[mse_model_rdmft(ndata, split) for split in range(N_SPLITS)] for ndata in ndata_list]
)

mse_list_exp_rdmft = np.array(
    [
        [mse_model_rdmft(ndata, split, expand=True) for split in range(N_SPLITS)]
        for ndata in ndata_list
    ]
)

mse_list_aug_rdmft = np.array(
    [
        [mse_model_rdmft(ndata, split, augment=True) for split in range(N_SPLITS)]
        for ndata in ndata_list
    ]
)

mse_list_aug_shuf_rdmft = np.array(
    [
        [mse_model_rdmft(ndata, split, augment=True, shuffle=True) for split in range(N_SPLITS)]
        for ndata in ndata_list
    ]
)

mse_list_aug_shuf_b10240_rdmft = np.array(
    [
        [
            mse_model_rdmft(ndata, split, augment=True, shuffle=True, batch_size=10240)
            for split in range(N_SPLITS)
        ]
        for ndata in ndata_list
    ]
)

In [None]:
# *** DFT ***
plt.plot(ndata_list, mse_list_dft, "_", c='C0', alpha=1)
plt.errorbar(ndata_list, np.mean(mse_list_dft, axis=1), np.std(mse_list_dft, axis=1), label="DFT")

# Fit a linear model to the log-log data
log_ndata = np.log10(ndata_list)
log_mse = np.log10(np.mean(mse_list_dft, axis=1))
coefficients = np.polyfit(log_ndata[:3], log_mse[:3], 1)
poly = np.poly1d(coefficients)
plt.plot(
    ndata_list, 10 ** poly(log_ndata), label=f"Fit DFT: slope={coefficients[0]:.2f}", linestyle="--", c='k'
)


# *** RDMFT with compressed RDMs ***
mask = np.all(mse_list_rdmft != None, axis=1)
plt.plot(ndata_list, mse_list_rdmft, "_", c='C1', alpha=1)
plt.errorbar(
    ndata_list[mask],
    np.mean(mse_list_rdmft[mask], axis=1),
    np.std(mse_list_rdmft[mask].astype(float), axis=1),
    label="RDMFT, compressed 1RDM"
)


# *** RDMFT with expanded RDMs ***
mask = np.all(mse_list_exp_rdmft != None, axis=1)
plt.plot(ndata_list, mse_list_exp_rdmft, "_", c='C2', alpha=1)
plt.errorbar(
    ndata_list[mask],
    np.mean(mse_list_exp_rdmft[mask], axis=1),
    np.std(mse_list_exp_rdmft[mask].astype(float), axis=1),
    label="RDMFT, expanded 1RDM"
)

# *** RDMFT with augmented RDMs ***
mask = np.all(mse_list_aug_rdmft != None, axis=1)
plt.plot(ndata_list, mse_list_aug_rdmft, "_", c='C3', alpha=1)
plt.errorbar(
    ndata_list[mask],
    np.mean(mse_list_aug_rdmft[mask], axis=1),
    np.std(mse_list_aug_rdmft[mask].astype(float), axis=1),
    label="RDMFT, perm-augmented 1RDM"
)

# *** RDMFT with augmented and shuffled RDMs ***
mask = np.all(mse_list_aug_shuf_rdmft != None, axis=1)
plt.plot(ndata_list, mse_list_aug_shuf_rdmft, "_", c='C4', alpha=1)
plt.errorbar(
    ndata_list[mask],
    np.mean(mse_list_aug_shuf_rdmft[mask], axis=1),
    np.std(mse_list_aug_shuf_rdmft[mask].astype(float), axis=1),
    label="RDMFT, perm-augmented and shuffled 1RDM"
)

# *** RDMFT with augmented, shuffled and batched RDMs ***
mask = np.all(mse_list_aug_shuf_b10240_rdmft != None, axis=1)
plt.plot(ndata_list, mse_list_aug_shuf_b10240_rdmft, "x", c='C5', alpha=1)
plt.errorbar(
    ndata_list[mask],
    np.mean(mse_list_aug_shuf_b10240_rdmft[mask], axis=1),
    np.std(mse_list_aug_shuf_b10240_rdmft[mask].astype(float), axis=1),
    label="RDMFT, perm-augmented, shuffled and batched 1RDM"
)


plt.legend(loc = "lower center", bbox_to_anchor=(0.5,1), ncol=2)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Number of training data")
plt.ylabel("MSE on test set")

## Compare batch sizes

In [None]:
def plot_mse_vs_ndata(color, expand=False, augment=False, shuffle=False, batch_size=None):
    data = np.array(
        [
            [
                mse_model_rdmft(
                    ndata,
                    split,
                    augment=augment,
                    expand=expand,
                    shuffle=shuffle,
                    batch_size=batch_size,
                )
                for split in range(N_SPLITS)
            ]
            for ndata in ndata_list
        ]
    )
    label = 'RDMFT'
    if expand:
        label += ' exp.'
    if augment:
        label += ' aug.'
    if shuffle:
        label += ' shuf.'
    if batch_size is not None:
        label += f' batchsize {batch_size}'
    else:
        label += f' batchsize {2*L*10}'
    
    mask = np.all(data != None, axis=1)
    plt.plot(ndata_list[mask], data[mask], "_", c='C1', alpha=1)
    plt.errorbar(
        ndata_list[mask],
        np.mean(data[mask], axis=1),
        np.std(data[mask].astype(float), axis=1),
        label=label
    )

In [None]:
# *** DFT ***
plt.plot(ndata_list, mse_list_dft, "_", c='C0', alpha=1)
plt.errorbar(ndata_list, np.mean(mse_list_dft, axis=1), np.std(mse_list_dft, axis=1), label="DFT")

# Fit a linear model to the log-log data
log_ndata = np.log10(ndata_list)
log_mse = np.log10(np.mean(mse_list_dft, axis=1))
coefficients = np.polyfit(log_ndata[:3], log_mse[:3], 1)
poly = np.poly1d(coefficients)
plt.plot(
    ndata_list, 10 ** poly(log_ndata), label=f"Fit DFT: slope={coefficients[0]:.2f}", linestyle="--", c='k'
)

plot_mse_vs_ndata(color='C2', augment=True, shuffle=True,)

plot_mse_vs_ndata(color='C3', augment=True, shuffle=True, batch_size=640)

plot_mse_vs_ndata(color='C4', augment=True, shuffle=True, batch_size=2560)

plot_mse_vs_ndata(color='C4', augment=True, shuffle=True, batch_size=10240)


plt.legend(loc = "lower center", bbox_to_anchor=(0.5,1), ncol=2)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Number of training data")
plt.ylabel("MSE on test set")
plt.show()