In [2]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import lightning.pytorch as pl
import torch_geometric as tg

torch.seed()

  Referenced from: <DAC8FDCB-770B-356E-BA9C-E2F40A2AA20E> /opt/anaconda3/lib/python3.9/site-packages/torchvision/image.so
  Expected in:     <AE6DCE26-A528-35ED-BB3D-88890D27E6B9> /opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


9832140095817926393

In [124]:
class GCNConv(tg.nn.MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add') # aggr can be add, mean, max, min, prod, and custom for a custom func, e.g. torch.scatter_add
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x --> N, in_channels
        edge_index, _ = tg.utils.add_self_loops(edge_index, num_nodes=x.shape[0])
        # edge_index --> 2, num_edges(E)
        x = self.linear(x)
        updated_repr = self.propagate(edge_index, x=x, size=(x.shape[0], x.shape[0])) # size should be num_nodes, num_nodes
        # propagate calls message, aggregate and update with aggr being defined through constructor above
        return updated_repr

    def message(self, x_i, x_j, edge_index, size):
        # x_j --> E, in_channels : feature vectors of the source nodes (nodes from which the edge is directed)
        # x_i --> E, in_channels : feature vectors of the neighbour nodes (nodes towards which the edge is directed)
        row, col = edge_index
        # print("row",row)
        # print("col", col)
        deg = tg.utils.degree(index=row, num_nodes=size[0], dtype=x_j.dtype) # index takes the row from which the edge is directed
        # print("xs")
        # print(x_i)
        # print(x_j)
        # print("deg", deg)
        # print("deg row", deg[row]) # degree of each node of the index col
        # print("deg col", deg[col]) # degree of each node in the nodes towards which the edge goes
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm_factor = deg_inv_sqrt[row] * deg_inv_sqrt[col] # because of undirected graph
        norm_factor = norm_factor.view(-1, 1)
        return norm_factor * x_j
    
    def update(self, aggr_out):
        return aggr_out

In [118]:
class EdgeConv(tg.nn.MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr="max")
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels, out_channels),
            nn.LeakyReLU(),
            nn.Linear(out_channels, out_channels),
        )
    
    def forward(self, x, edge_index):
        updated_repr = self.propagate(edge_index, x=x, size=(x.shape[0], x.shape[0]))
        return updated_repr
    
    def message(self, x_i, x_j):
        diff = x_i - x_j
        msg = torch.cat([x_i, diff], dim=1)
        msg = self.mlp(msg)
        return msg
    
    def update(self, aggr_out):
        return aggr_out

In [121]:
from torch_geometric.nn import knn_graph 
# GPU accelerated edge convolutions as the operation apparently computes the entire graph sequentially
class KNNConv(EdgeConv): # dont understand this yet
    def __init__(self, in_channels, out_channels, k=6):
        super(KNNConv, self).__init__(in_channels, out_channels)
        self.k = k

    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow) # flow defaults to source_to_target w/reference to message passing
        return super.forward(x, edge_index)

In [123]:
gcn = GCNConv(2, 3)
ecn = EdgeConv(2, 3)
# knn = KNNConv(2, 3, k=2)

x = torch.randn(5, 2) # num_nodes * features_per_node
y = torch.randn(5, 3) # num_nodes * features_per_node

# x = torch.arange(0,10).view(5, 2).float()
edge_index = torch.tensor([[0, 1, 1, 2, 3, 4],
                           [1, 4, 2, 1, 3, 3]], dtype=torch.long) # 2, E
adam = optim.Adam(gcn.parameters(), lr=0.001)

# message = gcn.message(x, x, edge_index)
# print(message, message.shape)
x_out = gcn(x, edge_index)
print(x_out, x_out.shape)
x_ecn = ecn(x, edge_index)
print(x_ecn, x_ecn.shape)
# x_knn = knn(x)
# print(x_knn, x_knn.shape)

tensor([[-0.0767, -0.0190,  0.5316],
        [ 0.1705,  0.2976,  0.7172],
        [ 0.2856,  0.3836,  0.3468],
        [ 0.0931,  1.0285,  0.7458],
        [ 0.1694,  0.4096,  0.4468]], grad_fn=<ScatterAddBackward0>) torch.Size([5, 3])
tensor([[ 0.0000,  0.0000,  0.0000],
        [-0.3318,  0.4351, -0.1737],
        [-0.8640,  0.0603, -0.5952],
        [-0.3346,  0.4314, -0.1814],
        [-1.0888, -0.1188, -0.6981]], grad_fn=<ScatterReduceBackward0>) torch.Size([5, 3])


