In [84]:
import os
import os.path as osp
import argparse
import random
import time
import csv
import numpy as np


from torch_geometric.data import Data, Dataset

In [85]:
MP_API_KEY = "j61NN3yuDh8tQWf0OrkachbbUoJ8npVP"

CONVENTIONAL_UNIT_CELL = False

WORK_DIR = "."

DATASET_DIR = osp.join("{}".format(WORK_DIR), "dataset")
DATASET_RAW_DIR = osp.join("{}".format(DATASET_DIR), "raw")
DATASET_PROCESSED_DIR = osp.join("{}".format(DATASET_DIR), "processed")

In [86]:
print(DATASET_DIR, DATASET_RAW_DIR, DATASET_PROCESSED_DIR)

./dataset ./dataset/raw ./dataset/processed


In [87]:
from mp_api.client import MPRester

# try:
#     del material_id_data
# except:
#     pass

mpr = MPRester(MP_API_KEY)

# Get material id data with data filter
mid_doc = mpr.materials.summary.search(
    fields=["material_id", "formation_energy_per_atom", "structure"],
    exclude_elements=["O"],
    num_elements=(3, 3),
    chunk_size=10,  # TODO: remove this if not debugging
    num_chunks=1,  # TODO: remove this if not debugging
)

Retrieving SummaryDoc documents:   0%|          | 0/10 [00:00<?, ?it/s]

In [88]:
indices = []
for i, d in enumerate(mid_doc):
    filename = osp.join(DATASET_RAW_DIR, "CONFIG_" + str(i + 1) + ".cif")
    output = d.structure.to_file(filename=filename, fmt="cif")
    indices.append(
        {
            "idx": i + 1,
            "mid": str(d.material_id),
            "formation_energy_per_atom": d.formation_energy_per_atom,
        }
    )
indices_filename = osp.join(DATASET_RAW_DIR, "INDICES")
with open(indices_filename, "w", newline="") as f:
    cw = csv.DictWriter(f, fieldnames=["idx", "mid", "formation_energy_per_atom"])
    cw.writeheader()
    cw.writerows(indices)

In [89]:
class MPDataset(Dataset):

    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        # self.args = args
        # self.load(self.processed_paths[0])

    @property
    def processed_file_names(self) -> list[str]:
        return ["data.pt"]

    @property
    def raw_file_names(self):
        return ["INDICES"]

    def download(self):
        print("download")
        pass

    def process(self):
        print("process")
        pass

    def len(self) -> int:
        return len(self.processed_file_names)

    def get(self, idx):
        # data = torch.load(osp.join(self.processed_dir, f"data_{idx}.pt"))
        return None


dataset = MPDataset(root="./dataset")
dataset.processed_dir

'dataset/processed'

In [90]:
# # TODO: remove this if not debugging
# mid_list = mid_list[0:100]

# indices = []
# structures = []

# for i, d in enumerate(mid_list):
#     print("Progress: {}/{}".format(i + 1, len(mid_list)))
#     st = mpr.get_structure_by_material_id(
#         d, conventional_unit_cell=CONVENTIONAL_UNIT_CELL
#     )
#     structures.append(st)

In [91]:
# st = mpr.get_structure_by_material_id("mp-978908", conventional_unit_cell=True)
# print(st)
# print(st.to_primitive())

In [116]:
import torch
from ase.io import read as ase_read
from torch_geometric.utils import dense_to_sparse

indices_filename = osp.join("{}".format(DATASET_RAW_DIR), "INDICES")
assert osp.exists(indices_filename), "INDICES file not exist in " + indices_filename
with open(indices_filename) as f:
    reader = csv.reader(f)
    indices = [row for row in reader][1:]
# filenames = ["CONFIG_" + d[0] + ".cif" for _, d in enumerate(indices)]
# filenames.append("INDICES")
# for i, d in enumerate(indices):

d = indices[0]
idx, mid, y = d[0], d[1], d[2]
filename = osp.join("{}".format(DATASET_RAW_DIR), "CONFIG_" + idx + ".cif")

compound = ase_read(filename, format="cif")

# get distance matrix
distance_matrix = compound.get_all_distances(mic=False)
# get mask by max cutoff distance
cutoff_mask = distance_matrix > 4.0
# suppress invalid values using max cutoff distance
distance_matrix = np.ma.array(distance_matrix, mask=cutoff_mask)
# let '--' in the masked array to 0
distance_matrix = np.nan_to_num(np.where(cutoff_mask, np.isnan(distance_matrix), distance_matrix))
# make it as a tensor
distance_matrix = torch.Tensor(distance_matrix)

