In [None]:
import tltorch
import torch
from torch import nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [None]:
def plot_losses_epoch(loss, accuracy, train_losses, test_losses, train_accuracies, test_accuracies):
    clear_output()
    fig, axs = plt.subplots(2, 2, figsize=(15, 8))
    axs[0][0].plot(range(1, len(train_losses) + 1), train_losses, label='train')
    axs[0][0].plot(range(1, len(test_losses) + 1), test_losses, label='test')
    axs[0][0].set_ylabel('loss')
    axs[0][0].set_xlabel('epoch')
    axs[0][0].legend()

    axs[0][1].plot(range(1, len(train_accuracies) + 1), train_accuracies, label='train')
    axs[0][1].plot(range(1, len(test_accuracies) + 1), test_accuracies, label='test')
    axs[0][1].set_ylabel('accuracy')
    axs[0][1].set_xlabel('epoch')
    axs[0][1].legend()

    axs[1][0].plot(range(1, len(loss) + 1), loss)
    axs[1][0].set_ylabel('loss')
    axs[1][0].set_xlabel('batch')

    axs[1][1].plot(range(1, len(accuracy) + 1), accuracy)
    axs[1][1].set_ylabel('accuracy')
    axs[1][1].set_xlabel('batch')

    for r_ax in axs:
        for ax in r_ax:
            ax.grid()

    plt.show()

In [None]:
import sys
sys.path.insert(0,'../models')
from TTCL import TTCL

In [None]:
inp_ch, inp_h, inp_w = (3, 32, 32)
p = 1

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(inp_ch, 64, (3, 3)),
                nn.BatchNorm2d(64),
                nn.ReLU()
        )
        self.tcl1 = nn.Sequential(
                TTCL((4, 4, 4), (4, 4, 4), (3, 3), rank=(20, 20, 20, 1), p=p, padding='same', device=device),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.MaxPool2d(3, stride=2)
        )
        self.tcl2 = nn.Sequential(
                TTCL((4, 4, 4), (4, 8, 4), (3, 3), rank=(27, 22, 22, 1), p=p, padding='same', device=device),
                nn.BatchNorm2d(128),
                nn.ReLU()
        )
        self.tcl3 = nn.Sequential(
                TTCL((4, 8, 4), (4, 8, 4), (3, 3), rank=(23, 23, 23, 1), p=p, padding='same', device=device),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.MaxPool2d(3, stride=2)
        )
        self.tcl4 = nn.Sequential(
                TTCL((4, 8, 4), (4, 8, 4), (3, 3), rank=(23, 23, 23, 1), p=p, padding='same', device=device),
                nn.BatchNorm2d(128),
                nn.ReLU()
        )
        self.tcl5 = nn.Sequential(
                TTCL((4, 8, 4), (4, 8, 4), (3, 3), rank=(23, 23, 23, 1), p=p, padding='same', device=device),
                nn.AvgPool2d(4)
        )
        self.linear = nn.Linear(128, 10)
    def forward(self, x):
        x = self.conv(x)
        x = self.tcl1(x)
        x = self.tcl2(x)
        x = self.tcl3(x)
        x = self.tcl4(x)
        x = self.tcl5(x)
        x = x.reshape(x.size(0), -1)
        x = self.linear(x)
        return x

In [None]:
def load_checkpoint(path, modelclass, **kwargs):
    model = modelclass(**kwargs).to(device)

    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = checkpoint['epoch']
    train_loss = checkpoint['train_loss']
    test_loss = checkpoint['test_loss']
    train_accuracy = checkpoint['train_accuracy']
    test_accuracy = checkpoint['test_accuracy']

    return epoch, model, train_loss, test_loss, train_accuracy, test_accuracy

path = 'trained_models/TTCL-conv-p-1-epoch40.pt'
epoch, model, train_loss, test_loss, train_accuracy, test_accuracy = load_checkpoint(path, Model)

plot_losses_epoch(train_loss, train_accuracy, train_loss, test_loss, train_accuracy, test_accuracy)