In [2]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import numpy as np
import matplotlib.pyplot as plt

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader 
from torch.optim import Adam
import torchvision



In [3]:
torch.manual_seed(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data

In [4]:
batch_size = 256
transform = transforms.Compose([transforms.ToTensor()])

train_set = CIFAR10(
    root = './../data_cifar10',
    train= True,
    download= True,
    transform=transform,
)
val_set = CIFAR10(
    root = './../data_cifar10',
    train= False,
    download= True,
    transform=transform,
)

train_loader = DataLoader(train_set, batch_size=batch_size,
                          shuffle=True, num_workers=2)
test_loader = DataLoader(val_set, batch_size=batch_size,
                         shuffle=False,
                         num_workers=2)

# Model

In [6]:
class CNNModel(nn.Module):
    def __init__(self, n_classes = 10):
        super(CNNModel, self).__init__()
        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
        )
        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(64, 64, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
        )

        self.conv_layer3 = nn.Sequential(
            nn.Conv2d(64, 64, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride= 2)
        )

        self.conv_layer4 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
        )
        self.conv_layer5 = nn.Sequential(
            nn.Conv2d(128, 128, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
        )
        self.conv_layer6 = nn.Sequential(
            nn.Conv2d(128, 128, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_layer7 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
        )
        self.conv_layer8 = nn.Sequential(
            nn.Conv2d(256, 256, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
        )
        self.conv_layer9 = nn.Sequential(
            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv_layer10 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
        )
        self.conv_layer11 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
        )
        self.conv_layer12 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.flatten = nn.Flatten()
        self.dense1 = nn.Sequential(
            nn.Linear(512*2*2, 512),
            nn.ReLU()
        )
        self.dense2 = nn.Linear(512, n_classes)

        # self.initialize_weights()
    # def initialize_weights(self):
    #     for m in self.modules():
    #         if isinstance(m, nn.Conv2d):
    #             init.xavier_uniform_(m.weight)
    #             if m.bias is not None:
    #                 init.zeros_(m.bias)
    #         elif isinstance(m, nn.Linear):
    #             init.xavier_uniform_(m.weight)
    #             if m.bias is not None:
    #                 init.zeros_(m.bias)
    def forward(self, x):
        x = self.conv_layer1(x)
        x = self.conv_layer2(x)
        x = self.conv_layer3(x)
        x = self.conv_layer4(x)
        x = self.conv_layer5(x)
        x = self.conv_layer6(x)
        x = self.conv_layer7(x)
        x = self.conv_layer8(x)
        x = self.conv_layer9(x)
        x = self.conv_layer10(x)
        x = self.conv_layer11(x)
        x = self.conv_layer12(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)

        return x