In [4]:
import torch
import torch_geometric.nn as nn

In [5]:
class BottleneckBlock(torch.nn.Module):
    def __init__(self, in_planes, planes, norm_fn='batch', stride=1):
        super(BottleneckBlock, self).__init__()
  
        self.conv1 = nn.ChebConv(in_planes, planes//4, self.bottleneck_layer1_kernel_size)
        self.conv2 = nn.ChebConv(planes//4, planes//4, self.bottleneck_layer2_kernel_size)
        self.conv3 = nn.ChebConv(planes//4, planes,self.bottleneck_layer3_kernel_size)
        self.relu = torch.nn.ReLU(inplace=True)

        num_groups = planes // 8
        
        if norm_fn == 'batch':
            self.norm1 = nn.BatchNorm(planes//4)
            self.norm2 = nn.BatchNorm(planes//4)
            self.norm3 = nn.BatchNorm(planes)
            if not stride == 1:
                self.norm4 = nn.BatchNorm(planes)
        
        elif norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm(planes//4)
            self.norm2 = nn.InstanceNorm(planes//4)
            self.norm3 = nn.InstanceNorm(planes)
            if not stride == 1:
                self.norm4 = nn.InstanceNorm2d(planes)

        
        if stride == 1:
            self.downsample = None
        
        else:    
            self.downsample = torch.nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)


    def forward(self, x):
        y = x
        y = self.relu(self.norm1(self.conv1(y)))
        y = self.relu(self.norm2(self.conv2(y)))
        y = self.relu(self.norm3(self.conv3(y)))

        if self.downsample is not None:
            x = self.downsample(x)

        return self.relu(x+y)