y = torch.Tensor(np.array([y], dtype=np.float64))
print(y)
# dense transform to sparse to get edge_index and edge_attr
# data.y = torch.tensor()

sparse_distance_matrix = dense_to_sparse(distance_matrix)
print(sparse_distance_matrix)
print(sparse_distance_matrix[0])
print(sparse_distance_matrix[1])
print(indices)

print(torch.LongTensor(np.array([compound.get_atomic_numbers()])))
print(torch.LongTensor(np.array([compound.get_atomic_numbers()])).t().contiguous())

print(torch.Tensor(np.array([sparse_distance_matrix[1]], dtype=np.float32)).shape)
print(torch.Tensor(np.array([sparse_distance_matrix[1]], dtype=np.float32)).t().contiguous().shape)

tensor([-0.4133])
(tensor([[0, 0, 1, 2, 2, 3],
        [2, 3, 2, 0, 1, 0]]), tensor([3.2950, 3.2950, 3.2950, 3.2950, 3.2950, 3.2950]))
tensor([[0, 0, 1, 2, 2, 3],
        [2, 3, 2, 0, 1, 0]])
tensor([3.2950, 3.2950, 3.2950, 3.2950, 3.2950, 3.2950])
[['1', 'mp-861724', '-0.41328523750000556'], ['2', 'mp-1183076', '-0.4802780425000037'], ['3', 'mp-1183068', '-0.40283124874999743'], ['4', 'mp-1183063', '-0.46480851375000043'], ['5', 'mp-1183086', '-0.4336823506250056'], ['6', 'mp-862319', '-0.5426539487500008'], ['7', 'mp-862786', '-0.38625132874999935'], ['8', 'mp-861883', '-0.3422987462500018'], ['9', 'mp-1183120', '-0.01683831749999598'], ['10', 'mp-867122', '-0.2711647174999996']]
tensor([[89, 89, 47, 77]])
tensor([[89],
        [89],
        [47],
        [77]])
torch.Size([1, 6])
torch.Size([6, 1])


In [93]:
a = np.array(
    [
        [0.0, 7.04464783, 3.52232392, 10.56697175],
        [7.04464783, 0.0, 3.52232392, 3.52232392],
        [3.52232392, 3.52232392, 0.0, 7.04464783],
        [10.56697175, 3.52232392, 7.04464783, 0.0],
    ]
)

b = np.array(
    [
        [0.0, 5.86651732, 2.93325866, 2.93325866],
        [5.86651732, 0.0, 8.79977597, 2.93325866],
        [2.93325866, 8.79977597, 0.0, 5.86651732],
        [2.93325866, 2.93325866, 5.86651732, 0.0],
    ]
)
print(a > 3.0)
print(b > 3.0)

[[False  True  True  True]
 [ True False  True  True]
 [ True  True False  True]
 [ True  True  True False]]
[[False  True False False]
 [ True False  True False]
 [False  True False  True]
 [False False  True False]]


In [94]:
filenames = []
print(filenames.append(["1"]))

None


In [1]:
from args import *
from dataset import *

dataset = MPDataset(args)

No module named 'phonopy'
No module named 'phonopy'
Downloading raw dataset...


Retrieving SummaryDoc documents:   0%|          | 0/100 [00:00<?, ?it/s]

Processing...
Processing dataset: 100%|██████████| 100/100 [00:00<00:00, 1120.49it/s]
Done!


In [2]:
d = dataset[0]
print(d, d.edge_index, d.edge_attr, d.x, d.y, sep="\n")

Data(edge_index=[2, 6], edge_attr=[6, 1], x=[4, 1], y=[1])
tensor([[0, 0, 1, 2, 2, 3],
        [2, 3, 2, 0, 1, 0]])
tensor([[3.2950],
        [3.2950],
        [3.2950],
        [3.2950],
        [3.2950],
        [3.2950]])
tensor([[89],
        [89],
        [47],
        [77]])
tensor([-0.4133])


In [3]:
import torch

a = torch.tensor([0, 1, 2, 3, 4, 5])
print(torch.min(a).item())
print(torch.max(a).item())

5