In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [3]:
# LeNet5
class LeNet5(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(LeNet5, self).__init__()
        self._model_name = 'LeNet5'
        self.features = nn.Sequential(nn.Conv2d(in_channels, 6, 5),
                                      nn.ReLU(),
                                      nn.MaxPool2d(2, stride=2),
                                      nn.Conv2d(6, 16, 5),
                                      nn.ReLU(),
                                      nn.MaxPool2d(2, stride=2) 
                                      )
        self.classfier = nn.Sequential(nn.Linear(5*5*16, 120),
                                       nn.ReLU(),
                                       nn.Linear(120, 84),
                                       nn.ReLU(),
                                       nn.Linear(84, num_classes)
                                       )
    
    def forward(self, x):
        out = self.features(x)
        out = out.view(-1, 16*5*5)     # must reshape, then go into FC!
        out = self.classfier(out)
        return out

In [None]:
# AlexNet
class AlexNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(AlexNet, self).__init__()
        self._model_name = 'AlexNet'
        self.features = nn.Sequential(nn.Conv2d(in_channels, 96, kernel_size=11, stride=4, padding=2),
                                      nn.ReLU(inplace=True),
                                      nn.MaxPool2d(3, stride=2),
                                      nn.Conv2d(96, 256, 5, padding=2),
                                      nn.ReLU(inplace=True),
                                      nn.MaxPool2d(3, stride=2),
                                      nn.Conv2d(256, 384, 3, padding=1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(384, 384, 3, padding=1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(384, 256, 3, padding=1),
                                      nn.ReLU(inplace=True),
                                      nn.MaxPool2d(3, stride=2)
                                      )
        self.classfier = nn.Sequential(nn.Dropout(),
                                       nn.Linear(6*6*256, 4096),
                                       nn.ReLU(inplace=True),
                                       nn.Dropout(),
                                       nn.Linear(4096, 4096),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(4096, 1000)
                                       )
    
    def forward(self, x):
        out = self.features(x)
        out = out.view(-1, 6*6*256)     # must reshape, then go into FC!
        out = self.classfier(out)
        return out

In [5]:
# AlexNet for CIFAR
class AlexNet_CIFAR(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet_CIFAR, self).__init__()
        self._model_name = 'AlexNet_CIFAR'
        self.features = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
                                      nn.ReLU(inplace=True),
                                      nn.MaxPool2d(3, 2),
                                      nn.Conv2d(32, 32, 5, 1, padding=2),
                                      nn.ReLU(inplace=True),
                                      nn.AvgPool2d(3, 2),
                                      nn.Conv2d(32, 64, 5, 1, padding=2),
                                      nn.ReLU(inplace=True),
                                      nn.AvgPool2d(3, 2)
                                      )
        self.classifier = nn.Sequential(nn.Linear(64*3*3, 64),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(64, num_classes),
                                        nn.ReLU(inplace=True),
                                        )

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), 64*3*3)
        out = self.classifier(out)
        return out