In [41]:
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    :   latentgnn_v1.py
@Time    :   2019/05/27 13:39:43
@Author  :   Songyang Zhang 
@Version :   1.0
@Contact :   sy.zhangbuaa@hotmail.com
@License :   (C)Copyright 2019-2020, PLUS Group@ShanhaiTech University
@Desc    :   None
'''
import sys 
sys.path.append("")

from backend.torch.networks import FlattenMlp

import torch
import torch.nn as nn 
import torch.nn.functional as F

import numpy as np

class LatentGNNV1(nn.Module):
    """
    Latent Graph Neural Network for Non-local Relations Learning
    Args:
        latent_dims (list): List of latent dimensions  
        channel_stride (int): Channel reduction factor. Default: 4
        num_kernels (int): Number of latent kernels used. Default: 1
        without_residual (bool): Flag of use residual connetion. Default: False
        norm_layer (nn.Module): Module used for batch normalization. Default: nn.BatchNorm2d.
        norm_func (function): Function used for normalization. Default: F.normalize
    """
    def __init__(self, input_dim, latent_dims, 
                    num_kernels=1,
                    hidden_dict = {'in':[128, 128],
                                   'out':[32, 32]},
                    norm_layer=nn.BatchNorm2d, norm_func=F.normalize):
        super(LatentGNNV1, self).__init__()
        self.num_kernels = num_kernels
        self.norm_func = norm_func

        # Define the latentgnn kernel
        assert len(latent_dims) == num_kernels, 'Latent dimensions mismatch with number of kernels'

        for i in range(num_kernels):
            self.add_module('LatentGNN_Kernel_{}'.format(i), 
                                LatentGNN_Kernel(input_dim=input_dim, 
                                                num_kernels=num_kernels,
                                                hidden_layers = hidden_dict,
                                                latent_dim=latent_dims[i],
                                                norm_layer=norm_layer,
                                                norm_func=norm_func))

        # Residual Connection
        self.gamma = nn.Parameter(torch.zeros(1))
    
        self.kernel_channel = nn.Sequential(
                                    nn.Conv1d(in_channels=num_kernels, 
                                            out_channels=1,
                                            kernel_size=1, padding=0, bias=False)
        )
        
    def forward(self, feature):
        # Generate visible space feature 
        out_features = []
        for i in range(self.num_kernels):
            out_features.append(eval('self.LatentGNN_Kernel_{}'.format(i))(feature))
        
        out_features = torch.cat(out_features, dim=1) if self.num_kernels > 1 else out_features[0]
        out_features = self.kernel_channel(out_features.unsqueeze(dim=0).permute(0,2,1))
        return out_features.squeeze(0).T

class LatentGNN_Kernel(nn.Module):
    """
    A LatentGNN Kernel Implementation
    Args:
    """
    def __init__(self, input_dim, latent_dim, hidden_layers, num_kernels,
                        norm_layer,
                        norm_func):
        super(LatentGNN_Kernel, self).__init__()

        self.norm_func = norm_func
        #----------------------------------------------
        # Step 1 & 3: Visible-to-Latent & Latent-to-Visible
        #----------------------------------------------

        
        self.psi_in = FlattenMlp(
            hidden_sizes = hidden_layers['in_'],
            input_size=input_dim,
            output_size=latent_dim,
            bias = False
        )
        
        self.psi_out = FlattenMlp(
            hidden_sizes = hidden_layers['out_'],
            input_size=input_dim,
            output_size=1,
            bias = False
        )
#         self.psi_in = nn.Sequential(
#                         nn.Linear(input_dim, latent_dim,
#                                     bias=False),
# #                         norm_layer(latent_dim),
#                         nn.ReLU(inplace=True),
#         )

#         self.psi_out = nn.Sequential(
#                         nn.Linear(input_dim, 1,
#                                     bias=False),
# #                         norm_layer(latent_dim),
#                         nn.ReLU(inplace=True),
#         )


    def forward(self, feature):

        #----------------------------------------------
        # Step1 : Contexts-to-Latent 
        #----------------------------------------------
#         print(feature.shape)
        phi = self.psi_in(feature)
#         print(phi.shape)
        graph_adj_in = F.softmax(phi, dim=1)
#         print(graph_adj_in.shape)
        latent_node_feature = graph_adj_in.T @ feature
#         print(latent_node_feature.shape)

        #----------------------------------------------
        # Step2 : Latent-to-Latent 
        #----------------------------------------------
        # Generate Dense-connected Graph Adjacency Matrix
        latent_node_feature_n = self.norm_func(latent_node_feature, dim=-1)
#         print(latent_node_feature_n.shape)
        affinity_matrix = latent_node_feature_n @ latent_node_feature_n.T
#         print(affinity_matrix.shape)
        affinity_matrix = F.softmax(affinity_matrix, dim=-1)
#         print(affinity_matrix.shape)
        latent_node_feature = affinity_matrix @ latent_node_feature
#         print(affinity_matrix.shape)
        
        #----------------------------------------------
        # Step3: Latent-to-Output
        #----------------------------------------------
        graph_adj_out = F.softmax(self.psi_out(latent_node_feature), dim = 1)
#         print(graph_adj_out.shape)
        output = latent_node_feature.T @ graph_adj_out
        
#         print(output.shape)
        
        return output

ModuleNotFoundError: No module named 'backend'

In [39]:
if __name__ == "__main__":
    ctxt_dim = 1 + 3 + 4*2
    latent_nodes = 6
    num_kernels = 2
    hidden_dict = dict(
    in_ = [64, 64],
    out_ = [32, 32]
    )
    network = LatentGNNV1(input_dim=ctxt_dim, hidden_dict = hidden_dict,
                        latent_dims=[latent_nodes, latent_nodes],
                        num_kernels=num_kernels)
    
    dump_inputs = torch.rand((100, ctxt_dim))
    print(str(network)) 
    output = network(dump_inputs)
    print(output.shape)
#     test_group_latentgnn()

NameError: name 'FlattenMlp' is not defined

In [5]:
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    :   latentgnn_v1.py
@Time    :   2019/05/27 13:39:43
@Author  :   Songyang Zhang 
@Version :   1.0
@Contact :   sy.zhangbuaa@hotmail.com
@License :   (C)Copyright 2019-2020, PLUS Group@ShanhaiTech University
@Desc    :   None
'''

