In [11]:
import torch
import torch_geometric.nn as nn
from bottleneckblock import BottleneckBlock

In [12]:
class SmallEncoder(torch.nn.Module):
    def __init__(self):
        super(SmallEncoder, self).__init__()
        #self.norm_fn = norm_fn

        if self.norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm(num_groups=8, num_channels=32)
            
        elif self.norm_fn == 'batch':
            self.norm1 = nn.BatchNorm(32)

        #elif self.norm_fn == 'instance':
         #   self.norm1 = nn.InstanceNorm2d(32)

        #elif self.norm_fn == 'none':
          #  self.norm1 = nn.Sequential()

        self.conv1 = nn.ChebConv(self.conv1_input_channels, self.conv1_output_channels, self.conv1_kernel_size)
        self.relu1 = nn.ReLU(inplace=True)

        self.in_planes = 32
        self.layer1 = self._make_layer(self.bottleneck_layer1_input_channels,  stride=1)
        self.layer2 = self._make_layer(self.bottleneck_layer1_output_channels, stride=2)
        self.layer3 = self._make_layer(self.bottleneck_layer2_output_channels, stride=2)

        self.dropout = None
        if dropout > 0:
            self.dropout = nn.Dropout2d(p=dropout)
        
        self.conv2 = nn.ChebConv(96, output_dim, kernel_size=1)

        for m in self.modules():
            if isinstance(m, nn.ChebConv):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm, nn.InstanceNorm)):
                if m.weight is not None:
                    torch.nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)

    def _make_layer(self, dim, stride=1):
        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
        layers = (layer1, layer2)
    
        self.in_planes = dim
        return torch.nn.Sequential(*layers)


    def forward(self, x):

        # if input is list, combine batch dimension
        is_list = isinstance(x, tuple) or isinstance(x, list)
        if is_list:
            batch_dim = x[0].shape[0]
            x = torch.cat(x, dim=0)

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.conv2(x)

        if self.training and self.dropout is not None:
            x = self.dropout(x)

        if is_list:
            x = torch.split(x, [batch_dim, batch_dim], dim=0)

        return x