In [102]:
def train(num_steps):
    grads = []
    for _ in range(num_steps):

        adam.zero_grad()
        x_out = gcn(x, edge_index)
        loss = torch.nn.functional.mse_loss(x_out, y)
        loss.backward()
        grads.append(gcn.linear.weight.grad)
        adam.step()
        if _ % 100 == 0:
            print(loss.item())
    return grads, gcn(x, edge_index)

In [115]:
#vanilla GCN layer in pytorch without a non_linearity and torch_geometric
class GraphConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GraphConvLayer, self).__init__()
        self.linear = nn.Linear(in_channels, out_channels) # provides the weight matrix for weighted sum of the neighbours' features

    def forward(self, x, adj_mat):
        # x --> N, in_channels
        # adj_mat --> N, N
        adj_hat = adj_mat + torch.eye(adj_mat.shape) # N, N
        d_hat = torch.diag(torch.pow(torch.sum(adj_mat, dim=1), -0.5)) # N, N
        a_hat = d_hat @ adj_mat @ d_hat # symmetrically normalized adjacency matrix
        # a_hat --> N, N
        x = self.linear(x)
        x = a_hat @ x
        return x # aggregation through summing the neighbours' features 


In [117]:
class GraphConvNet(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, num_layers):
        super(GraphConvNet, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GraphConvLayer(in_channels, hidden_dim),
                           nn.LeakyReLU()
                        )
        for _ in range(num_layers - 2):
            self.layers.append(GraphConvLayer(hidden_dim, hidden_dim),
                              nn.LeakyReLU()
                            )
        self.layers.append(GraphConvLayer(hidden_dim, out_channels),
                          nn.LeakyReLU()
                          )
        self.model = nn.Sequential(*self.layers)
    
    def forward(self, x, adj_mat):
        return self.model(x, adj_mat)


In [129]:
rand = torch.randn(2, 4, 3)
rand.norm(dim=-1, keepdim=True).shape

torch.Size([2, 4, 1])

In [146]:
class GVP(nn.Module):
    def __init__(self, in_shapes, out_shapes, h_dim):
        super(GVP, self).__init__()
        s_in, v_in = in_shapes
        s_out, v_out = out_shapes
        print("in_shapes: ",in_shapes)
        print("out_shapes: ",out_shapes)
        self.lin_s = nn.Linear(s_in + h_dim, s_out)
        self.lin_v = nn.Linear(v_in, h_dim)
        self.lin_out_v = nn.Linear(h_dim, v_out)
        self.lin_out_s = nn.Linear(s_in + h_dim, s_out)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
    
    def forward(self, x):
        s, v = x
        # s --> N, s_in
        # v --> N, v_in, 3
        print("s: ",s.shape)
        print("v: ",v.shape)
        v = v.permute(0, 2, 1)
        vh = self.lin_v(v)
        print("vh: ",vh.shape)
        vu = self.lin_out_v(vh)
        print("vu: ",vu.shape)
        vh = vh.transpose(-1, -2)
        vu = vu.transpose(-1, -2)
        sh = torch.norm(vh, dim=-1)
        s_hn = torch.cat([s, sh], dim=-1)
        print("s_hn: ",s_hn.shape)
        s_out = self.lin_out_s(s_hn)
        print("s_out: ",s_out.shape)
        s_out = self.tanh(s_out)
        v_out = self.relu(vu * torch.norm(vu, dim=-1, keepdim=True))
        print("v_out: ",v_out.shape)
        return (s_out, v_out)
        


In [148]:
gvp = GVP((2, 4), (3, 4), 5)
x = (torch.randn(2, 2), torch.randn(2, 4, 3))
gvp(x)

in_shapes:  (2, 4)
out_shapes:  (3, 4)
s:  torch.Size([2, 2])
v:  torch.Size([2, 4, 3])
vh:  torch.Size([2, 3, 5])
vu:  torch.Size([2, 3, 4])
s_hn:  torch.Size([2, 7])
s_out:  torch.Size([2, 3])
v_out:  torch.Size([2, 4, 3])


(tensor([[-0.0216, -0.3689,  0.2534],
         [ 0.2418,  0.2044,  0.2630]], grad_fn=<TanhBackward0>),
 tensor([[[1.6820, 1.0839, 0.8515],
          [0.0389, 0.0726, 0.1266],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000]],
 
         [[0.7354, 0.9130, 0.8214],
          [0.0618, 0.0622, 0.1660],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000]]], grad_fn=<ReluBackward0>))