In [1]:
import joblib
import numpy as np
import pandas as pd

preprocessing_object = joblib.load("/data/foundation_model/data/preprocessing_objects_20250615.pkl.z")
# print(preprocessing_object)

all_data = (
    pd.read_parquet("/data/foundation_model/data/qc_ac_te_mp_dos_reformat_20250615.pd.parquet").drop(
        index=preprocessing_object["dropped_idx"]
    )
    # .reset_index(drop=True)
)
all_data.head(3)

test_data = (
    all_data[all_data["split"] == "test"]
    #  .reset_index(drop=True)
)
test_data.head(3)


desc_trans = pd.read_parquet("/data/foundation_model/data/qc_ac_te_mp_dos_composition_desc_trans_20250615.pd.parquet")


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [2]:
def swap_train_val_split(split, swap_ratio=0.1, random_seed=None, train_ratio: float = 1.0):
    split = split.copy()
    train_idx = split[split == "train"].index
    val_idx = split[split == "val"].index

    # 先交换
    n_swap = int(min(len(train_idx), len(val_idx)) * swap_ratio)
    if n_swap > 0:
        rng = np.random.default_rng(random_seed)
        swap_train = rng.choice(train_idx, n_swap, replace=False)
        swap_val = rng.choice(val_idx, n_swap, replace=False)
        split.loc[swap_train] = "val"
        split.loc[swap_val] = "train"

    # 再采样train
    train_idx = split[split == "train"].index
    if train_ratio < 1.0:
        rng = np.random.default_rng(random_seed)
        n_train = round(len(train_idx) * train_ratio)
        sampled_train_idx = rng.choice(train_idx, n_train, replace=False)
        # 其余train直接丢弃
        drop_idx = train_idx.difference(sampled_train_idx)
        split.loc[drop_idx] = np.nan

    return split.dropna()


In [3]:
preprocessing_object.keys()

dict_keys(['dropped_idx', 'composition_desc_scaler', 'material_type_label_encoder', 'space_group_label_encoder', 'band_gap_scaler', 'density_scaler', 'efermi_scaler', 'final_energy_per_atom_scaler', 'formation_energy_per_atom_scaler', 'total_magnetization_scaler', 'volume_scaler', 'dielectric_constant_scaler', 'thermal_conductivity_scaler', 'electrical_resistivity_scaler', 'power_factor_scaler', 'seebeck_coefficient_scaler', 'zt_scaler', 'magnetic_susceptibility_scaler', 'dielectric_total_scaler', 'dielectric_ionic_scaler', 'dielectric_electronic_scaler'])

In [4]:
import json
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix

N_TRY = 10
prop_name = "Material type"
# Experiment 1: Use Fourier features
print("training for property:", prop_name)

