In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


In [23]:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, dense_to_sparse
from torch_geometric.data import Data

In [46]:
class Encoder(nn.Module):
    """ Encoder module for AutoEncoder in BoxGCN-VAE. 
    Args:
        num_nodes: number of nodes in the encoder.
    """
    def __init__(self,
                 latent_dims,
                 num_nodes,
                 data_size,
                 ):
        super(Encoder, self).__init__()
       
        # Encoder. Add GC layer
        self.gconv1 = GCLayer(data_size,32)
        self.gconv2 = GCLayer(32,16)
        self.dense_boxes = nn.Linear(5, 16)
        self.dense_labels = nn.linear(1,16)
        self.act = nn.ReLU()
        self.add = torch.add()
        self.flatten = torch.flatten()
        self.dense1 = nn.Linear(16*num_nodes,128)
        self.dense2 = nn.Linear(17*num_nodes,128)
        self.dense3 = nn.Linear(128,128)
        self.concat = torch.cat()
        
        self.latent = nn.Linear(123,latent_dims)

    def forward(self, E, X_data,class_labels):
        
        x = self.gconv1(X_data,E)
        x = self.gconv2(X,E)
        x = self.flatten(x)
        
        boxes = X_data[:,1:]
        boxes = self.act(self.dense_boxes(boxes))
        
        labels = X_data[:,:1]
        labels = self.act(self.dense_labels(labels))
        
        mix = self.add(boxes,labels)
        mix = self.flatten(mix)
        mix = self.act(self.dense1(mix))
        
        x = self.concat([class_labels,x])
        x = self.act(self.dense2(x))
        x = self.add(x,mix)
        x = self.act(self.dense3(x))
        x = self.act(self.dense3(x))
        
        z_mean = self.act(self.latent(x))
        z_logvar = self.act(self.latent(x))
        
        return z_mean,z_logvar

In [47]:
enc = Encoder(32)
data = torch.FloatTensor(np.random.randint(1,100,(3,3,128,128)))
#data = data.long() #torch.LongTensor)
enc(data)

tensor([[[[ 6.4714,  0.2287,  0.2034,  ...,  2.4665,  0.0511,  5.1229],
          [ 8.0046,  1.1054,  4.2178,  ...,  6.7391,  8.2853,  3.5584],
          [ 2.8333,  3.4874,  0.0000,  ...,  2.3832,  2.7975,  3.1580],
          ...,
          [ 6.4242,  6.2622,  4.8089,  ...,  0.8557,  0.7687,  2.6595],
          [ 4.8272,  4.9356,  7.2358,  ...,  1.5183,  6.3316,  4.1390],
          [ 8.5096,  7.7990,  3.6040,  ...,  8.2278,  9.6779,  4.2201]],

         [[ 7.4707,  6.9643,  2.0873,  ...,  3.9388,  2.6321, 11.8568],
          [ 6.1515,  9.8629,  7.2426,  ...,  7.8532,  6.4814,  9.5496],
          [ 6.9265, 11.6347,  8.1229,  ...,  7.8073,  7.2392, 10.6673],
          ...,
          [ 6.8904, 14.7627, 10.5863,  ..., 12.7848,  7.6887, 10.4430],
          [ 3.3812, 12.0242,  8.6113,  ...,  9.0487, 13.6381, 12.1212],
          [ 1.0070,  8.8674,  9.1750,  ...,  5.6600,  4.8935, 12.9907]],

         [[12.2181, 11.5195, 11.5484,  ..., 12.9926, 12.8156,  7.2554],
          [ 5.7650,  3.8132,  

In [20]:
x = torch.tensor([[2,1], [5,6], [3,7], [12,0]], dtype=torch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
edge_index = torch.tensor([[0,1,2,3],
                           [1,2,3,0]], 
                          dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
data

Data(edge_index=[2, 4], x=[4, 2], y=[4])

In [37]:
adj = torch.tensor([[1,0,0,0,1],
                 [0,1,1,0,0],
                 [0,1,1,1,0],
                 [0,0,1,1,0],
                 [1,0,1,0,1]])
edge_index,_ = dense_to_sparse(adj)
data = torch.FloatTensor(np.random.randint(0,100,(5,5))/100)
data

tensor([[0.6600, 0.0100, 0.0900, 0.2400, 0.8900],
        [0.6600, 0.0100, 0.3900, 0.2100, 0.8700],
        [0.4000, 0.5400, 0.2100, 0.6400, 0.3600],
        [0.8900, 0.3900, 0.7400, 0.7500, 0.2900],
        [0.5400, 0.6400, 0.4700, 0.8900, 0.5100]])

(tensor([[0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4],
         [0, 4, 1, 2, 1, 2, 3, 2, 3, 0, 2, 4]]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]))

In [39]:
class GCLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCLayer, self).__init__(aggr='max') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, adj):
        # x has shape [N, in_channels]
        # adj has shape [E, E]
        edge_index, _ = dense_to_sparse(adj)
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]


        new_embedding = torch.cat([aggr_out, x], dim=1)
        
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
        
        return new_embedding

In [40]:
gcl1 = GCLayer(5,32)
gcl1#(data,adj)

tensor([[0.0831, 0.0000, 0.0000, 0.0538, 0.0000],
        [0.1434, 0.0000, 0.0000, 0.0530, 0.0000],
        [0.1565, 0.0000, 0.0000, 0.1168, 0.0000],
        [0.1176, 0.0003, 0.0000, 0.0318, 0.0000],
        [0.1699, 0.0000, 0.0000, 0.0572, 0.0000]], grad_fn=<ReluBackward0>)