In [3]:
import torch
torch.manual_seed(42)
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader

In [4]:
class Student(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes, dropout):
        super(Student, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Dropout()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Dropout()
        )

        self.classifier = nn.Sequential(
            nn.Linear(out_channels, 30),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(30, num_classes)
        )
        
    def forward(self, x):
        out_1 = self.conv_1(x)
        out_2 = self.conv_2(out_1)
        out_3 = torch.mean(out_2, dim=(2, 3))  
        out_4 = self.classifier(out_3)
        
        return out_4

In [6]:
model = Student(3, 32, 10, 0.2)

In [8]:
model.load_state_dict(torch.load('model/cifar10_github/epoch_99.bin'))
model.eval()

Student(
  (conv_1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Dropout(p=0.5, inplace=False)
  )
  (conv_2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Dropout(p=0.5, inplace=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=32, out_features=30, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=30, out_features=10, bias=True)
  )
)

In [10]:
@torch.no_grad()
def evaluate_model(dataloader, model, loss_fn):
    loss, accuracy = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            y_hat = model(X)
            loss += loss_fn(y_hat, y).item()
            accuracy += (y_hat.argmax(1) == y).type(torch.float).sum().item()
    loss = loss / len(dataloader.dataset)
    accuracy = accuracy / len(dataloader.dataset)
    return (loss, accuracy)

In [12]:
def load_CIFAR10_dataset():
    transform = torchvision.transforms.ToTensor()
    cifar10_path = './data'
    cifar10_train = torchvision.datasets.CIFAR10(root=cifar10_path, train=True, transform=transform, download=True)
    cifar10_test = torchvision.datasets.CIFAR10(root=cifar10_path, train=False, transform=transform)
    cifar10_splitted_train, cifar10_validation = torch.utils.data.random_split(
        cifar10_train, [45000, 5000], generator=torch.Generator().manual_seed(42))
    return (cifar10_train, cifar10_test, cifar10_splitted_train, cifar10_validation)

In [13]:
def construct_dataloaders(dataset, batch_size, shuffle_train=True):
    train_dataset, test_dataset, splitted_train_dataset, validation_dataset = dataset
    train_dataloader = DataLoader(train_dataset,
                                batch_size = batch_size,
                                shuffle = shuffle_train,)
    test_dataloader = DataLoader(test_dataset,
                                batch_size = 100,
                                shuffle = False,)
    splitted_train_dataloader = DataLoader(splitted_train_dataset,
                                batch_size = batch_size,
                                shuffle = shuffle_train,)
    validation_dataloader = DataLoader(validation_dataset,
                                batch_size = 100,
                                shuffle = False,)

    dataloaders = {}
    dataloaders['train'] = train_dataloader
    dataloaders['test'] = test_dataloader
    dataloaders['splitted_train'] = splitted_train_dataloader
    dataloaders['validation'] = validation_dataloader
    return dataloaders

In [21]:
def get_parameter_num(model, trainable = True):
    if trainable:
        num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        num =  sum(p.numel() for p in model.parameters())
    return num

In [16]:
dataset = load_CIFAR10_dataset()
dataloaders = construct_dataloaders(dataset, 100, shuffle_train=True)
loss_fn = torch.nn.CrossEntropyLoss()

Files already downloaded and verified


In [17]:
train_dataloader = dataloaders['splitted_train']
validation_dataloader = dataloaders['validation']

In [18]:
tr_loss, tr_accuracy = evaluate_model(train_dataloader, model, loss_fn)
va_loss, va_accuracy = evaluate_model(validation_dataloader, model, loss_fn)

In [19]:
tr_loss, tr_accuracy

(0.012904320503605737, 0.5353555555555556)

In [20]:
va_loss, va_accuracy

(0.013179720997810365, 0.5298)

In [22]:
get_parameter_num(model)

11444