In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import sklearn
import matplotlib.pyplot as plt
import torchvision

In [None]:
# VGG 16 Arch
# Conv 3, 64 -> Conv 3, 64 -> MaxPoll -> Conv3, 128 -> Conv3, 128 -> MaxPoll 
# -> Conv3, 256 -> Conv3, 256 -> Conv3, 256 -> MaxPoll
# -> Conv3, 512 -> Conv3, 512 -> Conv3, 512 -> MaxPoll
# -> Conv3, 512 -> Conv3, 512 -> Conv3, 512 -> MaxPoll
# FC - 4096
# FC - 4096
# FC - 1000
# Softmax

In [22]:
# Out channels and M denotes MaxPolling
VGG16 = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
class VGG(nn.Module):
    def __init__(self, in_channels=3, num_classes=1000):
        super(VGG, self).__init__()
        self.in_channels = in_channels
        self.conv_layers = self.creat_conv_layers(VGG16)
        self.fcs = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fcs(x)
        return x
    
    def creat_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels
        
        for x in architecture:
            if type(x) == int: # its conv
                out_channels = x
                layers+= [nn.Conv2d(in_channels= in_channels, out_channels=out_channels,
                                   kernel_size= (3,3), stride= (1,1), padding=(1,1)),
                          nn.BatchNorm2d(x),
                          nn.ReLU(),
                         ]
                in_channels = x
            elif x == 'M':
                layers += [nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))]
            
        return nn.Sequential(*layers)
        

In [23]:
device = torch.cuda('cuda' if torch.cuda.is_available() else 'cpu')
model = VGG(in_channels=3, num_classes=1000)
x = torch.randn(1, 3, 224, 224)
print(model(x).shape)

torch.Size([1, 1000])
