In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, models, transforms

In [None]:
class ResUnit(nn.Module):
  def __init__(self, in_channel, out_channel, bot_layer: bool):
    super().__init__()
    self.bot_layer = bot_layer
    self.cn1 = nn.Sequential(
        nn.Conv2d(in_channel, out_channel,
                  stride=2 if bot_layer else 1,
                  kernel_size=3,
                  padding = 1),

        nn.BatchNorm2d(out_channel),
        nn.ReLU()
    )
    self.cn2 = nn.Sequential(
        nn.Conv2d(out_channel, out_channel, 3,1,1),
        nn.BatchNorm2d(out_channel)
    )
    self.cn_bl = nn.Sequential(
        nn.Conv2d(in_channel, out_channel,1,2,0),
        nn.BatchNorm2d(out_channel)
    )


  def forward(self, x:torch.Tensor) -> torch.Tensor:
    out = self.cn1(x)
    out = self.cn2(out)
    if self.bot_layer:
      x = self.cn_bl(x)
    # resdual
    out = out + x
    out = F.relu(out)
    return out

x = torch.rand(32,64,32,32)
model = ResUnit(64,64, False)
model(x).size()

torch.Size([32, 64, 32, 32])

In [None]:
class ResNEt34(nn.Module):
  def __init__(self,in_channel, num_classes):
    super().__init__()
    self.input_block = nn.Sequential(
        nn.Conv2d(in_channel, 64, 7, 2, 3),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    )

    self.res_block = nn.ModuleList()

    blocks = [[64] *3, [128] * 4, [256] * 6, [512] * 3 ]

    for indx,  block in enumerate(blocks):
      module = []
      if indx==0:
          for i in block:
            module.append(ResUnit(i,i))
          module = nn.Sequential(*module)
      else:
        module.append(ResUnit(block[0] // 2, block[0], bot_layer=True))
        for i in block[1:]:
          module.append(ResUnit(i,i))
        module = nn.Sequential(*module)
      self.res_block.append()
    self.classifier = nn.Sequential(
        nn.AdaptiveAvgPool2d(2,),
        nn.Flatten(),
        nn.Linear(1024,num_classes)
    )

  def forward(self, x):
    x = self.input_block(x)
    for res_block:
      x = module(x)
    x = self.classifier(x)
    return x

SyntaxError: invalid syntax (<ipython-input-14-5cede7d2a1b0>, line 35)