In [6]:
import torch
import torch.nn as nn

input = torch.randn(32, 1, 1, 1280)
m = nn.Flatten(1, -1)
output = m(input)
output.size()

torch.Size([32, 1280])

In [11]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, 
                      out_channels=out_channels,
                      kernel_size=3,
                      stride=stride,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels,
                      out_channels=out_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Sequential()
        if stride != 1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels=in_channels,
                          out_channels=out_channels,
                          kernel_size=1,
                          stride=stride,
                          bias=False)
            )
        
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.conv(x)
        out2 = self.shortcut(x)
        out3 = out+out2
        out3 = self.act(out3)
        return out, out2, out3

In [12]:
input = torch.randn(32, 3, 56, 56)
m = BasicBlock(3, 64, 2)
out, out2, out3 = m(input)
out.shape, out2.shape, out3.shape

(torch.Size([32, 64, 28, 28]),
 torch.Size([32, 64, 28, 28]),
 torch.Size([32, 64, 28, 28]))

In [4]:
from datetime import datetime
now = datetime.now()
dt_string = now.strftime("%d/%m-%H:%M:%S")
print("date and time =", dt_string)

date and time = 13/03-14:40:51
