In [None]:
import pandas as pd
import torch
import torch.nn as nn
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import networkx as nx
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from sklearn.preprocessing import LabelBinarizer

## Graph Construction

In [None]:
# Load dataset
df = pd.read_csv("../datasets/pscdb/structural_rearrangement_data.csv")
pdbs = df["Free PDB"]
df.head

In [None]:
y = [torch.argmax(torch.Tensor(lab)).type(torch.LongTensor) for lab in LabelBinarizer().fit_transform(df.motion_type)]
y

In [None]:
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.edges.distance import add_hydrogen_bond_interactions, add_peptide_bonds,
from graphein.protein.graphs import construct_graph

# Override config with constructors
constructors = {
    "edge_construction_functions": [add_hydrogen_bond_interactions, add_peptide_bonds],
#    "node_metadata_functions": [add_dssp_feature]
}

config = ProteinGraphConfig(**constructors)
print(config.dict())

# Make graphs
graph_list = []
for pdb in tqdm(pdbs[0:31]):
    graph_list.append(
        construct_graph(pdb_code=pdb,
                        config=config
                       )
    )

### Convert Nx graphs to PyTorch Geometric

In [None]:
def convert_to_pyg_data(G: nx.Graph) -> Data:
   
    # Initialise dict used to construct Data object
    data = {}
    
    # Assign node ids as a feature
    data["node_id"] = [n for n in G.nodes()]
    G = nx.convert_node_labels_to_integers(G)
    
    # Construct Edge Index
    edge_index = torch.LongTensor(list(G.edges)).t().contiguous()
    
    # Add node features
    for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
        for key, value in feat_dict.items():
            data[str(key)] = [value] if i == 0 else data[str(key)] + [value]
            
    # Add edge features
    for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
        for key, value in feat_dict.items():
            data[str(key)] = list(value) if i == 0 else data[str(key)] + list(value)
    
    # Add graph-level features
    for i, feat_name in enumerate(G.graph):
        data[str(feat_name)] = [G.graph[feat_name]]
        
            
    data['edge_index'] = edge_index.view(2, -1)
    data = torch_geometric.data.Data.from_dict(data)
    data.num_nodes = G.number_of_nodes()

    return data       

In [None]:
graph_list = [convert_to_pyg_data(graph) for graph in graph_list]
graph_list

### Construct DataLoaders

In [None]:
from torch_geometric.data import Dataset, DataLoader
data = list(zip(graph_list, y[0:31]))
train_loader = DataLoader(data, batch_size=1)

## Define Model

In [None]:
from torch_geometric.nn import GCNConv, global_add_pool
from torch.nn.functional import mse_loss, nll_loss, relu, softmax, cross_entropy
from pytorch_lightning.metrics.functional import accuracy

In [None]:
class GraphNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = GCNConv(in_channels=3, out_channels=16)
        self.decoder = nn.Linear(16, 7)
        
    def forward(self, x):
        x = self.encoder(torch.Tensor(x.coords).squeeze(0), x.edge_index)
        x = relu(x)
        x = global_add_pool(x, batch=torch.as_tensor(0))
        x = self.decoder(x)
        return softmax(x)
        
    def training_step(self, batch, batch_idx):
        x, y = batch   
        y_hat = self(x)
        loss = cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        return loss
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## Train!

In [None]:
# Train Model
model = GraphNet()
trainer = pl.Trainer(max_epochs=20)
trainer.fit(model, train_loader)