In [2]:
import pandas as pd
import json
import rdkit as rd 
import pymatgen as pm
from pymatgen.core import Structure

In [2]:
qmof_attrs = pd.read_csv("qmof_database/qmof_database/qmof.csv")


  qmof_attrs = pd.read_csv("qmof_database/qmof_database/qmof.csv")


In [None]:
import os
import json

import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.nn import radius_graph
from pymatgen.core import Structure


class QMOF(InMemoryDataset):
    raw_file_names = ['qmof_structure_data.json']
    processed_file_names = ['data.pt']

    def __init__(self, root, transform=None, pre_transform=None,
                 pre_filter=None, cutoff: float = 5.0):
        """
        Args:
            root (str): Root directory. Expects
                root/raw/qmof_structure_data.json
            cutoff (float): radius (in Å) for connecting edges
        """
        self.cutoff = cutoff
        super().__init__(root, transform, pre_transform, pre_filter)
        # Load processed data
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return os.path.join(self.root, 'qmof_database')

    @property
    def processed_dir(self):
        return os.path.join(self.root, 'processed')

    def download(self):
        # Nothing to download; assume your qmof_structure_data.json
        # is already in raw_dir.
        pass

    def process(self):
        # 1) Load your JSON
        path = os.path.join(self.raw_dir, self.raw_file_names[0])
        with open(path) as f:
            struct_list = json.load(f)

        # 2) Build lookup tables
        self.ID2NAME = {
            d['qmof_id']: d['name']
            for d in struct_list
        }
        self.STRUCTURE_DATA = {
            d['qmof_id']: d['structure']
            for d in struct_list
        }

        # 3) Convert each entry to a torch_geometric.data.Data
        data_list = []
        for mol_id, struct_dict in self.STRUCTURE_DATA.items():
            data = self.get_graph(mol_id)
            if self.pre_filter and not self.pre_filter(data):
                continue
            if self.pre_transform:
                data = self.pre_transform(data)
            data_list.append(data)

        # 4) Collate & save
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def get_graph(self, mol_id) -> Data:
        # Reconstruct the Pymatgen Structure
        struct = Structure.from_dict(self.STRUCTURE_DATA[mol_id])
        coords = torch.tensor(struct.cart_coords, dtype=torch.float)  # [N,3]

        # Simple node feature: atomic number
        z = torch.tensor([site.specie.Z for site in struct], dtype=torch.long)
        x = z.view(-1, 1).to(torch.float)

        # Build edges by radius_graph
        edge_index = radius_graph(coords, r=self.cutoff, loop=False,
                                  max_num_neighbors=32)
        # Edge attr = distance
        row, col = edge_index
        edge_attr = (coords[row] - coords[col]).norm(dim=1, keepdim=True)

        data = Data(
            x=x,
            pos=coords,
            edge_index=edge_index,
            edge_attr=edge_attr,
            mol_id=mol_id,
            name=self.ID2NAME[mol_id]
        )
        return data


In [None]:
with open("qmof_database/qmof_database/qmof_structure_data.json") as f:
    struct_data = json.load(f)


In [None]:
structure0 = Structure.from_dict(struct_data[0]["structure"])