In [1]:
import torch
import torch_geometric
from graph_tools import plotter, extract_adsorbate
from functions import get_graph_sample, get_graph_formula, contcar_to_graph
from classes import HetGraphDataset
from torch_geometric.data import DataLoader

## Biomass

In [2]:
biomass = []
folder = "BM_dataset/Biomass/"
metals = ["ni", "ru"]
molecules = ["mol1", "mol2", "mol3", "mol4", "mol5"]
for metal in metals:
    for molecule in molecules:
        graph = get_graph_sample("{}{}-{}".format(folder, metal, molecule),
                                 "{}{}-0000".format(folder, metal), family="biomass")
        adsorbate = extract_adsorbate(graph)
        if graph.num_nodes == adsorbate.num_nodes:  #filter
            pass
        else:
            biomass.append(graph)
for molecule in molecules:
    biomass.append(get_graph_sample("{}{}".format(folder, molecule), gas_mol=True, family="biomass"))    
print(len(biomass))


15


## Polyurethanes

In [3]:

polyurethanes = []
folder = "BM_dataset/Polyurethanes/"
metals = ["ag", "au", "cu"]
molecules = ["mol16", "mol17", "mol18", "mol19", "mol41"]
for metal in metals:
    if metal == "au":
        for molecule in molecules:
            graph = get_graph_sample("{}{}-{}".format(folder, metal, molecule),
                                     "{}{}-sur".format(folder, metal), family="polyurethanes", 
                                     surf_multiplier=4)
            adsorbate = extract_adsorbate(graph)
            if adsorbate.num_nodes == graph.num_nodes:
                pass
            else:
                polyurethanes.append(graph)
    else:
        for molecule in molecules:
            graph = get_graph_sample("{}{}-{}".format(folder, metal, molecule),
                                     "{}{}-sur".format(folder, metal), family="polyurethanes")
            adsorbate = extract_adsorbate(graph)
            if adsorbate.num_nodes == graph.num_nodes:
                pass
            else:
                polyurethanes.append(graph)
for molecule in molecules:
    polyurethanes.append(get_graph_sample("{}{}".format(folder, molecule), gas_mol=True, family="polyurethanes"))
        
print(len(polyurethanes))

13


In [4]:
for graph in polyurethanes:
    print(graph.y)

-118.05121928
-157.42335498
-157.34867860999998
-196.70168272
-244.58779939
-117.80113190000009
-157.0177393900001
-156.93564073000005
-196.24380310000004
-244.31561998000006
-118.46814558999995
-157.59770521999997
-157.54930794999996
-196.82126069999993
-244.91745026
-117.29813031
-156.61705379
-156.57552837
-195.85739787
-243.47832296


## Plastics

In [4]:
plastics = []
folder = "BM_dataset/Plastics/"
metals = ["pt", "ru"]
molecules = ["PE", "PPit", "PPst", "PET", "PS"]
for metal in metals:
    for molecule in molecules:
        graph = get_graph_sample("{}{}-{}".format(folder, metal, molecule),
                                 "{}{}-0000".format(folder, metal), 
                                 family="plastics", 
                                 surf_multiplier=4)
        adsorbate = extract_adsorbate(graph)
        if graph.num_nodes == adsorbate.num_nodes:
            pass
        else:
            plastics.append(graph)
        
for molecule in molecules:
    plastics.append(get_graph_sample("{}{}".format(folder, molecule), gas_mol=True, family="plastics"))
        
print(len(plastics))

12


In [5]:
for graph in plastics:
    print(graph.y)

-207.79174933000013
-257.18159251
-337.67323951000003
-214.55030080000006
-256.22772469999995
-339.5979853900001
-214.70236036000006
-206.61948565
-256.25239953
-256.61073717
-334.81520255
-211.59859686


In [6]:
BM_dataset = plastics + polyurethanes + biomass
print(BM_dataset)

[Data(x=[39, 17], edge_index=[2, 76], y=-207.79174933000013, formula='C12H26-Pt1    ', family='plastics'), Data(x=[48, 17], edge_index=[2, 94], y=-257.18159251, formula='C15H32-Pt1    ', family='plastics'), Data(x=[56, 17], edge_index=[2, 118], y=-337.67323951000003, formula='C22H22O8-Pt4  ', family='plastics'), Data(x=[42, 17], edge_index=[2, 94], y=-214.55030080000006, formula='C16H18-Pt8    ', family='plastics'), Data(x=[48, 17], edge_index=[2, 94], y=-256.22772469999995, formula='C15H32-Ru1    ', family='plastics'), Data(x=[66, 17], edge_index=[2, 152], y=-339.5979853900001, formula='C22H22O84-Ru1 ', family='plastics'), Data(x=[42, 17], edge_index=[2, 102], y=-214.70236036000006, formula='C16H18-Ru8    ', family='plastics'), Data(x=[38, 17], edge_index=[2, 74], y=-206.61948565, formula='C12H26-(g)    ', family='plastics'), Data(x=[47, 17], edge_index=[2, 92], y=-256.25239953, formula='C15H32-(g)    ', family='plastics'), Data(x=[47, 17], edge_index=[2, 92], y=-256.61073717, formula

In [7]:
print(len(BM_dataset))

40


In [12]:
from torch_geometric.loader import DataLoader


In [8]:

torch.save(BM_dataset, "./BM_dataset/Graph_dataset.pt")



In [9]:
check = torch.load("./BM_dataset/Graph_dataset.pt")

In [10]:
check

[Data(x=[39, 17], edge_index=[2, 76], y=-207.79174933000013, formula='C12H26-Pt1    ', family='plastics'),
 Data(x=[48, 17], edge_index=[2, 94], y=-257.18159251, formula='C15H32-Pt1    ', family='plastics'),
 Data(x=[56, 17], edge_index=[2, 118], y=-337.67323951000003, formula='C22H22O8-Pt4  ', family='plastics'),
 Data(x=[42, 17], edge_index=[2, 94], y=-214.55030080000006, formula='C16H18-Pt8    ', family='plastics'),
 Data(x=[48, 17], edge_index=[2, 94], y=-256.22772469999995, formula='C15H32-Ru1    ', family='plastics'),
 Data(x=[66, 17], edge_index=[2, 152], y=-339.5979853900001, formula='C22H22O84-Ru1 ', family='plastics'),
 Data(x=[42, 17], edge_index=[2, 102], y=-214.70236036000006, formula='C16H18-Ru8    ', family='plastics'),
 Data(x=[38, 17], edge_index=[2, 74], y=-206.61948565, formula='C12H26-(g)    ', family='plastics'),
 Data(x=[47, 17], edge_index=[2, 92], y=-256.25239953, formula='C15H32-(g)    ', family='plastics'),
 Data(x=[47, 17], edge_index=[2, 92], y=-256.61073717