In this nb we will try to reduce the number of parameters of our model. Currently, the model overfits heavily on nodes.

In [9]:
import sys
sys.path.insert(0, '/home/alapena/GitHub/graph2mat4abn')
import os
os.chdir('/home/ICN2/alapena/GitHub/graph2mat4abn') # Change to the root directory of the project

from graph2mat4abn.tools.import_utils import load_config, get_object_from_module
from graph2mat4abn.tools.tools import get_basis_from_structures_paths, get_kwargs
from graph2mat4abn.tools.scripts_utils import get_model_dataset
from pathlib import Path
from e3nn import o3
from mace.modules import MACE, RealAgnosticResidualInteractionBlock
from graph2mat.models import MatrixMACE
from graph2mat.bindings.e3nn import E3nnGraph2Mat
import torch
import warnings
from graph2mat import BasisTableWithEdges

warnings.filterwarnings("ignore", message="The TorchScript type system doesn't support")
warnings.filterwarnings("ignore", message=".*is not a known matrix type key.*")

In [20]:
# The current model:
model_dir = Path("results/h_crystalls_11")
config = load_config(model_dir / "config.yaml")

# Basis generation (needed to initialize the model)
train_paths, val_paths = get_model_dataset(model_dir, verbose=False)
paths = train_paths + val_paths
basis = get_basis_from_structures_paths(paths, verbose=True, num_unique_z=config["dataset"].get("num_unique_z", None))
table = BasisTableWithEdges(basis)

print("Initializing model...")
# === Enviroment descriptor initialization ===
env_config = config["environment_representation"]

num_interactions = env_config["num_interactions"]
hidden_irreps = o3.Irreps(env_config["hidden_irreps"])

# ! This operation is somehow time-consuming:
mace_descriptor = MACE(
    r_max=env_config["r_max"],
    num_bessel=env_config["num_bessel"],
    num_polynomial_cutoff=env_config["num_polynomial_cutoff"],
    max_ell=env_config["max_ell"],
    interaction_cls=RealAgnosticResidualInteractionBlock,
    interaction_cls_first=RealAgnosticResidualInteractionBlock,
    num_interactions=num_interactions,
    num_elements=env_config["num_elements"],
    hidden_irreps=hidden_irreps,
    MLP_irreps=o3.Irreps(env_config["MLP_irreps"]),
    atomic_energies=torch.tensor(env_config["atomic_energies"]),
    avg_num_neighbors=env_config["avg_num_neighbors"],
    atomic_numbers=env_config["atomic_numbers"],
    correlation=env_config["correlation"],
    gate=get_object_from_module(env_config["gate"], "torch.nn.functional"),
)

# === Model initialization ===
model_config = config["model"]

# === Glue between MACE and E3nnGraph2Mat init ===
model = MatrixMACE(
    mace = mace_descriptor,
    readout_per_interaction=model_config.get("readout_per_interaction", False),
    graph2mat_cls = E3nnGraph2Mat,
    
    # Readout-specific arguments
    unique_basis = table,
    symmetric = True,

    # Preprocessing
    preprocessing_edges = get_object_from_module(
        model_config["preprocessing_edges"], 
        'graph2mat.bindings.e3nn.modules'
    ),
    preprocessing_edges_kwargs = get_kwargs(model_config["preprocessing_edges"], config),

    preprocessing_nodes = get_object_from_module(
        model_config["preprocessing_nodes"], 
        'graph2mat.bindings.e3nn.modules'
    ),
    preprocessing_nodes_kwargs = get_kwargs(model_config["preprocessing_nodes"], config),

    # Operations
    node_operation = get_object_from_module(
        model_config["node_operation"], 
        'graph2mat.bindings.e3nn.modules'
    ),
    node_operation_kwargs = get_kwargs(model_config["node_operation"], config),

    edge_operation = get_object_from_module(
        model_config["edge_operation"], 
        'graph2mat.bindings.e3nn.modules'
    ),
    edge_operation_kwargs = get_kwargs(model_config["edge_operation"], config),
)

Basis computation.
Number of structures to look on: 399
Looking for unique atoms in each structure...


1it [00:00, 72.51it/s]

Found enough basis points. Breaking the search...
Found enough basis points. Breaking the search...
Found the following atomic numbers: [7, 5]
Corresponding path indices: [0, 0]
Basis with 2 elements built!

Basis for atom 0.
	Atom type: 5
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [3.02420918 2.02341372 3.73961942 3.73961942 3.73961942 2.51253945
 2.51253945 2.51253945 3.73961942 3.73961942 3.73961942 3.73961942
 3.73961942]

Basis for atom 1.
	Atom type: 7
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [2.25704422 1.4271749  2.78012609 2.78012609 2.78012609 1.75309697
 1.75309697 1.75309697 2.78012609 2.78012609 2.78012609 2.78012609
 2.78012609]
