In [4]:
    import torch
import torch.nn as nn
from torch import Tensor

In [5]:
torch.cuda.is_available()

True

In [3]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [7]:
# PyTorch version of ResNet
# https://github.com/pytorch/vision/blob/cddad9ca3822011548e18342f52a3e9f4724c2dd/torchvision/models/resnet.py#L88


class ResBlk(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv_1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.batch_norm_1 = nn.BatchNorm2d(ch_out)
        self.conv_2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.batch_norm_2 = nn.BatchNorm2d(ch_out)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv_1(x)
        out = self.batch_norm_1(out)
        out = self.relu(out)
        
        out = self.conv_2(out)
        out = self.batch_norm_2(out)
        
        out += x
        out = self.relu(out)
        
        return out

In [None]:
nn.Sequential()