for ratio in [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3]:
    base_dir = Path(f"logs/classification/{prop_name}/{datetime.now().strftime('%m%d_%H%M')}_r{ratio}")

    mask = all_data[f"{prop_name}"].notnull()
    org_splits = all_data.loc[mask, "split"]
    prop = all_data.loc[mask, f"{prop_name}"]

    for n_try in range(N_TRY):
        # 1. Setup logging directory
        version = f"trial_{n_try + 1}"
        save_dir = base_dir / "predictions" / version
        save_dir.mkdir(parents=True, exist_ok=True)

        splits = swap_train_val_split(org_splits, swap_ratio=0.5, random_seed=None, train_ratio=ratio)
        splits.to_csv(f"{save_dir}/data_split.csv")

        X_train = desc_trans.loc[splits[splits == "train"].index]
        y_train = prop.loc[splits[splits == "train"].index]
        X_test = desc_trans.loc[splits[splits == "test"].index]
        y_test = prop.loc[splits[splits == "test"].index]

        # 编码类别
        le = preprocessing_object["material_type_label_encoder"]
        y_train_enc = le.transform(y_train)
        y_test_enc = le.transform(y_test)

        # 训练随机森林分类器
        clf = RandomForestClassifier(
            n_estimators=300, random_state=n_try, bootstrap=True, max_features="sqrt", n_jobs=60
        )
        clf = clf.fit(X_train, y_train_enc)
        y_pred_enc = clf.predict(X_test)
        y_fit_pred_enc = clf.predict(X_train)
        y_pred, y_true = le.inverse_transform(y_pred_enc), y_test.values
        y_fit_pred, y_fit_true = le.inverse_transform(y_fit_pred_enc), y_train.values

        # 保存预测结果
        results = pd.concat(
            [
                pd.DataFrame({"y_true": y_fit_true, "y_pred": y_fit_pred, "label": "train"}, index=X_train.index),
                pd.DataFrame({"y_true": y_true, "y_pred": y_pred, "label": "test"}, index=X_test.index),
            ]
        )
        results.to_parquet(save_dir / "clf_predictions.parquet")
        results.to_csv(save_dir / "clf_predictions.csv")

        # 保存模型
        joblib.dump(clf, save_dir / "clf_model.pkl.z")

        test_report = classification_report(y_test, y_pred, target_names=le.classes_, digits=4, output_dict=True)
        train_report = classification_report(
            y_fit_true, y_fit_pred, target_names=le.classes_, digits=4, output_dict=True
        )

        # 混淆矩阵
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(21, 8), sharey=True, dpi=130)
        plt.subplots_adjust(wspace=0.05)

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred, labels=le.classes_)
        hm1 = sns.heatmap(
            cm,
            annot=True,
            fmt="d",
            cmap="Blues",
            xticklabels=le.classes_,
            yticklabels=le.classes_,
            ax=ax1,
            annot_kws={"size": 13},
        )
        ax1.set_xlabel("Predicted", fontsize=18)
        ax1.set_ylabel("True", fontsize=18)
        ax1.set_title(f"{prop_name}", fontsize=18)
        ax1.tick_params(axis="both", labelsize=13)
        cbar1 = hm1.collections[0].colorbar
        cbar1.ax.tick_params(labelsize=13)

        # Normalized confusion matrix
        cm_norm = confusion_matrix(y_test, y_pred, labels=le.classes_, normalize="true")
        hm2 = sns.heatmap(
            cm_norm,
            annot=True,
            fmt=".2f",
            cmap="Blues",
            xticklabels=le.classes_,
            yticklabels=le.classes_,
            ax=ax2,
            annot_kws={"size": 13},
        )
        ax2.set_xlabel("Predicted", fontsize=18)
        ax2.set_title(f"{prop_name} (Normalized)", fontsize=18)
        ax2.tick_params(axis="both", labelsize=13)
        cbar2 = hm2.collections[0].colorbar
        cbar2.ax.tick_params(labelsize=13)

        _ = fig.savefig(f"{save_dir}/test_cm.png", bbox_inches="tight")
        plt.cla()
        plt.clf()

        # 分类指标
        print(classification_report(y_true, y_pred, target_names=le.classes_, digits=4))

        # 混淆矩阵
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(21, 8), sharey=True, dpi=150)
        plt.subplots_adjust(wspace=0.05)

        # Confusion matrix
        cm = confusion_matrix(y_fit_true, y_fit_pred, labels=le.classes_)
        hm1 = sns.heatmap(
            cm,
            annot=True,
            fmt="d",
            cmap="Blues",
            xticklabels=le.classes_,
            yticklabels=le.classes_,
            ax=ax1,
            annot_kws={"size": 15},
        )
        ax1.set_xlabel("Predicted", fontsize=18)
        ax1.set_ylabel("True", fontsize=18)
        ax1.set_title(f"{prop_name}", fontsize=18)
        ax1.tick_params(axis="both", labelsize=13)
        cbar1 = hm1.collections[0].colorbar
        cbar1.ax.tick_params(labelsize=13)

        # Normalized confusion matrix
        cm_norm = confusion_matrix(y_fit_true, y_fit_pred, labels=le.classes_, normalize="true")
        hm2 = sns.heatmap(
            cm_norm,
            annot=True,
            fmt=".2f",
            cmap="Blues",
            xticklabels=le.classes_,
            yticklabels=le.classes_,
            ax=ax2,
            annot_kws={"size": 13},
        )
        ax2.set_xlabel("Predicted", fontsize=18)
        ax2.set_title(f"{prop_name} (Normalized)", fontsize=18)
        ax2.tick_params(axis="both", labelsize=13)
        cbar2 = hm2.collections[0].colorbar
        cbar2.ax.tick_params(labelsize=13)

        _ = fig.savefig(f"{save_dir}/train_cm.png", bbox_inches="tight")
        plt.cla()
        plt.clf()

        print(classification_report(y_fit_true, y_fit_pred, target_names=le.classes_, digits=4))

        metrics = {"train": train_report, "test": test_report}
        with open(f"{save_dir}/metrics.json", "w") as f:
            json.dump(metrics, f, indent=2)


training for property: Material type
              precision    recall  f1-score   support

         DAC     0.5000    1.0000    0.6667         1
         DQC     0.7500    1.0000    0.8571         3
         IAC     0.8462    0.9167    0.8800        24
         IQC     0.8462    0.7857    0.8148        28
      others     0.9995    0.9992    0.9993      7298

    accuracy                         0.9981      7354
   macro avg     0.7884    0.9403    0.8436      7354
