In [1]:
%load_ext autoreload
%autoreload 2

In [11]:
import numpy as np
from pathlib import Path

import pytorch_lightning as pl

import torch
from torch import nn
from torch_geometric.nn import knn_graph
from torch_geometric import transforms
from torch_geometric.data import Data, DataLoader

from geometric_vector_perceptron import GVP_MPNN, GVP_Network, GVP


KeyboardInterrupt: 

In [3]:
data_path = Path("../data/synthetic")

In [4]:
cnn = torch.from_numpy(np.load(data_path/"cnn.npy"))
synthetic = torch.from_numpy(np.load(data_path/"synthetic.npy"))
with np.load(data_path/"answers.npz") as data:
    off_center = torch.from_numpy(data["off_center"])
    perimeter = torch.from_numpy(data["perimeter"])

In [6]:
num_structs = synthetic.shape[0] # number of protein structures, 20k

In [7]:
# add one-hot vector to the last dimension
is_special = torch.zeros((20000, 2, 100, 1))
is_special[:, 1, :3] = 1 
synthetic = torch.cat([synthetic, is_special], dim=3) # (2000, 2, 100, 3) -> (20000, 2, 100, 4) -- last channel corresponds to is_special

In [8]:
synthetic_transforms = transforms.Compose(
    [
        transforms.KNNGraph(k=10),
        transforms.Cartesian(),
        transforms.Distance()
    ]
)

data_list = [synthetic_transforms(Data(x=synthetic[n, 1], pos=synthetic[n, 0, :, :3], y=torch.tensor([off_center[n], perimeter[n]]))) for n in range(num_structs)]

In [8]:
class ExtendedPPF:
    def __init__(self, norm=True, cat=True):
        self.norm = norm
        self.cat = cat

        self.ppf = transforms.PointPairFeatures(cat=False)
        self.distance = transforms.Distance(norm=norm, cat=False)

    def __call__(self, data):
        existing_features = data.edge_attr

        ppf_features = self.ppf(data).edge_attr
        ppf_features = torch.cos(ppf_features)
        dist_features = self.distance(data).edge_attr


        new_features = torch.cat([dist_features, ppf_features[:, 1:]], dim=-1)
        
        if existing_features is not None and self.cat:
            data.edge_attr = torch.cat([existing_features, new_features], dim=-1)
        else:
            data.edge_attr = new_features

        return data

        
        

In [12]:
gnn_transforms = transforms.Compose(
    [
        transforms.KNNGraph(k=10),
        # transforms.Cartesian(),
        # transforms.PointPairFeatures(),
        # transforms.Distance(norm=False),
        ExtendedPPF()
    ]
)
data_list = [gnn_transforms(Data(x=synthetic[n, 1], pos=synthetic[n, 0, :, :3], norm=synthetic[n, 1, :, :3], y=torch.tensor([off_center[n], perimeter[n]]))) for n in range(num_structs)]

In [None]:
data_0 = data_list[0]
# node_v, node_s = torch.split(data_0.x, (3, 1), dim=-1) # node feature vectors and scalars
# edge_v, edge_s = torch.split(data_0.edge_attr, (3, 1), dim=-1) # edge feature vectors and scalars

In [None]:
data_0.edge_attr.shape

torch.Size([1000, 4])

In [None]:
dataloader = DataLoader(data_list, batch_size=32)

In [None]:
# model = SyntheticModel(
#     feats_x_in=1,
#     vectors_x_in=1,
#     feats_edge_in=1,
#     vectors_edge_in=1,
#     feats_h=20,
#     vectors_h=4,
# )

In [None]:
data = next(iter(dataloader))

In [16]:
from torch_geometric.nn import GCNConv
from torchmetrics.functional import mean_squared_error
from torch import optim

In [None]:
class SyntheticGNN(pl.LightningModule):
    def __init__(self, num_node_features):
        super().__init__()
        self.layers = nn.ModuleList([
            GCNConv(num_node_features, 32),
            GCNConv(32, 32),
            GCNConv(32, 32)
        ])
        self.classifier = nn.Linear(32, 1)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index

        for layer in self.layers:
            x = layer(x, edge_index)

        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        data, y = batch, batch.y
        y_hat = self(data)
        loss = mean_squared_error(y_hat, y)

        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters())