import torch
import torch.nn as nn 
import torch.nn.functional as F

import numpy as np

class LatentGNNV1(nn.Module):
    """
    Latent Graph Neural Network for Non-local Relations Learning
    Args:
        in_channels (int): Number of channels in the input feature 
        latent_dims (list): List of latent dimensions  
        channel_stride (int): Channel reduction factor. Default: 4
        num_kernels (int): Number of latent kernels used. Default: 1
        mode (str): Mode of bipartite graph message propagation. Default: 'asymmetric'.
        without_residual (bool): Flag of use residual connetion. Default: False
        norm_layer (nn.Module): Module used for batch normalization. Default: nn.BatchNorm2d.
        norm_func (function): Function used for normalization. Default: F.normalize
        graph_conv_flag (bool): Flag of use graph convolution layer. Default: False
    """
    def __init__(self, in_channels, latent_dims, 
                    channel_stride=4, num_kernels=1, 
                    mode='asymmetric', without_residual=False, 
                    norm_layer=nn.BatchNorm2d, norm_func=F.normalize,
                    graph_conv_flag=False):
        super(LatentGNNV1, self).__init__()
        self.without_resisual = without_residual
        self.num_kernels = num_kernels
        self.mode = mode
        self.norm_func = norm_func

        inter_channel = in_channels // channel_stride

        # Reduce the channel dimension for efficiency
        if mode == 'asymmetric':
            self.down_channel_v2l = nn.Sequential(
                                    nn.Conv2d(in_channels=in_channels, 
                                            out_channels=inter_channel,
                                            kernel_size=1, padding=0, bias=False),
                                    norm_layer(inter_channel),
            )

            self.down_channel_l2v = nn.Sequential(
                                    nn.Conv2d(in_channels=in_channels, 
                                            out_channels=inter_channel,
                                            kernel_size=1, padding=0, bias=False),
                                    norm_layer(inter_channel),
            )

        elif mode == 'symmetric':   
            self.down_channel = nn.Sequential(
                                    nn.Conv2d(in_channels=in_channels, 
                                            out_channels=inter_channel,
                                            kernel_size=1, padding=0, bias=False),
                                    norm_layer(inter_channel),
            )
            # nn.init.kaiming_uniform_(self.down_channel[0].weight, a=1)
            # nn.init.kaiming_uniform_(self.down_channel[0].weight, mode='fan_in')
        else:
            raise NotImplementedError

        # Define the latentgnn kernel
        assert len(latent_dims) == num_kernels, 'Latent dimensions mismatch with number of kernels'

        for i in range(num_kernels):
            self.add_module('LatentGNN_Kernel_{}'.format(i), 
                                LatentGNN_Kernel(in_channels=inter_channel, 
                                                num_kernels=num_kernels,
                                                latent_dim=latent_dims[i],
                                                norm_layer=norm_layer,
                                                norm_func=norm_func,
                                                mode=mode,
                                                graph_conv_flag=graph_conv_flag))
        # Increase the channel for the output
        self.up_channel = nn.Sequential(
                                    nn.Conv2d(in_channels=inter_channel*num_kernels,
                                                out_channels=in_channels,
                                                kernel_size=1, padding=0,bias=False),
                                    norm_layer(in_channels),
        )

        # Residual Connection
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, conv_feature):
        # Generate visible space feature 
        if self.mode == 'asymmetric':
            v2l_conv_feature = self.down_channel_v2l(conv_feature)
            l2v_conv_feature = self.down_channel_l2v(conv_feature)
            v2l_conv_feature = self.norm_func(v2l_conv_feature, dim=1)
            l2v_conv_feature = self.norm_func(l2v_conv_feature, dim=1)
        elif self.mode == 'symmetric':
            v2l_conv_feature = self.norm_func(self.down_channel(conv_feature), dim=1)
            l2v_conv_feature = None
        out_features = []
        for i in range(self.num_kernels):
            out_features.append(eval('self.LatentGNN_Kernel_{}'.format(i))(v2l_conv_feature, l2v_conv_feature))
        
        out_features = torch.cat(out_features, dim=1) if self.num_kernels > 1 else out_features[0]
        
        print(out_features.shape)
        out_features = self.up_channel(out_features)
        print(out_features.shape)
        if self.without_resisual:
            return out_features
        else:
            return conv_feature + out_features*self.gamma

