In [1]:
import torch
from torch_geometric.data import Data
from pymatgen.core import Structure, Element

  from .autonotebook import tqdm as notebook_tqdm


In [None]:


def cif_to_graph(cif_path, cutoff=3.0):

    structure = Structure.from_file(cif_path)

    atom_features = []
    for site in structure:
        Z = site.specie.Z
        en = Element(site.specie.symbol).X if Element(site.specie.symbol).X is not None else 0.0
        atom_features.append([Z, en])
    x = torch.tensor(atom_features, dtype=torch.float)

    edge_index = []
    edge_attr = []
    for i, site in enumerate(structure):
        neighbors = structure.get_neighbors(site, r=cutoff)
        for neigh in neighbors:
            j = neigh.index
            edge_index.append([i, j])
            edge_attr.append([neigh.nn_distance])
            edge_index.append([j, i])
            edge_attr.append([neigh.nn_distance])
            
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data


In [None]:
import os
import pandas as pd
import torch
from torch_geometric.data import InMemoryDataset

class MOFDataset(InMemoryDataset):
    def __init__(self, root, csv_file, cif_dir, transform=None, pre_transform=None):
        self.csv_file = csv_file
        self.cif_dir = cif_dir
        super(MOFDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
     
        return []
    
    @property
    def processed_file_names(self):
        return ['data.pt']
    
    def download(self):
        pass
    
    def process(self):
        data_list = []
        
        # Load CSV data
        df = pd.read_csv(self.csv_file).set_index('filename')
        
        for filename, row in df.iterrows():
            cif_path = os.path.join(self.cif_dir, f"{filename}.cif")
            if not os.path.exists(cif_path):
                print(f"CIF file not found: {cif_path}")
                continue
            
            try:
                graph = cif_to_graph(cif_path)
            except Exception as e:
                print(f"Error processing {cif_path}: {e}")
                continue
            

            if 'ASA_m2_g' in row:
                target_value = row['ASA_m2_g']
                graph.y = torch.tensor([target_value], dtype=torch.float)
            else:
                graph.y = torch.tensor([0.0], dtype=torch.float)
                
            data_list.append(graph)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [None]:
from torch_geometric.loader import DataLoader


csv_file = './DATASET.csv'
cif_dir = 'TEST_Classification'
processed_root = 'processed_data'
dataset = MOFDataset(root=processed_root, csv_file=csv_file, cif_dir=cif_dir)
print(len(dataset))
print(dataset[0])
loader = DataLoader(dataset, batch_size=32, shuffle=True)



Processing...


CIF file not found: TEST_Classification/BUQWER_clean.cif
CIF file not found: TEST_Classification/BURJOO_clean.cif
CIF file not found: TEST_Classification/BURQIQ_clean.cif
CIF file not found: TEST_Classification/BUSGOM_clean.cif
CIF file not found: TEST_Classification/BUSNAF_clean.cif
CIF file not found: TEST_Classification/BUSQEM_clean.cif
CIF file not found: TEST_Classification/BUSQIQ_clean.cif
CIF file not found: TEST_Classification/BUSRUE_clean.cif
CIF file not found: TEST_Classification/BUVHEH_clean.cif
CIF file not found: TEST_Classification/BUVWOF02_clean.cif
CIF file not found: TEST_Classification/BUVWOF03_clean.cif
CIF file not found: TEST_Classification/BUVXOG_clean.cif
CIF file not found: TEST_Classification/BUVYEX_clean.cif
CIF file not found: TEST_Classification/BUVYIB01_clean.cif
CIF file not found: TEST_Classification/BUVYIB_clean.cif
CIF file not found: TEST_Classification/BUWLUC_clean.cif
CIF file not found: TEST_Classification/BUWMAJ_clean.cif
CIF file not found: TEST_

Done!


In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class GCNNet(nn.Module):
    def __init__(self, num_node_features, hidden_channels):
        super(GCNNet, self).__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 1)  

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        

        x = self.conv1(x, edge_index)
        x = F.relu(x)

        x = self.conv2(x, edge_index)
        x = F.relu(x)

        x = global_mean_pool(x, batch)
        

        out = self.lin(x)
        return out.view(-1)


