In [None]:
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax

# Disable future warnings.
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
jax.devices()

In [None]:
def calculate_moment_of_inertia_tensor(masses, positions):
    diag = jnp.sum(positions**2, axis=-1)[..., None, None] * jnp.eye(3)
    outer = positions[..., None, :] * positions[..., :, None]
    return jnp.sum(masses[..., None, None] * (diag - outer), axis=-3)


def generate_datasets(
    key,
    num_train=1000,
    num_valid=100,
    num_points=10,
    min_mass=0.0,
    max_mass=1.0,
    stdev=1.0,
):
    # Generate random keys.
    train_position_key, train_masses_key, valid_position_key, valid_masses_key = (
        jax.random.split(key, num=4)
    )

    # Draw random point masses with random positions.
    train_positions = stdev * jax.random.normal(
        train_position_key, shape=(num_train, num_points, 3)
    )
    train_masses = jax.random.uniform(
        train_masses_key,
        shape=(num_train, num_points),
        minval=min_mass,
        maxval=max_mass,
    )
    valid_positions = stdev * jax.random.normal(
        valid_position_key, shape=(num_valid, num_points, 3)
    )
    valid_masses = jax.random.uniform(
        valid_masses_key,
        shape=(num_valid, num_points),
        minval=min_mass,
        maxval=max_mass,
    )

    # Calculate moment of inertia tensors.
    train_inertia_tensor = calculate_moment_of_inertia_tensor(
        train_masses, train_positions
    )
    valid_inertia_tensor = calculate_moment_of_inertia_tensor(
        valid_masses, valid_positions
    )

    # Return final train and validation datasets.
    train_data = dict(
        positions=train_positions,
        masses=train_masses,
        inertia_tensor=train_inertia_tensor,
    )
    valid_data = dict(
        positions=valid_positions,
        masses=valid_masses,
        inertia_tensor=valid_inertia_tensor,
    )
    return train_data, valid_data

In [None]:
def mean_squared_loss(prediction, target):
    return jnp.mean(optax.l2_loss(prediction, target))

In [None]:
import jax
import jax.numpy as jnp
import numpy as np


def prepare_datasets_test(filename, key, num_train, num_valid, batch_size):
    # Load the dataset.
    dataset = np.load(filename)
    num_data = len(dataset["R"])

    Z = jnp.full(1, 23)
    Z = jnp.append(Z, jnp.full(16, 14))
    Z = jnp.expand_dims(Z, axis=0)
    Z = jnp.repeat(Z, num_data, axis=0)
    num_draw = num_train + num_valid
    if num_draw > num_data:
        raise RuntimeError(
            f"datasets only contains {num_data} points, requested num_train={num_train}, num_valid={num_valid}"
        )

    # Randomly draw train and validation sets from dataset.
    choice = np.asarray(
        jax.random.choice(key, num_data, shape=(num_draw,), replace=False)
    )
    train_choice = choice[:num_train]
    valid_choice = choice[num_train:]

    # Collect and return train and validation sets.
    train_data = dict(
        atomic_numbers=jnp.asarray(Z[train_choice]),
        positions=jnp.asarray(dataset["R"][train_choice]),
    )
    valid_data = dict(
        atomic_numbers=jnp.asarray(Z[valid_choice]),
        positions=jnp.asarray(dataset["R"][valid_choice]),
    )

    # Split the training data into batches
    train_batches = []
    for i in range(0, num_train, batch_size):
        batch_data = {
            "atomic_numbers": train_data["atomic_numbers"][i : i + batch_size],
            "positions": train_data["positions"][i : i + batch_size],
        }
        train_batches.append(batch_data)

    return train_batches, valid_data

