In [2]:
import torch
from torch import nn
import torchvision
from torchvision import transforms

from utils import *
from learner import *



In [3]:
BATCH_SIZE = 64
device = "cuda" if torch.cuda.is_available() else "cpu"

Data processing

In [17]:
act_gr = nn.LeakyReLU(0.1)
nfs = (32, 64, 128, 256, 512, 1024)
nbks = (3, 2, 2, 1, 1)

In [23]:
def conv(ic, oc, ks=3, s=1, act=nn.ReLU, norm=None, bias=True):
    layers = []
    if norm: layers.append(norm(ic))
    if act: layers.append(act())
    layers.append(nn.Conv2d(in_channels=ic, out_channels=oc, kernel_size=ks, stride=s, padding=ks//2, bias=bias))
    return nn.Sequential(*layers)

In [24]:
def conv_block(ic, oc, s, act=act_gr, norm=None, ks=3):
    return nn.Sequential(
        conv(ic=ic, oc=oc, s=1, act=act, norm=norm, ks=ks),
        conv(ic=ic, oc=oc, s=s, act=act, norm=norm, ks=ks)
    )

In [25]:
class ResBlock(nn.Module):
    def __init__(self, ic, oc, ks=3, s=1, act=act_gr, norm=None):
        super().__init__()
        self.convs = conv_block(ic=ic, oc=oc, s=s, act=act, norm=norm)
        self.id_conv = _ if ic == oc else conv(ic=ic, oc=oc, ks=1, s=1, act=None, norm=norm)
        self.pool = _ if s==1 else nn.AvgPool2d(2, ceil_mode=True)

    def forward(self, x):
        return self.convs(x) + self.id_conv(self.pool(x))

In [31]:
def res_blocks(n_bk, ic, oc, s=1, ks=3, act=act_gr, norm=None):
    return nn.Sequential(*[ResBlock(ic if i==0 else oc, oc, s=s if i==n_bk-1 else 1, ks=ks, act=act, norm=norm) for i in range(n_bk)])

In [32]:
def get_dropmodel(act=act_gr, nfs=nfs, nbks=nbks, norm=nn.BatchNorm2d, drop=0.2):
    layers = [nn.Conv2d(3, nfs[0], 5, padding=2)]
    layers += [res_blocks(nbks[i], nfs[i], nfs[i+1], act=act, norm=norm, s=2) for i in range(len(nfs) - 1)]
    layers += [act_gr(), norm(nfs[-1], nn.AdaptiveAvgPool2d(i), nn.Flatten(), nn.Dropout(drop))]
    layers += [nn.Linear(nfs[-1], 200, bias=False), nn.BatchNorm1d(200)]
    return nn.Sequential(*layers)

In [33]:
get_dropmodel(nbks=(4,3,3,2,1), drop=0.1)

TypeError: forward() missing 1 required positional argument: 'input'