weighted avg     0.9982    0.9981    0.9981      7354

              precision    recall  f1-score   support

         DAC     1.0000    1.0000    1.0000        12
         DQC     1.0000    1.0000    1.0000        10
         IAC     1.0000    0.9759    0.9878        83
         IQC     0.9667    0.9603    0.9635       151
      others     0.9998    0.9999    0.9998     34039

    accuracy                         0.9997     34295
   macro avg     0.9933    0.9872    0.9902     34295
weighted avg     0.9996    0.9997    0.9

  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(21, 8), sharey=True, dpi=130)


              precision    recall  f1-score   support

         DAC     0.5000    1.0000    0.6667         1
         DQC     0.7500    1.0000    0.8571         3
         IAC     0.8400    0.8750    0.8571        24
         IQC     0.8519    0.8214    0.8364        28
      others     0.9996    0.9993    0.9995      7298

    accuracy                         0.9982      7354
   macro avg     0.7883    0.9391    0.8434      7354
weighted avg     0.9983    0.9982    0.9983      7354

              precision    recall  f1-score   support

         DAC     1.0000    1.0000    1.0000        11
         DQC     1.0000    1.0000    1.0000        12
         IAC     1.0000    0.9870    0.9935        77
         IQC     0.9927    0.9784    0.9855       139
      others     0.9999    1.0000    0.9999     30627

    accuracy                         0.9998     30866
   macro avg     0.9985    0.9931    0.9958     30866
weighted avg     0.9998    0.9998    0.9998     30866

              precisio

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


              precision    recall  f1-score   support

         DAC     0.0000    0.0000    0.0000         1
         DQC     0.6000    1.0000    0.7500         3
         IAC     0.8333    0.8333    0.8333        24
         IQC     0.8519    0.8214    0.8364        28
      others     0.9992    0.9992    0.9992      7298

    accuracy                         0.9978      7354
   macro avg     0.6569    0.7308    0.6838      7354
weighted avg     0.9978    0.9978    0.9978      7354

              precision    recall  f1-score   support

         DAC     1.0000    1.0000    1.0000         3
         DQC     1.0000    1.0000    1.0000         9
         IAC     1.0000    1.0000    1.0000        64
         IQC     0.9703    0.9608    0.9655       102
      others     0.9998    0.9999    0.9999     23828

    accuracy                         0.9997     24006
   macro avg     0.9940    0.9921    0.9931     24006
weighted avg     0.9997    0.9997    0.9997     24006

              precisio

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


              precision    recall  f1-score   support

         DAC     0.0000    0.0000    0.0000         1
         DQC     0.5000    0.6667    0.5714         3
         IAC     0.8400    0.8750    0.8571        24
         IQC     0.7857    0.7857    0.7857        28
      others     0.9990    0.9989    0.9990      7298

    accuracy                         0.9974      7354
   macro avg     0.6250    0.6653    0.6427      7354
weighted avg     0.9974    0.9974    0.9974      7354

              precision    recall  f1-score   support

         DAC     1.0000    1.0000    1.0000         3
         DQC     1.0000    1.0000    1.0000         4
         IAC     1.0000    1.0000    1.0000        53
         IQC     0.9770    0.9884    0.9827        86
      others     1.0000    0.9999    0.9999     20431

    accuracy                         0.9999     20577
   macro avg     0.9954    0.9977    0.9965     20577
weighted avg     0.9999    0.9999    0.9999     20577

              precisio

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


              precision    recall  f1-score   support

         DAC     0.0000    0.0000    0.0000         1
         DQC     1.0000    1.0000    1.0000         3
         IAC     0.9048    0.7917    0.8444        24
         IQC     0.8636    0.6786    0.7600        28
      others     0.9982    0.9996    0.9989      7298

    accuracy                         0.9976      7354
   macro avg     0.7533    0.6940    0.7207      7354
weighted avg     0.9973    0.9976    0.9974      7354

              precision    recall  f1-score   support

         DAC     1.0000    1.0000    1.0000         2
         DQC     1.0000    1.0000    1.0000         3
         IAC     1.0000    1.0000    1.0000        29
         IQC     0.9828    0.9828    0.9828        58
      others     0.9999    0.9999    0.9999     13626

    accuracy                         0.9999     13718
   macro avg     0.9965    0.9965    0.9965     13718
weighted avg     0.9999    0.9999    0.9999     13718

              precisio

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


ValueError: Number of classes, 4, does not match size of target_names, 5. Try specifying the labels parameter

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>

<Figure size 2730x1040 with 0 Axes>

<Figure size 3150x1200 with 0 Axes>