In [34]:
import datasets

mptrj = datasets.load_dataset("nimashoghi/mptrj", split="all")
mptrj.set_format("numpy")
mptrj

Dataset({
    features: ['numbers', 'positions', 'forces', 'cell', 'pbc', 'energy', 'stress', 'e_per_atom_relaxed', 'mp_id', 'energy_per_atom', 'ef_per_atom_relaxed', 'corrected_total_energy', 'ef_per_atom', 'task_id', 'calc_id', 'ionic_step', 'filename', 'extxyz_id', 'num_atoms', 'corrected_total_energy_relaxed', 'energy_referenced', 'corrected_total_energy_referenced', 'corrected_total_energy_relaxed_referenced', 'composition'],
    num_rows: 1580395
})

In [35]:
wbm = datasets.load_dataset("nimashoghi/wbm", split="all")
wbm.set_format("numpy")
wbm

Dataset({
    features: ['formula', 'n_sites', 'volume', 'uncorrected_energy', 'e_form_per_atom_wbm', 'e_above_hull_wbm', 'bandgap_pbe', 'wyckoff_spglib_initial_structure', 'uncorrected_energy_from_cse', 'e_correction_per_atom_mp2020', 'e_correction_per_atom_mp_legacy', 'e_form_per_atom_uncorrected', 'e_form_per_atom_mp2020_corrected', 'e_above_hull_mp2020_corrected_ppd_mp', 'site_stats_fingerprint_init_final_norm_diff', 'wyckoff_spglib', 'unique_prototype', 'formula_from_cse', 'initial_structure', 'id', 'material_id', 'frac_pos', 'cart_pos', 'pos', 'cell', 'num_atoms', 'atomic_numbers', 'composition'],
    num_rows: 256963
})

In [36]:
import pickle

import numpy as np
from tqdm.auto import tqdm


def get_problematic_samples_compositions():
    with open(
        "/net/csefiles/coc-fung-cluster/nima/shared/repositories/jmp-peft/notebooks/problematic_samples.relaxdata.pkl",
        "rb",
    ) as f:
        data = pickle.load(f)

    return np.stack(
        [np.bincount(x["atoms"][0]["numbers"], minlength=120) for x in tqdm(data)],
        axis=0,
    )


problematic_compositions = get_problematic_samples_compositions()
print(problematic_compositions.shape)

  0%|          | 0/75 [00:00<?, ?it/s]

(75, 120)


In [37]:
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import trange

sns.set_theme()


def composition_to_idxs(composition: np.ndarray):
    return np.repeat(np.arange(len(composition)), composition)


def plot_compositions(compositions: list[tuple[str, np.ndarray]]):
    fig, ax = plt.subplots(figsize=(10, 5))
    for name, composition in compositions:
        atomic_numbers_flat = np.concatenate(
            [composition_to_idxs(composition[i]) for i in trange(len(composition))],
            axis=0,
        )
        print(atomic_numbers_flat.shape)

        sns.histplot(atomic_numbers_flat, ax=ax, label=name, stat="density")

    ax.set_xlabel("Atomic number")
    ax.set_ylabel("Density")
    ax.legend()

    fig.tight_layout()


# plot_compositions(
#     [
#         ("mptrj", mptrj["composition"]),
#         ("wbm", wbm["composition"]),
#         ("problematic", problematic_compositions),
#     ]
# )

In [39]:
import numpy as np
import plotly.graph_objects as go


def composition_to_idxs(composition: np.ndarray):
    return np.repeat(np.arange(len(composition)), composition)


def plot_compositions(
    compositions: list[tuple[str, np.ndarray]],
    element_labels=None,
    max_atomic_number: int | None = None,
):
    fig = go.Figure()

    if max_atomic_number is None:
        max_atomic_number = max(len(comp[0]) for _, comp in compositions)

    for name, composition in compositions:
        atomic_numbers_flat = np.concatenate(
            [composition_to_idxs(composition[i]) for i in trange(len(composition))],
            axis=0,
        )
        print(atomic_numbers_flat.shape)

        hist_data = np.histogram(
            atomic_numbers_flat,
            bins=np.arange(max_atomic_number + 1) - 0.5,
            density=True,
        )

        fig.add_trace(
            go.Bar(
                x=hist_data[1][:-1],
                y=hist_data[0],
                name=name,
            )
        )

    # Add vertical lines to separate elements
    for i in range(max_atomic_number):
        fig.add_vline(x=i, line_width=1, line_dash="dash", line_color="gray")

    # Customize x-axis ticks and labels
    if element_labels:
        fig.update_layout(
            xaxis=dict(
                tickmode="array",
                tickvals=list(range(max_atomic_number)),
                ticktext=element_labels[:max_atomic_number],
                tickangle=45,
            )
        )

    fig.update_layout(
        xaxis_title="Atomic number / Element",
        yaxis_title="Density",
        barmode="group",
        legend_title="Compositions",
        width=1200,
        height=600,
        bargap=0.1,
        bargroupgap=0.05,
    )

    fig.show()


# Get element symbols from the package `pymatgen`
from pymatgen.core import Element

element_labels = [Element.from_Z(i).symbol for i in range(1, 119)]


plot_compositions(
    [
        ("mptrj", mptrj["composition"]),
        ("wbm", wbm["composition"]),
        ("problematic", problematic_compositions),
    ],
    element_labels=element_labels,
    max_atomic_number=95,
)

  0%|          | 0/1580395 [00:00<?, ?it/s]

(49295660,)


  0%|          | 0/256963 [00:00<?, ?it/s]

(1994133,)


  0%|          | 0/75 [00:00<?, ?it/s]

(509,)