In [None]:
from torch_geometric.loader import DataLoader

# Hyperparameters
hidden_channels = 64
learning_rate = 0.001
batch_size = 32
epochs = 50
dataset = MOFDataset(root='data/processed', csv_file = './DATASET.csv', cif_dir = "./TEST_Classification")
torch.manual_seed(42)
dataset = dataset.shuffle()
train_dataset = dataset[:int(0.8 * len(dataset))]
test_dataset = dataset[int(0.8 * len(dataset)):]

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

model = GCNNet(num_node_features=dataset.num_node_features, hidden_channels=hidden_channels)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
for epoch in range(1, epochs + 1):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")


model.eval()
total_loss = 0
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        out = model(batch)
        loss = criterion(out, batch.y.view(-1))
        total_loss += loss.item() * batch.num_graphs
test_loss = total_loss / len(test_loader.dataset)
print(f"Test MSE Loss: {test_loss:.4f}")


Processing...


CIF file not found: ./TEST_Classification/BUQWER_clean.cif
CIF file not found: ./TEST_Classification/BURJOO_clean.cif
CIF file not found: ./TEST_Classification/BURQIQ_clean.cif
CIF file not found: ./TEST_Classification/BUSGOM_clean.cif
CIF file not found: ./TEST_Classification/BUSNAF_clean.cif
CIF file not found: ./TEST_Classification/BUSQEM_clean.cif
CIF file not found: ./TEST_Classification/BUSQIQ_clean.cif
CIF file not found: ./TEST_Classification/BUSRUE_clean.cif
CIF file not found: ./TEST_Classification/BUVHEH_clean.cif
CIF file not found: ./TEST_Classification/BUVWOF02_clean.cif
CIF file not found: ./TEST_Classification/BUVWOF03_clean.cif
CIF file not found: ./TEST_Classification/BUVXOG_clean.cif
CIF file not found: ./TEST_Classification/BUVYEX_clean.cif
CIF file not found: ./TEST_Classification/BUVYIB01_clean.cif
CIF file not found: ./TEST_Classification/BUVYIB_clean.cif
CIF file not found: ./TEST_Classification/BUWLUC_clean.cif
CIF file not found: ./TEST_Classification/BUWMAJ_c

Done!


Epoch 1, Loss: 4724784.8718
Epoch 2, Loss: 4717593.6568
Epoch 3, Loss: 4704901.7682
Epoch 4, Loss: 4683419.2924
Epoch 5, Loss: 4647887.7253
Epoch 6, Loss: 4592851.3402
Epoch 7, Loss: 4513232.1474
Epoch 8, Loss: 4401726.6617
Epoch 9, Loss: 4261938.5720
Epoch 10, Loss: 4081002.6933
Epoch 11, Loss: 3867498.3166
Epoch 12, Loss: 3628664.4867
Epoch 13, Loss: 3369919.1741
Epoch 14, Loss: 3101268.2589
Epoch 15, Loss: 2838632.2268
Epoch 16, Loss: 2615166.7623
Epoch 17, Loss: 2433090.1149
Epoch 18, Loss: 2288493.8648
Epoch 19, Loss: 2201866.6346
Epoch 20, Loss: 2149388.6605
Epoch 21, Loss: 2124556.5380
Epoch 22, Loss: 2111691.8873
Epoch 23, Loss: 2103226.6063
Epoch 24, Loss: 2099453.3242
Epoch 25, Loss: 2096788.2458
Epoch 26, Loss: 2093900.5345
Epoch 27, Loss: 2092260.1849
Epoch 28, Loss: 2089438.2160
Epoch 29, Loss: 2086578.3974
Epoch 30, Loss: 2083440.4231
Epoch 31, Loss: 2081076.8819
Epoch 32, Loss: 2078032.5419
Epoch 33, Loss: 2075676.3733
Epoch 34, Loss: 2072781.8600
Epoch 35, Loss: 2072320