In [9]:
import torch
from torch import nn

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

### MODEL ###
class ConvBlock(nn.Module):
    def __init__(self, dims, affine=False):
        super(ConvBlock, self).__init__()
        self.num_conv = len(dims)-1
        self.ins = dims[:-1]
        self.outs = dims[1:]
        
        layers = list()
        for i in range(self.num_conv):
            layers += [nn.Conv2d(self.ins[i], self.outs[i], 3, stride=1, padding=1),
                            nn.GELU(),
                            nn.BatchNorm2d(self.outs[i], affine=affine)]
        layers += [nn.MaxPool2d(2, stride=2)]
        self.layers = nn.Sequential(*layers)
    def forward(self, x):
        return self.layers(x)

def read_cfg(cfg):
    temp = list()
    tempa = list()
    for i in range(len(cfg)-1):        
        if cfg[i+1] == 'M':
            temp.append([i, "M"])
            tempa = list()
        else:
            temp.append([i])
            tempa = list()
    return temp

read_cfg(cfg["VGG11"])

[[64, 'M'], [128, 'M'], [256, 256, 'M'], [512, 512, 'M'], [512, 512, 'M']]

In [None]:

    
class VGG(nn.Module):
    def __init__(self, args):
        super(VGG, self).__init__()
        self.bn_affine = True if args.bn_affine == 1 else False
              
        if args.dataset == "cifar":
            self.units = cfg["VGG"]
            self.output_layer  = nn.Linear(self.units[-1][-1], 10)
            self.size = (args.batchsize, 3, 32, 32)
            
        self.module_list = nn.ModuleList( [ConvBlock(unit, affine=self.bn_affine) for unit in (self.units)])
                
        self.f3 = nn.Dropout(p=0.2)
        self.act2 = nn.ReLU()
        self.AP = torch.nn.AvgPool2d(2, stride=1)
        
    def forward(self, data):
        x = data.view(self.size)
        output = []
        for module in self.module_list:
            x_ = module(x.detach())
            x = module(x)
            output.append(x_)
        x = torch.flatten(x, 1)
        x = self.f3(x)
        x_ = self.act2(self.output_layer(x.detach()))
        x = self.act2(self.output_layer(x))
        output.append(x_)
        return x, output
    
# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# y = net(x)
# print(y.size())