# Crystal Dataset

In [1]:
import pymatgen
import pymatgen.core.structure
import numpy as np
import os

Prepare some test data.

In [2]:
test_data = [
    pymatgen.core.Structure(lattice=np.array([[4.34157255, 0., 2.50660808], [1.44719085, 4.09327385, 2.50660808], [0., 0., 5.01321616]]), species=["Te", "Ba"], coords=np.array([[0.5, 0.5, 0.5], [0. , 0. , 0. ]])),
    pymatgen.core.Structure(lattice=np.array([[2.95117784, 0., 1.70386332], [0.98372595, 2.78239715, 1.70386332], [0., 0., 3.40772664]]), species=["B", "As"], coords=np.array([[0.25, 0.25, 0.25], [0. , 0. , 0. ]])),
    pymatgen.core.Structure(lattice=np.array([[4.3015, 0., 0.],[-2.15075, 3.725208, 0.], [0., 0., 5.2703]]), species=["Ba", "Ga", "Si", "H"], coords=np.array([[0., 0., 0.],[0.6666, 0.3333, 0.5423], [0.3334, 0.6667, 0.4555], [0.6666, 0.3333, 0.8759]])),
]
os.makedirs("ExampleCrystal", exist_ok=True)
os.makedirs("ExampleCrystal/CifFiles", exist_ok=True)
for i, x in enumerate(test_data):
    x.to(filename="ExampleCrystal/CifFiles/file_%s.cif" % i, fmt="cif")
csv_data = "".join([
    "file_name,index,label\n",  # Need header!
    "file_0.cif, 0, 98.58577122703691\n",
    "file_1.cif, 1, 701.5857233477558\n",
    "file_2.cif, 2, 1138.5856886491724"
])
with open("ExampleCrystal/data.csv", "w") as f:
    f.write(csv_data)

### 0. Crystal dataset

Data is organized like:

 ```bash
 ├── data_directory
    ├── file_directory
    │   ├── *.cif
    │   ├── *.cif
    │   └── ...
    ├── file_name.csv
    └── file_name.pymatgen.json
 ```

In [3]:
from kgcnn.data.crystal import CrystalDataset

In [4]:
dataset = CrystalDataset(
    data_directory="ExampleCrystal/", 
    dataset_name="ExampleCrystal", 
    file_name="data.csv", 
    file_directory="CifFiles"
)

### 1. Generate a json-serialized list of structures via `prepare_data`

In [5]:
dataset.prepare_data(file_column_name="file_name", overwrite=True)

INFO:kgcnn.data.ExampleCrystal:Searching for structure files in 'ExampleCrystal/CifFiles'
INFO:kgcnn.data.ExampleCrystal:Read 3 single files.
INFO:kgcnn.data.ExampleCrystal:... Read .cif file 0 from 3
INFO:kgcnn.data.ExampleCrystal:Exporting as dict for pymatgen ...
INFO:kgcnn.data.ExampleCrystal:Saving structures as .json ...


<CrystalDataset []>

### 2. Read in memory with `read_in_memory`.

In [6]:
dataset.read_in_memory(label_column_name="label")
print(dataset[0])

INFO:kgcnn.data.ExampleCrystal:Making node features from structure...
INFO:kgcnn.data.ExampleCrystal:Reading structures from .json ...
INFO:kgcnn.data.ExampleCrystal: ... read structures 0 from 3


{'graph_labels': array(98.58577123), 'node_coordinates': array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [1.31245681e-09, 6.13991078e+00, 2.27324404e-09]]), 'node_frac_coordinates': array([[0. , 0. , 0. ],
       [0.5, 0.5, 0.5]]), 'graph_lattice': array([[ 1.44719085e+00,  4.09327385e+00,  2.50660808e+00],
       [ 1.44719085e+00,  4.09327385e+00, -2.50660808e+00],
       [-2.89438170e+00,  4.09327385e+00,  1.51549528e-09]]), 'abc': array([5.01321616, 5.01321616, 5.01321616]), 'charge': array([0.]), 'volume': array([89.0910946]), 'node_number': array([56, 52])}


Read pymatgen only via `get_structures_from_json_file`. The structures are not assigned to the dataset but returned by the function.

In [7]:
dataset.get_structures_from_json_file()

INFO:kgcnn.data.ExampleCrystal:Reading structures from .json ...


