In [5]:
import torch
import torch.nn as nn
from torchvision.models import resnet34

In [None]:
model = resnet34()
model

In [None]:
class Block(nn.Module):
  def __init__(self, input, output, stride = 1, downsample = None):
    super().__init__()
    self.conv1 = nn.Conv2d(input, output, kernel_size = 3, stride = stride, bias = False)
    self.bn1 = nn.BatchNorm2d(output)
    self.relu = nn.ReLU(inplace = True)
    self.conv2 = nn.Conv2d(output, output, kernel_size = 3, stride = stride, bias = False)
    self.bn2 = nn.BatchNorm2d(output)
    self.downsample = downsample
  def forward(self, x):
    residual = x
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    if self.downsample is not None:
      residual = self.downsample(x)
    out += residual
    return self.relu(out)

In [None]:
class Resnet(nn.Module):
  def __init__(self, block, layers, num_classes = 1000):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
    self.filter = 64
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU(inplace = True)
    self.max1 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
    self.layer1 = self._make_layer(block, 64, layers[0], stride = 2)
    self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
    self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
    self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(512, num_classes)
  def _make_layer(self, block, output, blocks, stride = 1):
    if stride != 1 or self.filter != 64:
      downsample = nn.Sequential(
          nn.Conv2d(self.filter, output, 1, stride = stride, bias = False),
          nn.BatchNorm2d(output)
      )
    layers = []
    layers.append(block(self.filter, output, stride, downsample))
    self.filter = output
    for _ in range(1, blocks):
      layers.append(block(self.filter, output, stride, downsample))
    return nn.Sequential(*layers)
  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.max1(out)
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = self.avgpool(out)
    out = torch.flatten(out, 1)
    out = self.fc(out)
    return out

In [None]:
def resnet34():
    layers=[3, 4, 6, 3]
    model = Resnet(Block, layers)
    return model