In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


In [23]:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, dense_to_sparse
from torch_geometric.data import Data

In [91]:
class GCLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCLayer, self).__init__(aggr='max') #  "Max" aggregation.
        self.mlp = nn.Sequential(nn.Linear(2 * in_channels, out_channels),
                                   nn.ReLU(),
                                   nn.Linear(out_channels, out_channels))

    def forward(self, x, adj):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        edge_index,_ = dense_to_sparse(adj)
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)

In [98]:
class Encoder(nn.Module):
    """ Encoder module for AutoEncoder in BoxGCN-VAE. 
    Args:
        num_nodes: number of nodes in the encoder.
    """
    def __init__(self,
                 latent_dims,
                 num_nodes,
                 data_size,
                 ):
        super(Encoder, self).__init__()
       
        # Encoder. Add GC layer
        self.gconv1 = GCLayer(data_size,32)
        self.gconv2 = GCLayer(32,16)
        self.dense_boxes = nn.Linear(4, 16)
        self.dense_labels = nn.Linear(1,16)
        self.act = nn.ReLU()
        self.dense1 = nn.Linear(16*num_nodes,128)
        self.dense2 = nn.Linear(17*num_nodes,128)
        self.dense3 = nn.Linear(128,128)
        
        self.latent = nn.Linear(128,latent_dims)

    def forward(self, E, X_data,class_labels):
        
        x = self.gconv1(X_data,E)
        x = self.gconv2(x,E)
        x = torch.flatten(x)
        
        boxes = X_data[:,1:]
        boxes = self.act(self.dense_boxes(boxes))
        
        labels = X_data[:,:1]
        labels = self.act(self.dense_labels(labels))
        
        mix = torch.add(boxes,labels)
        mix = torch.flatten(mix)
        mix = self.act(self.dense1(mix))
        
        x = torch.cat([class_labels,x])
        x = self.act(self.dense2(x))
        x = torch.add(x,mix)
        x = self.act(self.dense3(x))
        x = self.act(self.dense3(x))
        print(x.size())
        
        z_mean = self.act(self.latent(x))
        z_logvar = self.act(self.latent(x))
        
        return z_mean,z_logvar

In [109]:
class Decoder(nn.Module):
    """ Decoder module for Box-Vae
    """
    def __init__(self,
                 latent_dims,
                 num_nodes,
                 bbx_size,
                 class_size,
                 label_size=1
                 ):
        super(Decoder, self).__init__()
       
        self.num_nodes = num_nodes
        self.bbx_size = bbx_size
        self.class_size = class_size
        self.label_size = label_size
        self.dense1 = nn.Linear(latent_dims,128)  
        self.dense2 = nn.Linear(128,128)
        self.dense_bbx = nn.Linear(128,num_nodes*bbx_size)
        self.dense_lbl = nn.Linear(128,num_nodes*label_size)
        self.dense_edge = nn.Linear(128,num_nodes*num_nodes)
        self.dense_cls = nn.Linear(128,class_size)
        self.act1 = nn.Sigmoid()
        self.act2 = nn.Softmax()

    def forward(self, embedding):
        x = self.act1(self.dense1(embedding))
        x = self.act1(self.dense2(x))
        x = self.act1(self.dense2(x))
        
        x_bbx = self.act1(self.dense_bbx(x))
        x_bbx = torch.reshape(x_bbx,[self.num_nodes,self.bbx_size])
        
        x_lbl = self.act1(self.dense_lbl(x))
        x_lbl = torch.reshape(x_lbl,[self.num_nodes,self.label_size])
        
        x_edge = self.act1(self.dense_edge(x))
        x_edge = torch.reshape(x_bbx,[self.num_nodes,self.num_nodes])
        
        class_pred = self.act2(self.dense_cls(x))
              
        return x_bbx, x_lbl, x_edge, class_pred


In [129]:
class AutoEncoder(nn.Module):
    
    """ AutoEncoder module for Box-Vae
    """
    def __init__(self,
                 latent_dims,
                 num_nodes,
                 bbx_size,
                 num_obj_classes,
                 label_size=1
                ):
        
        super(AutoEncoder, self).__init__()
        self.latent_dims = latent_dims
        self.num_nodes = num_nodes
        self.encoder = Encoder(latent_dims,
                               num_nodes,
                               bbx_size)
        
        self.decoder = Decoder(latent_dims,
                               num_nodes,
                               bbx_size,
                               num_obj_classes,
                               label_size)
        
    def forward(self,E, X , nodes, obj_class):

        z_mean, z_logvar = self.encoder(E, X, obj_class)
        z_latent = z_mean + torch.randn(self.latent_dims)*torch.exp(z_logvar)
        x_bbx, x_lbl, x_edge, class_pred = self.decoder(z_latent)
        #true_edge=E, true_node=X, latent_dim,  true_class=nodes, class_vec=class_pred)
        # conditioning has to be added
        return x_bbx, x_lbl, x_edge, class_pred

    

