# Reproducing the Oxidation State Assignment Model from Jablonka et al.

This notebook reproduces the key methodology and results from:
**"Using collective knowledge to assign oxidation states of metal cations in metal–organic frameworks"**
([DOI: 10.1038/s41557-021-00717-y](https://doi.org/10.1038/s41557-021-00717-y))

## Table of Contents
1. [Environment Setup](#env)
2. [Load Data](#load)
3. [Featurization](#feat)
4. [Model Training](#train)
5. [Evaluation](#eval)
6. [SHAP Analysis](#shap)

In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier, ExtraTreesClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.preprocessing import StandardScaler
import shap
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import warnings
warnings.filterwarnings('ignore')
import pickle

## 2. Load Data 

In [5]:
# Load feature matrix and labels
X = np.load("features_all.npy")
y = np.load("labels_all.npy")

# Load structure names (CSD IDs, optional but useful for later)
names_path = r"E:\Projects\names_all.pkl"

with open(names_path, "rb") as f:
    names = pickle.load(f)
# Check shape
print("Features shape:", X.shape)
print("Labels shape:", y.shape)
print("Example CSD IDs:", names[:5])

Features shape: (211723, 116)
Labels shape: (211723, 1)
Example CSD IDs: ['DEWQUS', 'RIPPEN', 'CIQHOC', 'WIMWIB', 'EMETUL']


## 3. Featurization <a name="feat"></a>

> If you use raw CIFs, replace this with matminer or oximachine_featurizer logic.

```python
X = df.drop(columns=['oxidation_state'])
y = df['oxidation_state']

# Standardize
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Save for reproducibility
joblib.dump(scaler, 'scaler.pkl')
```

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, stratify=y, test_size=0.2, random_state=42)

models = [
    ('et', ExtraTreesClassifier(n_estimators=200, random_state=0)),
    ('gb', GradientBoostingClassifier(n_estimators=200, random_state=0)),
    ('knn', KNeighborsClassifier(n_neighbors=5)),
    ('lr', LogisticRegression(max_iter=1000))
]

voting_clf = VotingClassifier(estimators=models, voting='soft')
voting_clf.fit(X_train, y_train)

# Save model
joblib.dump(voting_clf, 'voting_model.pkl')

In [None]:
y_pred = voting_clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred))

# Confusion matrix
conf = confusion_matrix(y_test, y_pred)
sns.heatmap(conf, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()

In [None]:
explainer = shap.Explainer(voting_clf.estimators_[0], X_train)
shap_values = explainer(X_test[:100])

shap.summary_plot(shap_values, features=X_test[:100], feature_names=df.columns[:-1])

## Notes

- Use the original featurizer (`oximachine_featurizer`) and structure parser (`oximachinerunner`) to replace any placeholders.
- This notebook can be adapted for transfer learning on other materials datasets.