In [68]:
import torch
import torch.nn as nn

In [69]:
class double_cnns(nn.Module):
    def __init__(self, in_channels, intern_channels, expansion=None):
        super(double_cnns, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, intern_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(intern_channels)
        self.conv2 = nn.Conv2d(intern_channels, intern_channels * 2, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(intern_channels * 2)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.pool(x)
        return x

In [70]:
class last_cnns(nn.Module):
    def __init__(self, in_channels, intern_channels, expansion, is_four):
        super(last_cnns, self).__init__()
        self.expansion = expansion
        self.is_four = is_four
        self.conv1 = nn.Conv2d(in_channels, intern_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(intern_channels)
        self.conv2 = nn.Conv2d(intern_channels, intern_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(intern_channels)
        self.conv3 = nn.Conv2d(intern_channels, intern_channels, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(intern_channels)
        self.conv4 = nn.Conv2d(intern_channels, intern_channels * self.expansion, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(intern_channels * self.expansion)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        
        if self.is_four is not None:
            x = self.conv3(x)
            x = self.bn3(x)
            
        x = self.conv4(x)
        x = self.bn4(x)
        
        if self.expansion == 1:
            x = self.pool(x)
            
        return x

In [72]:
class VGG(nn.Module):
    def __init__(self, double_cnns, last_cnns, num_repeats, expansion, image_channels, num_classes, is_four=None):
        super(VGG, self).__init__()
        
        self.in_channels = image_channels
        self.layer1 = self._make_layer(block=double_cnns, num_rep=num_repeats[0], intern_channels=64)
        
        
        self.layer2 = self._make_layer(block=last_cnns, num_rep=num_repeats[1], intern_channels=512, expansion=expansion, is_four=is_four)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.fc1 = nn.Linear(in_features=512*expansion[-1], out_features=4096)
        self.fc2 = nn.Linear(in_features=4096, out_features=4096)
        self.fc3 = nn.Linear(in_features=4096, out_features=num_classes)
    
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.pool(x)
        
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return x
        
    def _make_layer(self, block, num_rep, intern_channels, expansion=None, is_four=None):
        layers = nn.ModuleList()
        
        if intern_channels != 512:
            
            for i in range(num_rep):
                
                layers.append(block(self.in_channels, intern_channels))
                intern_channels  = intern_channels * 2
                self.in_channels = intern_channels
                
        else:
            self.in_channels = intern_channels
            for i in range(num_rep):
                layers.append(block(self.in_channels, intern_channels, expansion[i], is_four))
        
        return nn.Sequential(*layers)   

In [76]:
def VGG16(image_channels, num_classes):
    return VGG(double_cnns, last_cnns, [3, 2], [1, 4], image_channels, num_classes)
    
def VGG19(image_channels, num_classes):
    return VGG(double_cnns, last_cnns, [3, 2], [1, 4], image_channels, num_classes, True)

def test():
    model16 = VGG16(image_channels=3, num_classes=100)
    model19 = VGG19(image_channels=3, num_classes=100)
    x = torch.randn(4, 3, 224, 224)
    print(model16(x).size())
    print(model19(x).size())

test()

torch.Size([4, 100])
torch.Size([4, 100])
