In [None]:
import tltorch
import torch
from torch import nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
def plot_losses_epoch(loss, accuracy, train_losses, test_losses, train_accuracies, test_accuracies):
    clear_output()
    fig, axs = plt.subplots(2, 2, figsize=(15, 8))
    axs[0][0].plot(range(1, len(train_losses) + 1), train_losses, label='train')
    axs[0][0].plot(range(1, len(test_losses) + 1), test_losses, label='test')
    axs[0][0].set_ylabel('loss')
    axs[0][0].set_xlabel('epoch')
    axs[0][0].legend()

    axs[0][1].plot(range(1, len(train_accuracies) + 1), train_accuracies, label='train')
    axs[0][1].plot(range(1, len(test_accuracies) + 1), test_accuracies, label='test')
    axs[0][1].set_ylabel('accuracy')
    axs[0][1].set_xlabel('epoch')
    axs[0][1].legend()

    axs[1][0].plot(range(1, len(loss) + 1), loss)
    axs[1][0].set_ylabel('loss')
    axs[1][0].set_xlabel('batch')

    axs[1][1].plot(range(1, len(accuracy) + 1), accuracy)
    axs[1][1].set_ylabel('accuracy')
    axs[1][1].set_xlabel('batch')

    for r_ax in axs:
        for ax in r_ax:
            ax.grid()

    plt.show()

In [None]:
import sys
sys.path.insert(0,'../models')
from TTCL import TTCL
from conv_models import Model, ModelConv

In [None]:
def load_checkpoint(path, modelclass, **kwargs):
    model = modelclass(**kwargs).to(device)

    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = checkpoint['epoch']
    train_loss = checkpoint['train_loss']
    test_loss = checkpoint['test_loss']
    train_accuracy = checkpoint['train_accuracy']
    test_accuracy = checkpoint['test_accuracy']

    return epoch, model, train_loss, test_loss, train_accuracy, test_accuracy

path = 'trained_models/TTCL-p-0.9-epoch100.pt'
epoch, model, train_loss, test_loss, train_accuracy, test_accuracy = load_checkpoint(path, Model)

plot_losses_epoch(train_loss, train_accuracy, train_loss, test_loss, train_accuracy, test_accuracy)

In [None]:
def get_cifar10_transform(train=True):
    if train:
        transform = T.Compose([
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.49139968, 0.48215827 ,0.44653124), (0.24703233, 0.24348505, 0.26158768), inplace=True)

        ])

    else:
        transform = T.Compose([
            T.ToTensor(),
            T.Normalize((0.49139968, 0.48215827 ,0.44653124), (0.24703233, 0.24348505, 0.26158768))
        ])

    return transform

In [None]:
batch_size = 32

test_set = CIFAR10('CIFAR10', train=False, download=True,
                   transform=get_cifar10_transform(train=False))
test_loader = DataLoader(test_set, batch_size=batch_size, pin_memory=True, num_workers=2)

In [None]:
def get_acc(model, dataloader, device="cuda:3", desc='Evaluating...'):
    model.to(device)
    model.eval()

    predicted_classes = torch.Tensor()
    true_classes = torch.Tensor()
    top5_predicted_classes = torch.Tensor()

    for images, labels in dataloader:

        images = images.to(device)
        labels = labels.to(device)
        
        with torch.no_grad():
            logits = model(images)
        
        predicted_classes = torch.cat((predicted_classes, (logits.argmax(dim=-1)).to('cpu')))
        true_classes = torch.cat((true_classes, labels.to('cpu')))
        top5_predicted_classes = torch.cat((top5_predicted_classes, (torch.topk(logits, 5).indices.to('cpu')).view(len(images), -1)), dim=0)


    accuracy = (predicted_classes == true_classes).type(torch.DoubleTensor).mean().item()
    top5_accuracy = (top5_predicted_classes == true_classes.view(-1, 1)).any(dim=1).type(torch.DoubleTensor).mean().item()
    
    return accuracy, top5_accuracy

In [None]:
baseline_number_of_parameters = 557642

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

folder = 'trained_models/'
files = ['TTCL-p-1-epoch100.pt', 'TTCL-p-0.9-epoch100.pt', 'TTCL-p-0.8-epoch100.pt',
         'TTCL-p-0.7-epoch100.pt', 'TTCL-p-0.6-epoch100.pt']

for file in files:
    path = folder + file
    epoch, model, train_loss, test_loss, train_accuracy, test_accuracy = load_checkpoint(path, Model)
    
    print(file)

    print(f'Best top-1 test accuracy: {np.max(test_accuracy)}')

    print(f'Number of parameters: {count_parameters(model)}')
    print(f'Compression ratio: {baseline_number_of_parameters / count_parameters(model)}')

    (acc, top5_acc) = get_acc(model, test_loader, device)
    print(f'Top-1 accuracy: {acc:.3f}\nTop-5 accuracy: {top5_acc:.3f}')
    
    print()