In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pymatgen.core import Structure
from pymatgen.io.cif import CifWriter

In [None]:
df_final = pd.read_hdf("Li_Na_electrodes_uma_m/final_dataset.h5")

In [None]:
df_final['pretty_formula'] = df_final['Li_structure'].apply(lambda s: s.composition.reduced_formula)

In [None]:
df_final = df_final[df_final['Li_structure_n_atoms']<=20]
df_final = df_final.reset_index(drop=True)

# 1. Select the columns you want
df_out = df_final[[
    "pretty_formula",
    "host_mp_id",
    "Li_structure_n_atoms",
    "host_energy_per_atom",
    "Li_energy_per_atom",
    "Na_energy_per_atom",
    "Li_voltage",
    "Na_voltage",
    "Li_structure",
]].copy()
df_out.columns = ['pretty_formula','material_id','num_atoms',
                  'host_energy_per_atom','Li_energy_per_atom','Na_energy_per_atom',
                  'Li_voltage','Na_voltage','cif']

# 2. Convert Li_structure (pymatgen Structure) -> CIF string
def structure_to_cif(struct):
    if struct is None:
        return ""
    # struct is a pymatgen.core.Structure
    return struct.to(fmt="cif")  # multi-line CIF string

df_out["cif"] = df_out["cif"].apply(structure_to_cif)

# Optionally, if you want the column name to be 'cif' instead of 'Li_structure':
# df_out = df_out.rename(columns={"Li_structure": "cif"})

# 3. Save to CSV; pandas will quote the multi-line CIF strings
# df_out.to_csv("final_data.csv", index=False)

In [None]:
# Shuffle the dataframe randomly
df_shuffled = df_out.sample(frac=1, random_state=42).reset_index(drop=True)

# Calculate split indices
n = len(df_shuffled)
train_end = int(0.8 * n)
test_end = int(0.9 * n)

# Split into train, test, val
df_train = df_shuffled[:train_end]
df_test = df_shuffled[train_end:test_end]
df_val = df_shuffled[test_end:]

# Verify the splits
print(f"Total samples: {n}")
print(f"Train: {len(df_train)} ({len(df_train)/n*100:.1f}%)")
print(f"Test: {len(df_test)} ({len(df_test)/n*100:.1f}%)")
print(f"Val: {len(df_val)} ({len(df_val)/n*100:.1f}%)")

# Save each split to separate CSV files
df_train.to_csv("li_data_20/train.csv", index=False)
df_test.to_csv("li_data_20/test.csv", index=False)
df_val.to_csv("li_data_20/val.csv", index=False)