## Load the data

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
# from sklearn.metrics import mean_squared_error
import numpy as np


In [None]:
import os
cur_dir = os.getcwd()
data_dir = os.path.join(cur_dir, 'data')
data_name = 'oqmd.csv'
data_path = os.path.join(data_dir, data_name)
df = pd.read_csv(data_path)

In [None]:
print('-'*25 + 'Data Info' + '-'*25)
print(df.head())
print('-'*25 + 'Data Info' + '-'*25)

-------------------------Data Info-------------------------
    formula  energy_per_atom  formation_energy_per_atom  band_gap  \
0  ZrZnNiMo        -6.399036                   0.157939       0.0   
1   DySc2Ir        -6.795189                  -0.266899       0.0   
2       YZr        -7.445319                   0.060478       0.0   
3   CrMoAu2        -6.218335                   0.504200       0.0   
4      Ge3O        -4.382844                   0.215750       0.0   

   magnetization_per_atom  
0                0.452473  
1                0.212367  
2                0.002801  
3                0.715628  
4                0.000000  
-------------------------Data Info-------------------------


In [None]:
data_size = df.shape[0]
print('Data size:', data_size)

Data size: 561888


# Fetch code from repo / Github

In [None]:
!git clone https://github.com/Eric-xin/Chem_ML.git

fatal: destination path 'Chem_ML' already exists and is not an empty directory.


In [None]:
# Copy codes to current directory
!cp -r ./Chem_ML/generator ./
!cp -r ./Chem_ML/data ./
!cp -r ./Chem_ML/model ./

In [None]:
# Replace "np.float" to "float"
!sed -i 's/np\.float/float/g' ./data/utils/look_up_data.py

## Initialize Attribute Generators

In [None]:
# Process the data
# Use the Stoichoimetric Attribute, Ionicity Attribute, Elemental Property Attribute, and Charge Dependent Attribute as features
# Use bandgap, energy_pa, volume_pa, magmom_pa, fermi, hull_distance, and delta_e as targets

from generator import *

# Generate the features
stoichiometric_generator = StoichiometricAttributeGenerator()
ionicity_generator = IonicityAttributeGenerator()
elemental_generator = ElementalPropertyAttributeGenerator()
charge_generator = ChargeDependentAttributeGenerator()

## Prepare features and targets

In [None]:
from data.utils import CompositionEntry
from tqdm import tqdm

# Filter out rows where 'formula' is NaN
df_filtered = df.dropna(subset=['formula'])

# Print dropped rows
dropped_rows = data_size - df_filtered.shape[0]
print('-'*25 + 'Dropped Rows' + '-'*25)
print('Dropped rows:', dropped_rows)
print('-'*25 + 'Dropped Rows' + '-'*25)

print('-'*25 + 'Start: data processing' + '-'*25)
features = []
targets = []
for i in tqdm(range(len(df_filtered)), desc="Processing Formulas"):
    formula = df_filtered.iloc[i]['formula']
    if formula is float:
        continue
    entry = CompositionEntry(formula)
    stoichiometric = np.array(stoichiometric_generator.generate_features([entry])).flatten()
    ionicity = np.array(ionicity_generator.generate_features([entry])).flatten()
    elemental = np.array(elemental_generator.generate_features([entry])).flatten()
    # charge = np.array(charge_generator.generate_features([entry])).flatten()

    # filter all the nan values, if there is any, skip this entry
    if np.isnan(stoichiometric).any() or np.isnan(ionicity).any() or np.isnan(elemental).any():
        print('NaN values found in entry:', formula)
        continue

    # feature = np.concatenate([stoichiometric, ionicity, elemental, charge])
    feature = np.concatenate([stoichiometric, ionicity, elemental])
    targets.append(df_filtered.iloc[i][['energy_per_atom', 'formation_energy_per_atom', 'magnetization_per_atom']].values)
    features.append(feature)

print('-'*25 + 'Completed: data processing' + '-'*25)
features = np.array(features)
targets = np.array(targets)


-------------------------Dropped Rows-------------------------
Dropped rows: 6
-------------------------Dropped Rows-------------------------
-------------------------Start: data processing-------------------------


	Electronegativity: Ar
Processing Formulas:   2%|▏         | 12563/561882 [01:34<54:24, 168.26it/s]

NaN values found in entry: Ar


	MeltingT: He
	Electronegativity: He
Processing Formulas:   6%|▌         | 31676/561882 [03:49<57:42, 153.14it/s]

