In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

#Helper Functions

In [4]:
class HPS:
  def __init__(self, **kwargs):
    for key, value in kwargs.items():
      setattr(self, key, value)

#HPS

In [32]:
hps_config = {
    'in_channels': [3, 16, 32, 64, 128],
    'out_channels': [16, 32, 64, 128, 256],
    'kernel_size': [3, 3, 3, 3, 3],
    'dilations': [[1, 3, 9], [1, 3, 9], [1, 3, 9], [1, 3, 9], [1, 3, 9]],
    'stride': [1, 1, 1, 1, 1],
    'padding': ['same', 'same', 'same', 'same', 'same'],
    'embed_dim': [16, 32, 64, 128, 256],
    'h_patch': [8, 8, 8, 8, 8],
    'w_patch': [8, 8, 8, 8, 8],
    'num_heads': [8, 8, 8, 8, 8],
    'ff_dim': [16, 32, 64, 128, 256],
    'attn_layers': [2, 2, 2, 4, 4],
    'mlp_in': [256, 64],
    'mlp_out': [64, 32],
    'dropout': 0.1,
    'num_layers': 5,
    'num_labels': 6
}

#Convolution Block

In [33]:
class ConvolutionBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding='same', dropout=0):
    super(ConvolutionBlock, self).__init__()

    self.expansion = nn.Conv2d(in_channels, 2 * in_channels, kernel_size=1)

    self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels)
    self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    self.batch_norm1 = nn.BatchNorm2d(in_channels)
    self.batch_norm2 = nn.BatchNorm2d(in_channels)
    self.glu = nn.GLU(dim=1)
    self.swish = nn.SiLU()
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B, C, H, W = x.shape

    x = self.batch_norm1(x)
    x = self.expansion(x)
    x = self.glu(x)

    x = self.depthwise(x)
    x = self.batch_norm2(x)
    x = self.swish(x)

    x = self.pointwise(x)
    x = self.dropout(x)

    return x

#Pooling Attention

In [34]:
class PoolingAttention(nn.Module):
  def __init__(self, h_patch, w_patch, embed_dim, ff_dim, num_heads, dropout):
    super(PoolingAttention, self).__init__()

    self.h_patch = h_patch
    self.w_patch = w_patch

    self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
    self.linear = nn.Linear(embed_dim, ff_dim)
    self.linear2 = nn.Linear(ff_dim, embed_dim)

    self.norm1 = nn.LayerNorm(embed_dim)
    self.norm2 = nn.LayerNorm(embed_dim)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)

    self.swish = nn.SiLU()

    self.pos_embed = nn.Parameter(torch.zeros(1, h_patch * w_patch, embed_dim))
    nn.init.trunc_normal_(self.pos_embed, std=0.02)

  def forward(self, x):
    B, C, H, W = x.shape

    assert H % self.h_patch == 0 and W % self.w_patch == 0, "H and W must be divisible by h_patch and w_patch"

    H_patch = H // self.h_patch
    W_patch = W // self.w_patch

    N = self.h_patch * self.w_patch

    x = x.view(B, C, self.h_patch, H_patch, self.w_patch, W_patch)
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.reshape(B, N, C, H_patch, W_patch)

    x = x.view(-1, C, H_patch, W_patch)
    x = F.avg_pool2d(x, kernel_size=(H_patch, W_patch))
    x = x.view(B, N, C)

    x = x + self.pos_embed

    x = self.norm1(x)
    _x, _ = self.attention(x, x, x)
    x = x + self.dropout1(_x)

    x = self.norm2(x)
    _x = self.linear(x)
    _x = self.swish(_x)
    _x = self.linear2(_x)
    x = x + self.dropout2(_x)

    x = torch.sigmoid(x)

    x_out = x.view(B, self.h_patch, self.w_patch, C)
    x_out = x_out.permute(0, 3, 1, 2)

    x_out = x_out.repeat_interleave(H_patch, dim=2).repeat_interleave(W_patch, dim=3)

    return x_out


#Residual Block

In [35]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, dilations, stride=1, padding='same', dropout=0):
    super(ResidualBlock, self).__init__()

    self.residual_blocks = nn.ModuleList([
        ConvolutionBlock(in_channels, in_channels, kernel_size, dilation, stride, padding, dropout)
        for dilation in dilations
    ])

    self.conv_block = ConvolutionBlock(in_channels, out_channels, kernel_size, dilation=1, stride=1, padding='same', dropout=dropout)

  def forward(self, x):
    for residual_block in self.residual_blocks:
      x = x + residual_block(x)

    return self.conv_block(x)

#EncoderLayer

In [36]:
class EncoderLayer(nn.Module):
  def __init__(self,
               in_channels,
               out_channels,
               kernel_size,
               dilations,
               stride,
               padding,
               embed_dim,
               h_patch,
               w_patch,
               num_heads,
               ff_dim,
               attn_layers,
               dropout=0):
    super(EncoderLayer, self).__init__()

    self.residual_block = ResidualBlock(in_channels, out_channels, kernel_size, dilations, stride, padding, dropout)

    self.attns = nn.ModuleList([
        PoolingAttention(h_patch, w_patch, embed_dim, ff_dim, num_heads, dropout)
        for _ in range(attn_layers)
    ])

  def forward(self, x):
    x = self.residual_block(x)

    for attn in self.attns:
      _x = attn(x)
      x = x * _x

    return x

#Model

In [39]:
class Model(nn.Module):
  def __init__(self, hps):
    super(Model, self).__init__()

    self.encoder_layers = nn.ModuleList([
        EncoderLayer(
            hps.in_channels[i],
            hps.out_channels[i],
            hps.kernel_size[i],
            hps.dilations[i],
            hps.stride[i],
            hps.padding[i],
            hps.embed_dim[i],
            hps.h_patch[i],
            hps.w_patch[i],
            hps.num_heads[i],
            hps.ff_dim[i],
            hps.attn_layers[i],
            hps.dropout
        )
        for i in range(hps.num_layers)
    ])

    self.mlps = nn.ModuleList([
        nn.Sequential(
          nn.Linear(hps.mlp_in[i], hps.mlp_out[i]),
          nn.ReLU(inplace=True),
          nn.Dropout(hps.dropout)
        )
        for i in range(len(hps.mlp_in))
    ])

    self.classifier = nn.Linear(hps.mlp_out[-1], hps.num_labels)

  def forward(self, x):
    for encoder_layer in self.encoder_layers:
      x = encoder_layer(x)

    x = x.permute(0, 2, 3, 1)

    for mlp in self.mlps:
      x = mlp(x)

    x = self.classifier(x)

    x = F.softmax(x, dim=-1)

    return x.permute(0, 3, 1, 2)

#Model Initialization

In [40]:
hps = HPS(**hps_config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = Model(hps).to(device)