[Structure Summary
 Lattice
     abc : 5.01321616 5.013216158484504 5.0132161584845045
  angles : 60.00000002 60.00000001 60.00000001
  volume : 89.09109460256703
       A : 1.4471908506158624 4.093273852854227 2.5066080815154956
       B : 1.4471908506158624 4.093273852854227 -2.506608078484504
       C : -2.8943816986068107 4.093273852854227 1.5154952848206449e-09
     pbc : True True True
 PeriodicSite: Ba1 (Ba) (0.0, 0.0, 0.0) [0.0, 0.0, 0.0]
 PeriodicSite: Te0 (Te) (1.312e-09, 6.14, 2.273e-09) [0.5, 0.5, 0.5],
 Structure Summary
 Lattice
     abc : 3.40772664 3.4077266405150777 3.407726637424612
  angles : 60.000000029999995 60.000000024999984 59.999999995
  volume : 27.98203208499981
       A : 0.9837259499337652 2.7823971493851167 1.7038633194849222
       B : 0.9837259499337652 2.7823971493851167 -1.7038633205150777
       C : -1.9674518897566036 2.7823971493851167 -0.0
     pbc : True True True
 PeriodicSite: B0 (B) (7.583e-09, 6.26, -7.726e-10) [0.75, 0.75, 0.75]
 PeriodicSit

Or save them directly to json without collecting individual files.

In [8]:
dataset.save_structures_to_json_file(test_data)

INFO:kgcnn.data.ExampleCrystal:Exporting as dict for pymatgen ...
INFO:kgcnn.data.ExampleCrystal:Saving structures as .json ...


### 3. Generate graph

In [9]:
dataset.map_list(method="set_range_periodic", max_distance=5.0, max_neighbours=20)

<CrystalDataset [{'graph_labels': array(98.58577123), 'node_coordinates': array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [1.31245681e-09, 6.13991078e+00, 2.27324404e-09]]), 'node_frac_coordinates': array([[0. , 0. , 0. ],
       [0.5, 0.5, 0.5]]), 'graph_lattice': array([[ 1.44719085e+00,  4.09327385e+00,  2.50660808e+00],
       [ 1.44719085e+00,  4.09327385e+00, -2.50660808e+00],
       [-2.89438170e+00,  4.09327385e+00,  1.51549528e-09]]), 'abc': array([5.01321616, 5.01321616, 5.01321616]), 'charge': array([0.]), 'volume': array([89.0910946]), 'node_number': array([56, 52]), 'range_indices': array([[0, 1],
       [0, 1],
       [0, 1],
       [0, 1],
       [0, 1],
       [0, 1],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0]], dtype=int32), 'range_image': array([[-1,  0,  0],
       [ 0, -1, -1],
       [-1,  0, -1],
       [ 0,  0, -1],
       [-1, -1,  0],
       [ 0, -1,  0],
       [ 1,  0,  0],
       [ 0,  1,  1],
   

In [10]:
dataset[0]

{'graph_labels': array(98.58577123),
 'node_coordinates': array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.31245681e-09, 6.13991078e+00, 2.27324404e-09]]),
 'node_frac_coordinates': array([[0. , 0. , 0. ],
        [0.5, 0.5, 0.5]]),
 'graph_lattice': array([[ 1.44719085e+00,  4.09327385e+00,  2.50660808e+00],
        [ 1.44719085e+00,  4.09327385e+00, -2.50660808e+00],
        [-2.89438170e+00,  4.09327385e+00,  1.51549528e-09]]),
 'abc': array([5.01321616, 5.01321616, 5.01321616]),
 'charge': array([0.]),
 'volume': array([89.0910946]),
 'node_number': array([56, 52]),
 'range_indices': array([[0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0]], dtype=int32),
 'range_image': array([[-1,  0,  0],
        [ 0, -1, -1],
        [-1,  0, -1],
        [ 0,  0, -1],
        [-1, -1,  0],
        [ 0, -1,  0],
        [ 1,  0,  0],
        [ 0

### 4. Model training

In [11]:
from kgcnn.literature.Schnet import make_crystal_model
from keras.optimizers import Adam

In [12]:
model_config = {
    "name": "Schnet",
    "inputs": [
        {'shape': (None,), 'name': "node_number", 'dtype': 'int64', 'ragged': True},
        {'shape': (None, 3), 'name': "node_coordinates", 'dtype': 'float32', 'ragged': True},
        {'shape': (None, 2), 'name': "range_indices", 'dtype': 'int64', 'ragged': True},
        {'shape': (None, 3), 'name': "range_image", 'dtype': 'int64', 'ragged': True},
        {'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32', 'ragged': False}
    ],
    "input_tensor_type": "ragged",
    "input_node_embedding": {"input_dim": 95, "output_dim": 64},
    "interaction_args": {
        "units": 128, "use_bias": True, 
        "activation": {"class_name": "function", "config": 'kgcnn>shifted_softplus'}, 
        "cfconv_pool": "scatter_sum"
    },
    "node_pooling_args": {"pooling_method": "scatter_mean"},
    "depth": 4,
    "gauss_args": {"bins": 25, "distance": 5, "offset": 0.0, "sigma": 0.4}, "verbose": 10,
    "last_mlp": {"use_bias": [True, True, True], "units": [128, 64, 1],
                 "activation": [
                     {"class_name": "function", "config": 'kgcnn>shifted_softplus'}, 
                     {"class_name": "function", "config": 'kgcnn>shifted_softplus'}, 
                     'linear'
                 ]},
    "output_embedding": "graph",
    "use_output_mlp": False,
    "output_mlp": None,  # Last MLP sets output dimension if None.
}
model = make_crystal_model(**model_config)

INFO:kgcnn.models.utils:Updated model kwargs: '{'name': 'Schnet', 'inputs': [{'shape': (None,), 'name': 'node_number', 'dtype': 'int64', 'ragged': True}, {'shape': (None, 3), 'name': 'node_coordinates', 'dtype': 'float32', 'ragged': True}, {'shape': (None, 2), 'name': 'range_indices', 'dtype': 'int64', 'ragged': True}, {'shape': (None, 3), 'name': 'range_image', 'dtype': 'int64', 'ragged': True}, {'shape': (3, 3), 'name': 'graph_lattice', 'dtype': 'float32', 'ragged': False}], 'input_tensor_type': 'ragged', 'input_embedding': None, 'cast_disjoint_kwargs': {}, 'input_node_embedding': {'input_dim': 95, 'output_dim': 64}, 'make_distance': True, 'expand_distance': True, 'interaction_args': {'units': 128, 'use_bias': True, 'activation': {'class_name': 'function', 'config': 'kgcnn>shifted_softplus'}, 'cfconv_pool': 'scatter_sum'}, 'node_pooling_args': {'pooling_method': 'scatter_mean'}, 'depth': 4, 'gauss_args': {'bins': 25, 'distance': 5, 'offset': 0.0, 'sigma': 0.4}, 'verbose': 10, 'last_m

In [13]:
dataset.clean(model_config["inputs"])

INFO:kgcnn.data.ExampleCrystal:No invalid graphs for assigned properties found.


array([], dtype=int32)

In [14]:
y_train = np.expand_dims(dataset.get("graph_labels"), axis=-1)/1000  # Change units by 1000
x_train = dataset.tensor(model_config["inputs"])
y_train.shape

(3, 1)

In [15]:
model.compile(
    loss="mean_absolute_error",
    optimizer=Adam(learning_rate=1e-04),
    metrics=["mean_absolute_error"],
)

# Build model with reasonable data.
model.predict(x_train)
model._compile_metrics.build(y_train, y_train)
model._compile_loss.build(y_train, y_train)

model.fit(
    x_train, y_train,
    shuffle=True,
    batch_size=3,
    epochs=20,
    verbose=2,
)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 520ms/step
Epoch 1/20
1/1 - 4s - 4s/step - loss: 0.6491 - mean_absolute_error: 0.6491
Epoch 2/20
1/1 - 0s - 23ms/step - loss: 0.6362 - mean_absolute_error: 0.6362
Epoch 3/20
1/1 - 0s - 22ms/step - loss: 0.6232 - mean_absolute_error: 0.6232
Epoch 4/20
1/1 - 0s - 20ms/step - loss: 0.6097 - mean_absolute_error: 0.6097
Epoch 5/20
1/1 - 0s - 21ms/step - loss: 0.5952 - mean_absolute_error: 0.5952
Epoch 6/20
1/1 - 0s - 24ms/step - loss: 0.5795 - mean_absolute_error: 0.5795
Epoch 7/20
1/1 - 0s - 23ms/step - loss: 0.5622 - mean_absolute_error: 0.5622
Epoch 8/20
1/1 - 0s - 23ms/step - loss: 0.5429 - mean_absolute_error: 0.5429
Epoch 9/20
1/1 - 0s - 23ms/step - loss: 0.5210 - mean_absolute_error: 0.5210
Epoch 10/20
1/1 - 0s - 24ms/step - loss: 0.4960 - mean_absolute_error: 0.4960
Epoch 11/20
1/1 - 0s - 23ms/step - loss: 0.4674 - mean_absolute_error: 0.4674
Epoch 12/20
1/1 - 0s - 23ms/step - loss: 0.4343 - mean_absolute_error: 0.4343
Ep

<keras.src.callbacks.history.History at 0x1b974302110>