### Extract PyG Graphs from FS-MOL

This notebook is used to extract all the graphs from FS-MOL's Train sample for contrastive learning.

In [4]:
import os
import sys

FS_MOL_CHECKOUT_PATH = os.path.abspath('../')

os.chdir(FS_MOL_CHECKOUT_PATH)
sys.path.insert(0, FS_MOL_CHECKOUT_PATH)

In [13]:
from fs_mol.data import FSMolDataset, DataFold, FSMolTask
from dpu_utils.utils import RichPath

# from fs_mol.modules.graph_feature_extractor import (
#     GraphFeatureExtractor,
#     GraphFeatureExtractorConfig,
# )
# from fs_mol.data.fsmol_batcher import FSMolBatcher
import torch
from typing import List
from fs_mol.custom.utils import convert_to_pyg_graph

In [7]:
fsmol_dataset = FSMolDataset.from_directory(
        directory=RichPath.create('/FS-MOL/datasets/fs-mol/'),
        task_list_file=RichPath.create('/FS-MOL/datasets/fsmol-0.1.json'),
    )

def get_all_samples(fsmol_dataset):
    def task_to_samples(paths: List[RichPath], idx: int):
        task = FSMolTask.load_from_file(paths[0])
        
        return task.samples
    
    return iter(fsmol_dataset.get_task_reading_iterable(DataFold.TRAIN, task_reader_fn=task_to_samples))

a = get_all_samples(fsmol_dataset)

The `a` variable now includes all the molecule of FS-MOL's train dataset. we should now convert the FSMOL's data structure to PyG's `data`.

We do this using the `convert_to_pyg_graph` function.

In [9]:
pyg_graphs = [convert_to_pyg_graph(graph) for graph in a]

In [12]:
pyg_graphs[12]

Data(x=[52, 32], edge_index=[2, 55], edge_attr=[55], y=1, bool_label=True)

We should now save this data so that we can use for later usage.

In [14]:
torch.save(pyg_graphs, './fsmol_pygs.pt')