In [None]:
class Dipole_Moment(nn.Module):
    # features = 1
    # max_degree = 1
    @nn.compact
    def __call__(self, atomic_numbers, positions):  # Shapes (..., N) and (..., N, 3).
        # 1. Initialize features
        positions -= positions[0, ...]
        x = jnp.concatenate(
            (atomic_numbers[..., None], positions), axis=-1
        )  # Shape (..., N, 4).
        x = x[..., None, :, None]  # Shape (..., N, 1, 3, 1).

        # Incremento de complejidad con más capas densas
        x = e3x.nn.Dense(features=1024)(x)
        x = e3x.nn.relu(x)
        x = nn.LayerNorm()(x)

        x = e3x.nn.Dense(features=512)(x)
        x = e3x.nn.relu(x)
        x = nn.LayerNorm()(x)

        x = e3x.nn.Dense(features=256)(x)
        x = e3x.nn.relu(x)
        x = nn.LayerNorm()(x)

        # Más capas densas
        x = e3x.nn.Dense(features=128)(x)
        x = e3x.nn.relu(x)
        x = nn.LayerNorm()(x)

        # Capas TensorDense adicionales para más complejidad
        x = e3x.nn.TensorDense(features=64, max_degree=2)(x)
        x = e3x.nn.relu(x)
        x = nn.LayerNorm()(x)

        x = e3x.nn.TensorDense(features=32, max_degree=2)(x)
        x = e3x.nn.relu(x)
        x = nn.LayerNorm()(x)

        x = e3x.nn.TensorDense(features=16, max_degree=2)(x)
        x = e3x.nn.relu(x)
        x = nn.LayerNorm()(x)

        # Capa final
        x = e3x.nn.TensorDense(features=1, max_degree=1)(x)
        x = jnp.sum(x, axis=-4)
        y = x[..., 1, 1:4, 0]

        return y

In [None]:
model = Dipole_Moment()

In [13]:
###############predict 200 k

# huziel
filename = "/home/beemoqc2/Documents/e3x_tranfer/SI16VPLUS_E3X_RETRAINED_WB97X_D_TIGHT_TRP_400K_1B_01_POSITION_0.npz"
dataset = np.load(filename, allow_pickle=True)
for key in dataset.keys():
    print(key)

print("R", dataset["R"].shape)
# Modificar el array "R"
dataset_modified = {
    key: np.squeeze(value, axis=0) if key == "R" else value
    for key, value in dataset.items()
}

# Guardar el dataset modificado
np.savez(
    "/home/beemoqc2/Documents/e3x_tranfer/SI16VPLUS_E3X_RETRAINED_WB97X_D_TIGHT_TRP_400K_1B_01_POSITION_0_reshape.npz",
    **dataset_modified,
)

filename = "/home/beemoqc2/Documents/e3x_tranfer/SI16VPLUS_E3X_RETRAINED_WB97X_D_TIGHT_TRP_400K_1B_01_POSITION_0_reshape.npz"
dataset = np.load(filename)
for key in dataset.keys():
    print(key)

print("R", dataset["R"].shape)

R
typ
InLine_txt
name_original
bead
number_beads
theory_level
SuperCell
converter_used
R (1, 200001, 17, 3)
R
typ
InLine_txt
name_original
bead
number_beads
theory_level
SuperCell
converter_used
R (200001, 17, 3)


In [None]:
import pickle
model_save_path = (
    "mode_training_Si16Vplus..DFT.SP-GRD.wB97X-D.tight.Data.5042.R_E_F_D_Q.pkl"
)
with open(model_save_path, "rb") as file:
    best_params = pickle.load(file)

In [None]:
num_train = 200000
num_val = 1
batch_size = 1024  # Ajusta este valor según la capacidad de tu memoria
filename = "/home/beemoqc2/Documents/e3x_tranfer/SI16VPLUS_E3X_RETRAINED_WB97X_D_TIGHT_TRP_400K_1B_01_POSITION_0_reshape.npz"
key = jax.random.PRNGKey(0)
train_batches, valid_data = prepare_datasets_test(
    filename, key, num_train, num_val, batch_size
)

i = 0
print(len(train_batches))
data_final = []
for batch in train_batches:
    i += 1
    print(i)
    Z, positions = (
        batch["atomic_numbers"],
        batch["positions"],
    )
    positions -= positions[0, ...]
    prediction = model.apply(best_params, Z, positions)
    data_final.append(prediction)
