PyTorch, and by extension PyTorch-Geometric, has some standardized ways of handling data and datasets. We first need to make a torch `Data` object from our graphs, with graphein's conversion functions

In [None]:
from graphein.ml import GraphFormatConvertor
from src import dataloader
import torch

In [None]:
columns = [
            "chain_id",
            "coords",
            "edge_index",
            "kind",
            "node_id",
            "residue_number",
            "amino_acid_one_hot",
            "meiler"
]

convertor = GraphFormatConvertor(src_format="nx", dst_format="pyg", columns=columns, verbose = None)

In [None]:
graphein_graph, interface_labels = dataloader.load_graph("1A22", "A", "B")

TODO: Explain what X and Y are below

In [None]:
def graphein_to_torch_graph(graphein_graph, interface_labels, convertor, 
                              node_attr_columns = ["amino_acid_one_hot", "meiler"]):
    """
    Converts a Graphein graph to a pytorch-geometric Data object.
    """
    data = convertor(graphein_graph)
    data_dict= data.to_dict()
    x_data = []
    for x in node_attr_columns:
        if data_dict[x].ndim == 1:
            x_data.append(torch.atleast_2d(data_dict[x]).T)
        else:
            x_data.append(torch.atleast_2d(data_dict[x]))
    data.x = torch.hstack(x_data).float()
    data.pos = data.coords.float()
    data.y = torch.zeros(data.num_nodes)
    for i, node_id in enumerate(data.node_id):
        if node_id in interface_labels:
            data.y[i] = 1
    return data

In [None]:
torch_geometric_graph = graphein_to_torch_graph(graphein_graph, interface_labels, convertor)

TODO: explore the data object

The `torch_geometric.data.Dataset` class is a standard way of representing a graph dataset in PyTorch. It is an abstract class that you can subclass to create your own dataset. Here's what the tyical architecture of a dataset looks like:

TODO: list the methods that need to be implemented
TODO: add code to run this for the dataset

In [None]:
from torch_geometric.data import Dataset
from pathlib import Path
import pickle
import torch

class ProteinDataset(Dataset):
    """
    torch-geometric Dataset class for loading protein files as graphs.
    """
    def __init__(self, root,
                 protein_names: list, 
                 pre_transform=None, 
                 transform=None):
        self.protein_names = protein_names
        super(ProteinDataset, self).__init__(root, pre_transform=pre_transform, transform=transform)

    def download(self):
        for protein_name in self.protein_names:
            output = Path(self.raw_dir) / f'{protein_name}.pkl'
            if not output.exists():
                pdb_id, chain = protein_name.split("_")
                graphein_graph, interface_labels = dataloader.load_graph(pdb_id, chain)
                with open(output, "wb") as f:
                    pickle.dump((graphein_graph, interface_labels), f)

    @property
    def raw_file_names(self):
        return [Path(self.raw_dir) / f"{protein_name}.pkl" for protein_name in self.protein_names]

    @property
    def processed_file_names(self):
        return [Path(self.processed_dir) / f"{protein_name}.pt" for protein_name in self.protein_names]

    def process(self):
        for protein_name in self.protein_names:
            output = Path(self.processed_dir) / f'{protein_name}.pt'
            if output.exists():
                continue
            with open(Path(self.raw_dir) / f"{protein_name}.pkl", "rb") as f:
                graphein_graph, interface_labels = pickle.load(f)
            torch_graph = graphein_to_torch_graph(graphein_graph, interface_labels)
            if self.pre_transform is not None:
                torch_graph = self.pre_transform(torch_graph)
            torch.save(torch_graph, output)

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(self.processed_file_names[idx])
        return data


This puts together what we've been implementing for loading proteins as graphs with graphein, converting those into PyTorch-geometric Data objects and then wrapping those into a PyTorch Dataset. 

Graphein also has a built-in `ProteinGraphDataset` class that combines these steps. It also has some nice features like (1) the ability to load a dataset of proteins from both the PDB or AlphaFold Database directory of PDB files, (2) the ability to apply custom transformations from your bioinformatics tools of choice to the PDB files (with the `pdb_transform` argument).

## Bonus
- pre-transforms and transforms
- how to use the `ProteinGraphDataset` class to include AlphaFold models