class LatentGNN_Kernel(nn.Module):
    """
    A LatentGNN Kernel Implementation
    Args:
    """
    def __init__(self, in_channels, num_kernels,
                        latent_dim, norm_layer,
                        norm_func, mode, graph_conv_flag):
        super(LatentGNN_Kernel, self).__init__()
        self.mode = mode
        self.norm_func = norm_func
        #----------------------------------------------
        # Step1 & 3: Visible-to-Latent & Latent-to-Visible
        #----------------------------------------------

        if mode == 'asymmetric':
            self.psi_v2l = nn.Sequential(
                            nn.Conv2d(in_channels=in_channels,
                                        out_channels=latent_dim,
                                        kernel_size=1, padding=0,
                                        bias=False),
                            norm_layer(latent_dim),
                            nn.ReLU(inplace=True),
            )
            # nn.init.kaiming_uniform_(self.psi_v2l[0].weight, a=1)
            # nn.init.kaiming_uniform_(self.psi_v2l[0].weight, mode='fan_in')
            self.psi_l2v = nn.Sequential(
                            nn.Conv2d(in_channels=in_channels,
                                        out_channels=latent_dim,
                                        kernel_size=1, padding=0,
                                        bias=False),
                            norm_layer(latent_dim),
                            nn.ReLU(inplace=True),
            )

        elif mode == 'symmetric':
            self.psi = nn.Sequential(
                            nn.Conv2d(in_channels=in_channels,
                                        out_channels=latent_dim,
                                        kernel_size=1, padding=0,
                                        bias=False),
                            norm_layer(latent_dim),
                            nn.ReLU(inplace=True),
            )

        #----------------------------------------------
        # Step2: Latent Messge Passing
        #----------------------------------------------
        self.graph_conv_flag = graph_conv_flag
        if graph_conv_flag:
            self.GraphConvWeight = nn.Sequential(
                            # nn.Linear(in_channels, in_channels,bias=False),
                            nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=False),
                            norm_layer(in_channels),
                            nn.ReLU(inplace=True),
                        )
            nn.init.normal_(self.GraphConvWeight[0].weight, std=0.01)

    def forward(self, v2l_conv_feature, l2v_conv_feature):
        B, C, H, W = v2l_conv_feature.shape

        # Generate Bipartite Graph Adjacency Matrix
        if self.mode == 'asymmetric':
            v2l_graph_adj = self.psi_v2l(v2l_conv_feature)
            l2v_graph_adj = self.psi_l2v(l2v_conv_feature)
            v2l_graph_adj = self.norm_func(v2l_graph_adj.view(B,-1, H*W), dim=2)
            l2v_graph_adj = self.norm_func(l2v_graph_adj.view(B,-1, H*W), dim=1)
            # l2v_graph_adj = self.norm_func(l2v_graph_adj.view(B,-1, H*W), dim=2)
        elif self.mode == 'symmetric':
            assert l2v_conv_feature is None
            l2v_graph_adj = v2l_graph_adj = self.norm_func(self.psi(v2l_conv_feature).view(B,-1, H*W), dim=1)

        #----------------------------------------------
        # Step1 : Visible-to-Latent 
        #----------------------------------------------
        latent_node_feature = torch.bmm(v2l_graph_adj, v2l_conv_feature.view(B, -1, H*W).permute(0,2,1))

        #----------------------------------------------
        # Step2 : Latent-to-Latent 
        #----------------------------------------------
        # Generate Dense-connected Graph Adjacency Matrix
        latent_node_feature_n = self.norm_func(latent_node_feature, dim=-1)
        affinity_matrix = torch.bmm(latent_node_feature_n, latent_node_feature_n.permute(0,2,1))
        affinity_matrix = F.softmax(affinity_matrix, dim=-1)

        latent_node_feature = torch.bmm(affinity_matrix, latent_node_feature)

        #----------------------------------------------
        # Step3: Latent-to-Visible 
        #----------------------------------------------
        visible_feature = torch.bmm(latent_node_feature.permute(0,2,1), l2v_graph_adj).view(B, -1, H, W)

        if self.graph_conv_flag:
            visible_feature = self.GraphConvWeight(visible_feature)

        return visible_feature

In [7]:
def test_latentgnn():
    network = LatentGNNV1(in_channels=1024,
                        latent_dims=[100, 100],
                        channel_stride=8,
                        num_kernels=2,
                        mode='asymmetric',
                        graph_conv_flag=False)
    
    dump_inputs = torch.rand((8,1024, 30,30))
    print(str(network))
    output = network(dump_inputs)
    print(output.shape)
    
test_latentgnn()


LatentGNNV1(
  (down_channel_v2l): Sequential(
    (0): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (down_channel_l2v): Sequential(
    (0): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (LatentGNN_Kernel_0): LatentGNN_Kernel(
    (psi_v2l): Sequential(
      (0): Conv2d(128, 100, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (psi_l2v): Sequential(
      (0): Conv2d(128, 100, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (LatentGNN_Kernel_1): LatentGNN_Kernel(
    (psi_v2l): Sequential(
      (0