NaN values found in entry: He


	Electronegativity: Ne
Processing Formulas:   7%|▋         | 36747/561882 [04:25<1:41:32, 86.20it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:   8%|▊         | 47044/561882 [05:42<53:02, 161.80it/s]

NaN values found in entry: Ar


	Electronegativity: Ar
Processing Formulas:  10%|▉         | 55801/561882 [06:45<56:43, 148.71it/s]

NaN values found in entry: Ar


	Electronegativity: Ne
Processing Formulas:  10%|█         | 57066/561882 [06:54<53:55, 156.00it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:  10%|█         | 58269/561882 [07:03<56:47, 147.81it/s]

NaN values found in entry: Ar


	MeltingT: He
	Electronegativity: He
Processing Formulas:  13%|█▎        | 72066/561882 [08:44<51:36, 158.20it/s]

NaN values found in entry: He


	Electronegativity: Ne
Processing Formulas:  16%|█▌        | 87562/561882 [10:30<49:29, 159.74it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:  17%|█▋        | 95970/561882 [11:29<49:05, 158.15it/s]

NaN values found in entry: Ar


	MeltingT: He
	Electronegativity: He
Processing Formulas:  20%|██        | 113211/561882 [13:29<46:52, 159.55it/s]

NaN values found in entry: He


	Electronegativity: Ne
Processing Formulas:  21%|██▏       | 120419/561882 [14:19<42:55, 171.43it/s]

NaN values found in entry: Ne


	Electronegativity: Ne
Processing Formulas:  23%|██▎       | 128077/561882 [15:10<1:13:02, 98.98it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:  25%|██▍       | 138718/561882 [16:24<43:01, 163.94it/s]

NaN values found in entry: Ar


	MeltingT: He
	Electronegativity: He
Processing Formulas:  25%|██▍       | 140011/561882 [16:33<45:10, 155.67it/s]

NaN values found in entry: He


	MeltingT: He
	Electronegativity: He
Processing Formulas:  27%|██▋       | 152452/561882 [18:04<45:53, 148.72it/s]

NaN values found in entry: He


	MeltingT: He
	Electronegativity: He
Processing Formulas:  36%|███▌      | 202094/561882 [23:56<54:13, 110.58it/s]

NaN values found in entry: He


	Electronegativity: Ar
Processing Formulas:  38%|███▊      | 210716/561882 [24:57<58:45, 99.61it/s]  

NaN values found in entry: Ar


	Electronegativity: Ne
Processing Formulas:  39%|███▉      | 221785/561882 [26:14<33:30, 169.20it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:  43%|████▎     | 241473/561882 [28:33<35:09, 151.86it/s]

NaN values found in entry: Ar


	Electronegativity: Ne
Processing Formulas:  44%|████▍     | 246470/561882 [29:09<38:28, 136.65it/s]

NaN values found in entry: Ne


	Electronegativity: Ne
Processing Formulas:  45%|████▌     | 255003/561882 [30:10<32:11, 158.87it/s]

NaN values found in entry: Ne


	Electronegativity: Ne
Processing Formulas:  49%|████▊     | 273639/561882 [32:23<31:41, 151.55it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:  52%|█████▏    | 289460/561882 [34:16<27:35, 164.58it/s]

NaN values found in entry: Ar


	Electronegativity: Ar
Processing Formulas:  52%|█████▏    | 289596/561882 [34:17<30:12, 150.25it/s]

NaN values found in entry: Ar


	Electronegativity: Ar
Processing Formulas:  52%|█████▏    | 292281/561882 [34:36<27:54, 160.98it/s]

NaN values found in entry: Ar


	MeltingT: He
	Electronegativity: He
Processing Formulas:  52%|█████▏    | 292949/561882 [34:40<27:11, 164.80it/s]

NaN values found in entry: He


	MeltingT: He
	Electronegativity: He
Processing Formulas:  56%|█████▌    | 315837/561882 [37:20<25:45, 159.25it/s]

NaN values found in entry: He


	MeltingT: He
	Electronegativity: He
Processing Formulas:  58%|█████▊    | 325047/561882 [38:24<28:47, 137.08it/s]

NaN values found in entry: He


	Electronegativity: Ne
Processing Formulas:  58%|█████▊    | 327527/561882 [38:42<24:19, 160.62it/s]

NaN values found in entry: Ne


	MeltingT: He
	Electronegativity: He
Processing Formulas:  60%|█████▉    | 335011/561882 [39:36<26:33, 142.41it/s]

NaN values found in entry: He


	Electronegativity: Ar
Processing Formulas:  63%|██████▎   | 352770/561882 [41:42<20:25, 170.66it/s]

NaN values found in entry: Ar


	Electronegativity: Ne
Processing Formulas:  63%|██████▎   | 354864/561882 [41:56<22:04, 156.27it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:  65%|██████▌   | 365938/561882 [43:15<20:06, 162.43it/s]

NaN values found in entry: Ar


	MeltingT: He
	Electronegativity: He
Processing Formulas:  69%|██████▊   | 385371/561882 [45:27<19:47, 148.64it/s]

NaN values found in entry: He


	MeltingT: He
	Electronegativity: He
Processing Formulas:  70%|██████▉   | 392789/561882 [46:18<20:04, 140.38it/s]

NaN values found in entry: He


	Electronegativity: Ne
Processing Formulas:  70%|██████▉   | 393031/561882 [46:20<17:44, 158.63it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:  74%|███████▎  | 414238/561882 [48:46<25:41, 95.77it/s]

NaN values found in entry: Ar


	MeltingT: He
	Electronegativity: He
Processing Formulas:  76%|███████▌  | 428203/561882 [50:23<13:11, 168.97it/s]

NaN values found in entry: He


	Electronegativity: Ar
Processing Formulas:  76%|███████▋  | 428739/561882 [50:26<13:56, 159.10it/s]

NaN values found in entry: Ar


	Electronegativity: Ne
Processing Formulas:  78%|███████▊  | 438550/561882 [51:34<14:09, 145.13it/s]

NaN values found in entry: Ne


	Electronegativity: Ne
Processing Formulas:  79%|███████▉  | 442959/561882 [52:04<12:54, 153.56it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:  80%|████████  | 451903/561882 [53:09<12:26, 147.34it/s]

NaN values found in entry: Ar


	Electronegativity: Ar
Processing Formulas:  82%|████████▏ | 461283/561882 [54:18<10:04, 166.47it/s]

NaN values found in entry: Ar


	Electronegativity: Ne
Processing Formulas:  86%|████████▌ | 484484/561882 [56:51<10:00, 128.99it/s]

NaN values found in entry: Ne


	MeltingT: He
	Electronegativity: He
Processing Formulas:  88%|████████▊ | 497093/561882 [58:13<06:30, 166.05it/s]

NaN values found in entry: He


	Electronegativity: Ne
Processing Formulas:  89%|████████▊ | 497761/561882 [58:17<06:34, 162.59it/s]

NaN values found in entry: Ne


	Electronegativity: Ne
Processing Formulas:  90%|████████▉ | 503914/561882 [58:57<05:32, 174.19it/s]

NaN values found in entry: Ne


	Electronegativity: Ar
Processing Formulas:  91%|█████████ | 509159/561882 [59:31<05:19, 164.79it/s]

NaN values found in entry: Ar


	MeltingT: He
	Electronegativity: He
Processing Formulas:  91%|█████████ | 510681/561882 [59:41<07:45, 109.95it/s]

NaN values found in entry: He


	MeltingT: He
	Electronegativity: He
Processing Formulas:  93%|█████████▎| 521053/561882 [1:00:48<06:09, 110.52it/s]

NaN values found in entry: He


Processing Formulas: 100%|██████████| 561882/561882 [1:05:26<00:00, 143.09it/s]


-------------------------Completed: data processing-------------------------


In [None]:
# from data.utils import CompositionEntry
# import numpy as np
# import pandas as pd
# # from concurrent.futures import ProcessPoolExecutor
# from concurrent.futures import ThreadPoolExecutor
# from tqdm import tqdm

# def generate_features_for_formula(formula):
#     if isinstance(formula, float):
#         return None
#     entry = CompositionEntry(formula)
#     stoichiometric = np.array(stoichiometric_generator.generate_features([entry])).flatten()
#     ionicity = np.array(ionicity_generator.generate_features([entry])).flatten()
#     elemental = np.array(elemental_generator.generate_features([entry])).flatten()
#     # charge = np.array(charge_generator.generate_features([entry])).flatten()

#     # filter all the nan values, if there is any, skip this entry
#     if np.isnan(stoichiometric).any() or np.isnan(ionicity).any() or np.isnan(elemental).any():
#         return None

#     # feature = np.concatenate([stoichiometric, ionicity, elemental, charge])
#     feature = np.concatenate([stoichiometric, ionicity, elemental])
#     return feature

# # Filter out rows where 'formula' is NaN
# df_filtered = df.dropna(subset=['formula'])

# # Print dropped rows
# dropped_rows = data_size - df_filtered.shape[0]
# print('-'*25 + 'Dropped Rows' + '-'*25)
# print('Dropped rows:', dropped_rows)
# print('-'*25 + 'Dropped Rows' + '-'*25)

# # Use ThreadPoolExecutor for parallel processing
# formulas = df_filtered['formula'].tolist()

# with ThreadPoolExecutor() as executor:
#     features = list(tqdm(executor.map(generate_features_for_formula, formulas), total=len(formulas)))

# # Filter out None values
# features = [f for f in features if f is not None]

# features = np.array(features)

In [None]:
print('-'*25 + 'Features' + '-'*25)
print("shape of features:", features.shape)
print(features[1])
print('-'*25 + 'Features' + '-'*25)

-------------------------Features-------------------------
shape of features: (561831, 141)
[3.00000000e+00 6.12372436e-01 5.38608673e-01 5.06099365e-01
 5.01108669e-01 5.00097571e-01 0.00000000e+00 2.13450798e-01
 6.83327026e-02 4.62500000e+01 5.60000000e+01 2.52500000e+01
 7.70000000e+01 2.10000000e+01 2.10000000e+01 2.82500000e+01
 4.90000000e+01 1.72500000e+01 6.00000000e+01 1.10000000e+01
 1.10000000e+01 1.11157206e+02 1.47261088e+02 6.62012940e+01
 1.92217000e+02 4.49559120e+01 4.49559120e+01 2.01300000e+03
 1.05400000e+03 3.63000000e+02 2.73900000e+03 1.68500000e+03
 1.81400000e+03 4.50000000e+00 6.00000000e+00 2.25000000e+00
 9.00000000e+00 3.00000000e+00 3.00000000e+00 5.00000000e+00
 2.00000000e+00 1.00000000e+00 6.00000000e+00 4.00000000e+00
 4.00000000e+00 1.68250000e+02 5.10000000e+01 1.36250000e+01
 1.92000000e+02 1.41000000e+02 1.70000000e+02 1.53500000e+00
 9.80000000e-01 3.32500000e-01 2.20000000e+00 1.22000000e+00
 1.36000000e+00 2.00000000e+00 0.00000000e+00 0.000000

In [None]:
print('-'*25 + 'Targets' + '-'*25)
print("shape of features:", targets.shape)
print(targets[1])
print('-'*25 + 'Targets' + '-'*25)

-------------------------Targets-------------------------
shape of features: (561831, 3)
[-6.795188535 -0.266899046666668 0.21236685]
-------------------------Targets-------------------------


In [None]:
cur_dir = os.getcwd()
save_dir = os.path.join(cur_dir, 'data/datasets')
print('-'*25 + 'Saving Data' + '-'*25)
np.save(os.path.join(save_dir, 'features_OQMD_full.npy'), features)
np.save(os.path.join(save_dir, 'targets_OQMD_full.npy'), targets)
print('-'*25 + 'Data Saved' + '-'*25)

-------------------------Saving Data-------------------------
-------------------------Data Saved-------------------------


In [None]:
import random

# Convert features and targets to PyTorch tensors
features = torch.tensor(features, dtype=torch.float32)
targets = torch.tensor(targets.astype(np.float32), dtype=torch.float32)

# random_seed = random.randint(0, 1000)
random_seed = 1024

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, targets, test_size=0.2, random_state=random_seed)

print('-'*25 + 'datasets information' + '-'*25)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

print(X_train[0])
print(y_train[0])
print('-'*25 + 'datasets information' + '-'*25)

-------------------------datasets information-------------------------
torch.Size([449464, 141]) torch.Size([449464, 3])
torch.Size([112367, 141]) torch.Size([112367, 3])
tensor([4.0000e+00, 5.0000e-01, 3.9685e-01, 3.2988e-01, 3.0475e-01, 2.8717e-01,
        0.0000e+00, 9.4461e-02, 2.7165e-02, 4.5750e+01, 5.2000e+01, 1.7750e+01,
        7.8000e+01, 2.6000e+01, 4.5750e+01, 6.5500e+01, 2.0000e+01, 6.5000e+00,
        7.5000e+01, 5.5000e+01, 6.5500e+01, 1.0778e+02, 1.3924e+02, 4.7169e+01,
        1.9508e+02, 5.5845e+01, 1.0778e+02, 1.2437e+03, 1.6117e+03, 6.8249e+02,
        2.0414e+03, 4.2975e+02, 1.2437e+03, 1.0750e+01, 5.0000e+00, 1.7500e+00,
        1.3000e+01, 8.0000e+00, 1.0750e+01, 4.7500e+00, 2.0000e+00, 7.5000e-01,
        6.0000e+00, 4.0000e+00, 4.7500e+00, 1.3300e+02, 2.0000e+01, 6.0000e+00,
        1.4200e+02, 1.2200e+02, 1.3300e+02, 1.8850e+00, 6.3000e-01, 1.9750e-01,
        2.2800e+00, 1.6500e+00, 1.8850e+00, 1.7500e+00, 1.0000e+00, 3.7500e-01,
        2.0000e+00, 1.0000e+0

In [None]:
# # Define the neural network architecture
# class Net(nn.Module):
#     def __init__(self):
#         super(Net, self).__init__()
#         self.fc1 = nn.Linear(X_train.shape[1], 128)
#         self.fc2 = nn.Linear(128, 64)
#         self.fc3 = nn.Linear(64, 32)
#         self.fc4 = nn.Linear(32, y_train.shape[1])

#     def forward(self, x):
#         x = torch.relu(self.fc1(x))
#         x = torch.relu(self.fc2(x))
#         x = torch.relu(self.fc3(x))
#         x = self.fc4(x)
#         return x

In [None]:
# from model.simple_model import Net
from model.simple_model_revised import Net
# from model.simple_model_residue import Net
# from model.model import Net
# from model.residue import Net

In [None]:
# from tqdm import tqdm
# import matplotlib.pyplot as plt

# # Note this is the deprecated code

# # Initialize the model, loss function, and optimizer
# input_dim = X_train.shape[1]
# output_dim = y_train.shape[1]
# print("input dim {}, output dim {}".format(input_dim, output_dim))
# model = Net(input_dim, output_dim)
# criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# # Train the model
# num_epochs = 1000
# losses = []
# for epoch in tqdm(range(num_epochs), desc="Training Epochs"):
#     model.train()
#     optimizer.zero_grad()
#     outputs = model(X_train)
#     loss = criterion(outputs, y_train)
#     loss.backward()
#     optimizer.step()

#     losses.append(loss.item())

# # Evaluate the model
# model.eval()
# with torch.no_grad():
#     predictions = model(X_test)
#     mse = criterion(predictions, y_test).item()
#     print(f'Mean Squared Error: {mse}')

# # Save the trained model
# torch.save(model.state_dict(), cur_dir + '/model/' + 'OQMD_100k_model')

# # Plot the loss curve
# plt.figure(figsize=(10, 5))
# plt.plot(range(num_epochs), losses, label='Training Loss')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.title('Training Loss Curve')
# plt.legend()
# plt.show()


In [17]:
from tqdm import tqdm
import matplotlib.pyplot as plt

# Check if MPS is available and set the device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Move data to the device
X_train, X_test, y_train, y_test = X_train.to(device), X_test.to(device), y_train.to(device), y_test.to(device)

# Initialize the model, loss function, and optimizer
input_dim = X_train.shape[1]
output_dim = y_train.shape[1]
print("input dim {}, output dim {}".format(input_dim, output_dim))
model = Net(input_dim, output_dim).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 1000
losses = []
for epoch in tqdm(range(num_epochs), desc="Training Epochs"):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()

    losses.append(loss.item())

    # if (epoch+1) % 10 == 0:
    #     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Evaluate the model
model.eval()
with torch.no_grad():
    predictions = model(X_test)
    mse = criterion(predictions, y_test).item()
    print(f'Mean Squared Error: {mse}')

# Save the trained model
torch.save(model.state_dict(), cur_dir + '/model/' + 'OQMD_100k_model')

# Plot the loss curve
plt.figure(figsize=(10, 5))
plt.plot(range(num_epochs), losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.show()

print('-'*25 + 'Training Summary' + '-'*25)
print(f'Mean Squared Error: {mse}')
print('-'*25 + 'Training Summary' + '-'*25)


Using device: cpu
input dim 141, output dim 3


Training Epochs:   2%|▏         | 20/1000 [03:49<3:07:24, 11.47s/it]


KeyboardInterrupt: 

In [None]:
# model = Net()
# model.load_state_dict(torch.load(current_dir + '/model/' + 'OQMD_100k_model'))

# # Make predictions
# model.eval()
# with torch.no_grad():
#     entry = CompositionEntry("Y2I6")
#     stoichiometric = np.array(stoichiometric_generator.generate_features([entry])).flatten()
#     ionicity = np.array(ionicity_generator.generate_features([entry])).flatten()
#     elemental = np.array(elemental_generator.generate_features([entry])).flatten()
#     # charge = np.array(charge_generator.generate_features([entry])).flatten()

#     # feature = np.concatenate([stoichiometric, ionicity, elemental, charge])
#     feature = np.concatenate([stoichiometric, ionicity, elemental])
#     feature = torch.tensor(feature, dtype=torch.float32)
#     prediction = model(feature)
#     print(prediction)

# # 'bandgap', 'energy_pa', 'volume_pa', 'magmom_pa', 'fermi', 'delta_e'

In [None]:
input_dim = X_train.shape[1]
output_dim = y_train.shape[1]
model = Net(input_dim, output_dim)
model.load_state_dict(torch.load(cur_dir + '/model/' + 'OQMD_100k_model'))

def print_prediction(pred):
    print(f'Energy per atom: {pred[0][0]:.4f} eV')
    print(f'Formation energy per atom: {pred[0][1]:.4f} eV')
    print(f'Magnetization per atom: {pred[0][2]:.4f} Bohr magneton')

# Make predictions
model.eval()
with torch.no_grad():
    entry = CompositionEntry("Y2I6")
    stoichiometric = np.array(stoichiometric_generator.generate_features([entry])).flatten()
    ionicity = np.array(ionicity_generator.generate_features([entry])).flatten()
    elemental = np.array(elemental_generator.generate_features([entry])).flatten()
    # charge = np.array(charge_generator.generate_features([entry])).flatten()

    # feature = np.concatenate([stoichiometric, ionicity, elemental, charge])
    feature = np.concatenate([stoichiometric, ionicity, elemental])
    feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0)  # Add this line to fix the error
    prediction = model(feature)
    # print(prediction)
    print_prediction(prediction)

In [None]:
from sklearn.metrics import mean_squared_error

input_dim = X_train.shape[1]
output_dim = y_train.shape[1]
model = Net(input_dim, output_dim)
model.load_state_dict(torch.load(cur_dir + '/model/' + 'OQMD_100k_model'))

# use the data from "small_set.txt" for testing
def read_data(file_path):
    # Define column names
    column_names = [
        'name', 'bandgap', 'energy_pa', 'volume_pa', 'magmom_pa',
        'fermi', 'hull_distance', 'delta_e'
    ]

    # Read the data
    data = pd.read_csv(file_path, delim_whitespace=True, names=column_names, skiprows=1)

    # Handle missing values
    data.replace('None', np.nan, inplace=True)
    data = data.apply(pd.to_numeric, errors='ignore')

    return data

file_path = './data/datasets/small_set.txt'
data = read_data(file_path)

# Only keep Name, Energy per Atom and Magnetization per Atom
data = data[['name', 'energy_pa', 'magmom_pa']]

# Evaluate the model and calculate the mean squared error
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)
model.eval()

all_predictions = []
all_targets = []

with torch.no_grad():
    for i in range(len(data)):
        entry = CompositionEntry(data.iloc[i]['name'])
        stoichiometric = np.array(stoichiometric_generator.generate_features([entry])).flatten()
        ionicity = np.array(ionicity_generator.generate_features([entry])).flatten()
        elemental = np.array(elemental_generator.generate_features([entry])).flatten()

        # filter all the nan values, if there is any, skip this entry
        if np.isnan(stoichiometric).any() or np.isnan(ionicity).any() or np.isnan(elemental).any():
            print('NaN values found in entry:', data.iloc[i]['name'])
            continue

        feature = np.concatenate([stoichiometric, ionicity, elemental])
        feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)

        output = model(feature)
        output = output[0, [0, 2]]  # Only keep Energy per Atom and Magnetization per Atom

        all_predictions.append(output.cpu().numpy())
        all_targets.append(data.iloc[i][['energy_pa', 'magmom_pa']].values)

mse = mean_squared_error(all_targets, all_predictions)
print(f'Mean Squared Error: {mse}')