In [1]:
import os
import numpy as np
import datetime
import shutil
import subprocess
import shlex

from scipy.spatial.distance import cdist
from potfit_nn import PotfitNN

In [2]:
_POTFIT_M_TEMPLATE = os.path.abspath("potfit.m")

In [3]:
p_list = np.load("/home/jhli/H2O-Ar-pes/cqpes-legacy-low-level/data/p.npy")[:, 1:]

V_delta_list = np.load("/home/jhli/H2O-Ar-pes/cqpes-legacy-delta-machine-learning/data/delta_V.npy").reshape((-1, 1))
V_low_list = np.load("/home/jhli/H2O-Ar-pes/cqpes-legacy-low-level/data/V.npy").reshape((-1, 1))

In [4]:
n_models = 3
n_queries = 19
n_samples_per_query = 100

query_idx = np.random.choice(range(len(p_list)), size=n_samples_per_query, replace=False)
pool_idx = np.delete(range(len(p_list)), query_idx, axis=0)

In [5]:
def npy_to_potfit(p_list, V_list, potfit_txt):
    potfit_data = np.hstack((p_list, V_list))
    np.savetxt(potfit_txt, potfit_data)

In [6]:
# initial
workdir = os.path.abspath(f"al-model-{datetime.datetime.now()}")
initial_dir = os.path.join(workdir, "initial")
os.makedirs(initial_dir)

np.savetxt(os.path.join(initial_dir, "query_idx.txt"), query_idx, fmt="%d")

committee = []

for idx in range(n_models):
    ckpt_dir = os.path.join(workdir, "initial", f"model-{idx:03d}")
    os.makedirs(ckpt_dir)
    
    shutil.copy(_POTFIT_M_TEMPLATE, ckpt_dir)
    npy_to_potfit(
        p_list=p_list[query_idx],
        V_list=V_delta_list[query_idx],
        potfit_txt=os.path.join(ckpt_dir, "data.txt"),
    )
    subprocess.run(
        args=shlex.split("bash -c 'source /opt/hpc4you/scripts/enable_matlab.sh; matlab < potfit.m > potfit.log'"),
        cwd=ckpt_dir,
    )

    model = PotfitNN(
        weights_file=os.path.join(ckpt_dir, "weights-1.txt"),
        biases_file=os.path.join(ckpt_dir, "biases-1.txt"),
    )

    committee.append(model)

In [7]:
def vote(committee, query_idx, pool_idx, n):
    # energy
    error_pred = np.zeros_like(V_low_list[pool_idx])
    for model in committee:
        V_pred = np.apply_along_axis(model.forward, axis=1, arr=p_list[pool_idx])
        V_true = V_delta_list[pool_idx]
        error_pred += (V_pred - V_true)**2.0

    error_pred = error_pred.flatten()

    print(f"{error_pred.shape = }")

    # distance
    dist = cdist(
        p_list[pool_idx],
        p_list[query_idx],
        metric="euclidean",
    ).sum(axis=1).reshape((-1, 1)).flatten()

    # print(f"{dist.shape = }")

    # lower energy
    # f3 = np.ones_like(V_low_list[pool_idx].flatten())
    # f3 = np.exp(-1.5 * V_low_list[pool_idx].flatten())
    # f3 = np.exp(-2.0 * V_low_list[pool_idx].flatten())
    # f3 = np.exp(-1.0 * V_low_list[pool_idx].flatten())
    f3 = np.exp(-2.5 * V_low_list[pool_idx].flatten())

    # print(f"{f3.shape = }")

    res_list = (error_pred**0.5 * dist * f3).flatten()
    # print(res_list.shape)

    vote_idx = res_list.argsort()[-n:]

    return vote_idx

In [8]:
for query_loop in range(n_queries):
    print(f"Query: {query_loop}")
    query_dir = os.path.join(workdir, f"query-{query_loop:03d}")
    os.makedirs(query_dir)

    # vote
    vote_idx = vote(committee, query_idx, pool_idx, n_samples_per_query)
    query_idx = np.concatenate((query_idx, pool_idx[vote_idx]))
    pool_idx = np.delete(pool_idx, vote_idx, axis=0)

    print(f"{len(query_idx) = }")
    print(f"{len(pool_idx) = }")

    np.savetxt(os.path.join(query_dir, "query_idx.txt"), query_idx, fmt="%d")

    committee = []

    for model_idx in range(n_models):
        ckpt_dir = os.path.join(query_dir, f"model-{model_idx:03d}")
        os.makedirs(ckpt_dir)

        shutil.copy(_POTFIT_M_TEMPLATE, ckpt_dir)
        npy_to_potfit(
            p_list=p_list[query_idx],
            V_list=V_delta_list[query_idx],
            potfit_txt=os.path.join(ckpt_dir, "data.txt"),
        )
        subprocess.run(
            args=shlex.split("bash -c 'source /opt/hpc4you/scripts/enable_matlab.sh; matlab < potfit.m > potfit.log'"),
            cwd=ckpt_dir,
        )

        model = PotfitNN(
            weights_file=os.path.join(ckpt_dir, "weights-1.txt"),
            biases_file=os.path.join(ckpt_dir, "biases-1.txt"),
        )

        committee.append(model)

Query: 0
error_pred.shape = (32424,)
dist.shape = (32424,)
f3.shape = (32424,)
(32424,)
len(query_idx) = 200
len(pool_idx) = 32324
Query: 1
error_pred.shape = (32324,)
dist.shape = (32324,)
f3.shape = (32324,)
(32324,)
len(query_idx) = 300
len(pool_idx) = 32224
Query: 2
error_pred.shape = (32224,)
dist.shape = (32224,)
f3.shape = (32224,)
(32224,)
len(query_idx) = 400
len(pool_idx) = 32124
Query: 3
error_pred.shape = (32124,)
dist.shape = (32124,)
f3.shape = (32124,)
(32124,)
len(query_idx) = 500
len(pool_idx) = 32024
Query: 4
error_pred.shape = (32024,)
dist.shape = (32024,)
f3.shape = (32024,)
(32024,)
len(query_idx) = 600
len(pool_idx) = 31924
Query: 5
error_pred.shape = (31924,)
dist.shape = (31924,)
f3.shape = (31924,)
(31924,)
len(query_idx) = 700
len(pool_idx) = 31824
Query: 6
error_pred.shape = (31824,)
dist.shape = (31824,)
f3.shape = (31824,)
(31824,)
len(query_idx) = 800
len(pool_idx) = 31724
Query: 7
error_pred.shape = (31724,)
dist.shape = (31724,)
f3.shape = (31724,)
(317