In [7]:
from typing import Union, Tuple, List, Any, Dict

import torch
import torch.nn.functional as F
from torch import nn


In [5]:
class SDNPool(nn.Module):
    def __init__(self, target_size: Union[int, Tuple[int, int]]):
        super().__init__()
        self._alpha = nn.Parameter(torch.rand(1))
        self._max_pool = nn.AdaptiveMaxPool2d(target_size)
        self._avg_pool = nn.AdaptiveAvgPool2d(target_size)

    def forward(self, x):
        avg_p = self._alpha * self._max_pool(x)
        max_p = (1 - self._alpha) * self._avg_pool(x)
        mixed = avg_p + max_p
        return mixed


class StandardHead(nn.Module):
    def __init__(self, in_channels: int, num_classes: int, pool_size: int = 4):
        super().__init__()
        self._num_classes = num_classes
        self._pooling = SDNPool(pool_size)
        self._fc = nn.Linear(in_channels * pool_size ** 2, num_classes)

    def forward(self, x: torch.Tensor):
        x = F.relu(x)
        x = self._pooling(x)
        x = x.view(x.size(0), -1)
        x = self._fc(x)
        return x

In [17]:
class SimpleCNNRes(torch.nn.Module):
    def __init__(self, layers_dim: List[int]):
        super().__init__()
        self.first_layer = torch.nn.Sequential(torch.nn.Conv2d(layers_dim[0], layers_dim[1], 3, padding=1),
                                               torch.nn.ReLU(),
                                               torch.nn.Conv2d(layers_dim[1], layers_dim[1], 3, padding=1, stride=2),
                                               torch.nn.ReLU()
                                               )
    
        self.blocks = torch.nn.ModuleList([
            torch.nn.Sequential(torch.nn.Conv2d(layer_dim1, layer_dim2, 3, padding=1),
                                torch.nn.ReLU(),
                                torch.nn.Conv2d(layer_dim2, layer_dim2, 3, padding=1),
                                torch.nn.ReLU()
                                )
            for layer_dim1, layer_dim2 in zip(layers_dim[2:-2], layers_dim[3:-1])
        ])
        self.final_layer = StandardHead(layers_dim[-2], layers_dim[-1], pool_size=4)

    def forward(self, x):
        x = self.first_layer(x)
        for block in self.blocks:
            x = block(x) + x
        x = self.final_layer(x)
        return x

In [29]:
model = SimpleCNNRes([3, 64, 64, 64, 64, 10])

In [30]:
model

SimpleCNNRes(
  (first_layer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
  )
  (blocks): ModuleList(
    (0-1): 2 x Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU()
    )
  )
  (final_layer): StandardHead(
    (_pooling): SDNPool(
      (_max_pool): AdaptiveMaxPool2d(output_size=4)
      (_avg_pool): AdaptiveAvgPool2d(output_size=4)
    )
    (_fc): Linear(in_features=1024, out_features=10, bias=True)
  )
)

In [21]:
model(torch.rand(1, 3, 32, 32)).shape

torch.Size([1, 10])

In [31]:
def entropy_loss(y_pred):
    return -torch.sum(torch.nn.functional.softmax(y_pred, dim=1) * torch.log_softmax(y_pred, dim=1))

In [36]:
entropy_loss(torch.tensor([4,4,4,4,4,4]).unsqueeze(0).float())

tensor(1.7918)