In [119]:
enc = Encoder(10,5,5)
adj = torch.tensor([[1,0,0,0,1],
                 [0,1,1,0,0],
                 [0,1,1,1,0],
                 [0,0,1,1,0],
                 [1,0,1,0,1]])

data = torch.FloatTensor(np.random.randint(0,100,(5,5))/100)
class_labels = torch.tensor([1,1,1,1,1])
z_mean, z_logvar = enc(adj,data,class_labels)

torch.Size([128])


In [120]:
dec = Decoder(10,5,5,5,1)
x_bbx, x_lbl, x_edge, class_pred = dec(z_mean)

  class_pred = self.act2(self.dense_cls(x))


In [130]:
vae = AutoEncoder(10,5,5,5,1)
obj_class = torch.tensor([1,0,0,0,0])
vae(adj,data,class_labels,obj_class)

torch.Size([128])


  class_pred = self.act2(self.dense_cls(x))


(tensor([[0.5498, 0.4814, 0.4957, 0.4198, 0.5253],
         [0.4671, 0.3859, 0.6136, 0.4118, 0.6067],
         [0.3947, 0.4343, 0.4350, 0.3817, 0.5505],
         [0.5631, 0.4041, 0.3958, 0.3992, 0.5827],
         [0.4714, 0.5076, 0.3885, 0.4637, 0.4431]], grad_fn=<ViewBackward>),
 tensor([[0.4062],
         [0.4861],
         [0.5251],
         [0.5621],
         [0.5086]], grad_fn=<ViewBackward>),
 tensor([[0.5498, 0.4814, 0.4957, 0.4198, 0.5253],
         [0.4671, 0.3859, 0.6136, 0.4118, 0.6067],
         [0.3947, 0.4343, 0.4350, 0.3817, 0.5505],
         [0.5631, 0.4041, 0.3958, 0.3992, 0.5827],
         [0.4714, 0.5076, 0.3885, 0.4637, 0.4431]], grad_fn=<ViewBackward>),
 tensor([0.1267, 0.1716, 0.3279, 0.1830, 0.1909], grad_fn=<SoftmaxBackward>))

In [20]:
x = torch.tensor([[2,1], [5,6], [3,7], [12,0]], dtype=torch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
edge_index = torch.tensor([[0,1,2,3],
                           [1,2,3,0]], 
                          dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
data

Data(edge_index=[2, 4], x=[4, 2], y=[4])

In [37]:
adj = torch.tensor([[1,0,0,0,1],
                 [0,1,1,0,0],
                 [0,1,1,1,0],
                 [0,0,1,1,0],
                 [1,0,1,0,1]])
edge_index,_ = dense_to_sparse(adj)
data = torch.FloatTensor(np.random.randint(0,100,(5,5))/100)
data

tensor([[0.6600, 0.0100, 0.0900, 0.2400, 0.8900],
        [0.6600, 0.0100, 0.3900, 0.2100, 0.8700],
        [0.4000, 0.5400, 0.2100, 0.6400, 0.3600],
        [0.8900, 0.3900, 0.7400, 0.7500, 0.2900],
        [0.5400, 0.6400, 0.4700, 0.8900, 0.5100]])

(tensor([[0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4],
         [0, 4, 1, 2, 1, 2, 3, 2, 3, 0, 2, 4]]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]))

In [40]:
gcl1 = GCLayer(5,32)
gcl1#(data,adj)

tensor([[0.0831, 0.0000, 0.0000, 0.0538, 0.0000],
        [0.1434, 0.0000, 0.0000, 0.0530, 0.0000],
        [0.1565, 0.0000, 0.0000, 0.1168, 0.0000],
        [0.1176, 0.0003, 0.0000, 0.0318, 0.0000],
        [0.1699, 0.0000, 0.0000, 0.0572, 0.0000]], grad_fn=<ReluBackward0>)