Initializing model...



  torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),


In [26]:
from prettytable import PrettyTable

def count_parameters(model, verbose=True):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    if verbose:
        print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
p_tot=count_parameters(model, verbose=False)
p_mace=count_parameters(mace_descriptor, verbose=False)
print("Parameters of G2M:", p_tot-p_mace)
None

Total Trainable Params: 311098
Total Trainable Params: 45114
Parameters of G2M: 265984


Most of the params are inside G2M. Let's try to reduce the node's params. Removing the $p_n$ preprocessing function for nodes:

In [22]:
# The current model:
model_dir = Path("results/h_crystalls_11")
config = load_config(model_dir / "config.yaml")

# Basis generation (needed to initialize the model)
train_paths, val_paths = get_model_dataset(model_dir, verbose=False)
paths = train_paths + val_paths
basis = get_basis_from_structures_paths(paths, verbose=True, num_unique_z=config["dataset"].get("num_unique_z", None))
table = BasisTableWithEdges(basis)

print("Initializing model...")
# === Enviroment descriptor initialization ===
env_config = config["environment_representation"]

num_interactions = env_config["num_interactions"]
hidden_irreps = o3.Irreps(env_config["hidden_irreps"])

# ! This operation is somehow time-consuming:
mace_descriptor = MACE(
    r_max=env_config["r_max"],
    num_bessel=env_config["num_bessel"],
    num_polynomial_cutoff=env_config["num_polynomial_cutoff"],
    max_ell=env_config["max_ell"],
    interaction_cls=RealAgnosticResidualInteractionBlock,
    interaction_cls_first=RealAgnosticResidualInteractionBlock,
    num_interactions=num_interactions,
    num_elements=env_config["num_elements"],
    hidden_irreps=hidden_irreps,
    MLP_irreps=o3.Irreps(env_config["MLP_irreps"]),
    atomic_energies=torch.tensor(env_config["atomic_energies"]),
    avg_num_neighbors=env_config["avg_num_neighbors"],
    atomic_numbers=env_config["atomic_numbers"],
    correlation=env_config["correlation"],
    gate=get_object_from_module(env_config["gate"], "torch.nn.functional"),
)

# === Model initialization ===
model_config = config["model"]

# === Glue between MACE and E3nnGraph2Mat init ===
model_1 = MatrixMACE(
    mace = mace_descriptor,
    readout_per_interaction=model_config.get("readout_per_interaction", False),
    graph2mat_cls = E3nnGraph2Mat,
    
    # Readout-specific arguments
    unique_basis = table,
    symmetric = True,

    # Preprocessing
    preprocessing_edges = get_object_from_module(
        model_config["preprocessing_edges"], 
        'graph2mat.bindings.e3nn.modules'
    ),
    preprocessing_edges_kwargs = get_kwargs(model_config["preprocessing_edges"], config),

    # preprocessing_nodes = get_object_from_module(
    #     model_config["preprocessing_nodes"], 
    #     'graph2mat.bindings.e3nn.modules'
    # ),
    # preprocessing_nodes_kwargs = get_kwargs(model_config["preprocessing_nodes"], config),

    # Operations
    node_operation = get_object_from_module(
        model_config["node_operation"], 
        'graph2mat.bindings.e3nn.modules'
    ),
    node_operation_kwargs = get_kwargs(model_config["node_operation"], config),

    edge_operation = get_object_from_module(
        model_config["edge_operation"], 
        'graph2mat.bindings.e3nn.modules'
    ),
    edge_operation_kwargs = get_kwargs(model_config["edge_operation"], config),
)

Basis computation.
Number of structures to look on: 399
Looking for unique atoms in each structure...


1it [00:00, 76.14it/s]

Found enough basis points. Breaking the search...
Found enough basis points. Breaking the search...
Found the following atomic numbers: [7, 5]
Corresponding path indices: [0, 0]
Basis with 2 elements built!

Basis for atom 0.
	Atom type: 5
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [3.02420918 2.02341372 3.73961942 3.73961942 3.73961942 2.51253945
 2.51253945 2.51253945 3.73961942 3.73961942 3.73961942 3.73961942
 3.73961942]

Basis for atom 1.
	Atom type: 7
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [2.25704422 1.4271749  2.78012609 2.78012609 2.78012609 1.75309697
 1.75309697 1.75309697 2.78012609 2.78012609 2.78012609 2.78012609
 2.78012609]
Initializing model...



  torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),


In [29]:
p_tot_1=count_parameters(model_1, verbose=False)
p_mace=count_parameters(mace_descriptor, verbose=False)
print("Parameters of G2M:", p_tot_1-p_mace)
print("Compared to before:", p_tot-p_mace)

Total Trainable Params: 280766
Total Trainable Params: 45114
Parameters of G2M: 235652
Compared to before: 265984
