In [5]:
import torch
from torch.nn import Linear, Parameter, BatchNorm1d
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_scatter import scatter_add
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU

unique_value_edge_attr = 2

class GINConv(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not. 
        

    See https://arxiv.org/abs/1810.00826
    """
    def __init__(self, in_channels, emb_dim, aggr = "add"):
        super(GINConv, self).__init__()
        #multi-layer perceptron
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(in_channels, 2*emb_dim),
            torch.nn.BatchNorm1d(2*emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(2*emb_dim, 4*emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(4*emb_dim, 2*emb_dim),
            torch.nn.ReLU(),
            )
        self.edge_embedding = torch.nn.Embedding(unique_value_edge_attr, in_channels)

        torch.nn.init.xavier_uniform_(self.edge_embedding.weight.data)

        self.aggr = aggr
        self.in_channels = in_channels
        self.emb_dim = emb_dim

    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index, _ = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), edge_attr.size(1))
        self_loop_attr[:,-4] = 1 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = torch.zeros((edge_attr.size(0), self.in_channels), dtype=torch.float).to(edge_attr.device).to(edge_attr.dtype)

        for i in range(edge_attr.size(1)):  # Iterate over the second dimension
            embedding_ith = self.edge_embedding(edge_attr[:, i]).clone().detach().to(edge_attr.device).to(edge_attr.dtype)
            edge_embeddings += embedding_ith

        return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        print(x_j.shape, edge_attr.shape)
        return x_j + edge_attr

    def update(self, aggr_out):
        print(aggr_out.size())
        aggr_out = torch.tensor(aggr_out, dtype=torch.float)
        return self.mlp(aggr_out)

In [6]:
import torch
from torch_geometric.data import Data

# Initialize the GCNConv layer
# Transform from 3-dimensional features to 2-dimensional features
gin_conv = GINConv(in_channels=3, emb_dim=6)

# Define node features (4 nodes with 3 features each)
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float)

# Define the edges in the graph (making it undirected)
# Each pair of nodes is connected in both directions
edge_index = torch.tensor([[0, 1, 3, 3, 2], 
                           [1, 2, 2, 0, 0]], dtype=torch.long)  # Edges: 0-1, 1-2, 2-3, 3-0, 0-2
edge_attr = torch.randint(0, 1, (edge_index.size(1), 19))
# Apply the GCNConv layer to the node features
out_features = gin_conv(x, edge_index, edge_attr)

print("Original node features:\n", x)
print("\nTransformed node features:\n", out_features.shape)


torch.Size([9, 3]) torch.Size([9, 3])
torch.Size([4, 3])
Original node features:
 tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])

Transformed node features:
 torch.Size([4, 12])


  aggr_out = torch.tensor(aggr_out, dtype=torch.float)


In [7]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax, remove_self_loops
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros
import numpy as np

num_atom_type = 121 #including the extra motif tokens and graph token
num_chirality_tag = 11  #degree

num_bond_type = 7 
num_bond_direction = 3 

class GINConv(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not. 
        

    See https://arxiv.org/abs/1810.00826
    """
    def __init__(self, emb_dim, aggr = "add"):
        super(GINConv, self).__init__()
        #multi-layer perceptron
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
        self.aggr = aggr

    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])

        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)

In [1]:
from KGGraph import MoleculeDataset
task_type = 'classification'
dataset = 'bace'
dataset = MoleculeDataset("dataset/" + task_type + "/" + dataset, dataset=dataset)
print(dataset)

MoleculeDataset(1513)


In [138]:
max(dataset.x.unique())

tensor(126)

In [151]:
x = dataset.x

In [166]:
num_categories = 127  # Number of unique categories in your feature x
emb_dim = 512
x_embedding = torch.nn.Embedding(num_categories, emb_dim)

In [167]:
x_clamped = torch.clamp(x, 0, x_embedding.num_embeddings - 1)

In [169]:
x_embedding(x_clamped).sum(dim=1)

tensor([[  9.6752,  -1.3753,  -9.0359,  ...,   2.0165,  27.0869,  29.7233],
        [  6.1366,   1.4360, -11.5866,  ...,   3.0902,  23.8735,  28.4991],
        [  6.1366,   1.4360, -11.5866,  ...,   3.0902,  23.8735,  28.4991],
        ...,
        [ -8.6951,  -0.3717, -11.8091,  ...,   9.5883,  51.8893,  52.0404],
        [ -8.6951,  -0.3717, -11.8091,  ...,   9.5883,  51.8893,  52.0404],
        [ -8.6951,  -0.3717, -11.8091,  ...,   9.5883,  51.8893,  52.0404]],
       grad_fn=<SumBackward1>)

In [141]:
x_embedding(x).sum(dim=1)

IndexError: index out of range in self

In [37]:
x_embeddings = torch.empty((x.size(0), emb_dim)).to(x.device).to(x.dtype)

# for i in range(x.size(1)):  # Iterate over the second dimension
#     print(i)
embedding_ith = x_embedding(x[:, 2]).detach()
embedding_ith

IndexError: index out of range in self

In [133]:
import torch

embedding = torch.nn.Embedding(5, 3)

# Create an input tensor with valid indices.
valid_input_tensor = torch.tensor([1, 2], dtype=torch.long)  

# Verify that all indices in the input tensor are within the valid range.
if torch.all(valid_input_tensor >= 0) and torch.all(valid_input_tensor < embedding.num_embeddings):
    output_tensor = embedding(valid_input_tensor)
    print(output_tensor)
else:
    print("Invalid indices detected.")

tensor([[-0.8772,  0.5183, -0.2573],
        [ 0.0248,  1.2517,  0.3967]], grad_fn=<EmbeddingBackward0>)


In [136]:
import torch
import torch.nn as nn

embedding = nn.Embedding(10, 5)

input_tensor = torch.tensor([1, 16, 7])

# a boolean mask to filter out indices
mask = (input_tensor < embedding.num_embeddings) & (input_tensor >= 0)

# Apply the mask the input tensor
valid_indices = input_tensor[mask]
output_tensor = embedding(valid_indices)

print(output_tensor)

tensor([[ 0.5203, -0.9386, -0.7461, -0.5588,  0.0097],
        [-0.7428,  0.8891, -0.5622, -0.0936,  0.7392]],
       grad_fn=<EmbeddingBackward0>)


In [137]:
import torch

embedding = torch.nn.Embedding(5, 3)

# Create an input tensor with potentially out-of-range indices.
input_tensor = torch.tensor([1, 6, -1, 2.5], dtype=torch.long)

# Using torch.clamp() to clamp the indices to the valid range.
input_tensor_clamped = torch.clamp(input_tensor, 0, embedding.num_embeddings - 1)

output_tensor = embedding(input_tensor_clamped)

# Print the output tensor.
print(output_tensor)

tensor([[ 2.1360,  0.1091, -0.8151],
        [ 1.3550,  0.1267,  1.5197],
        [-0.0916,  0.7949, -0.7115],
        [ 0.2716, -1.0548,  0.2645]], grad_fn=<EmbeddingBackward0>)
