In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Lambda

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm 
from abc import ABC

from telegramBot import Terminator

num_cores = 8
torch.set_num_interop_threads(num_cores) # Inter-op parallelism
torch.set_num_threads(num_cores) # Intra-op parallelism

In [2]:
def c3_to_c1(y):
    if y < 2 or y > 7:
        return 0
    return 1

def c3_to_c2(y):
    match y:
        case 0:
            return 0
        case 1:
            return 2
        case 2:
            return 3
        case 3:
            return 5
        case 4:
            return 6
        case 5:
            return 5
        case 6:
            return 4
        case 7:
            return 6
        case 8:
            return 1
        case _:
            return 2

def c2_to_c1(y):
    if y < 3:
        return 0
    return 1

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

coarser = Lambda(lambda y: torch.tensor([c3_to_c1(y), c3_to_c2(y), int(y)]))

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform, target_transform = coarser)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_cores)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform, target_transform = coarser)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_cores)

## NA_Layer2

In [4]:
"""class NA_Layer2(ABC, nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        yf = self.layer_f(x[0, :])
        yi = self.layer_f(x[1, :]) + self.layer_i(x[0, :])

        return torch.stack((yf, yi))

    def infinitesimal_gradient(self):
        self.layer_i.weight.grad = self.layer_f.weight.grad
        self.layer_f.weight.grad = None

        if bias:
            self.layer_i.bias.grad = self.layer_f.bias.grad
            self.layer_f.bias.grad = None

class NA_Linear2(NA_Layer2):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):

        super().__init__()
        self.bias = bias
        self.layer_f = nn.Linear(in_features, out_features, bias, device, dtype)
        self.layer_i = nn.Linear(in_features, out_features, bias, device, dtype)

class NA_Conv2d2(NA_Layer2):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 
                padding_mode='zeros', device=None, dtype=None):

        super().__init__()
        self.bias = bias
        self.layer_f = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)
        self.layer_i = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)"""

"class NA_Layer2(ABC, nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        yf = self.layer_f(x[0, :])\n        yi = self.layer_f(x[1, :]) + self.layer_i(x[0, :])\n\n        return torch.stack((yf, yi))\n\n    def infinitesimal_gradient(self):\n        self.layer_i.weight.grad = self.layer_f.weight.grad\n        self.layer_f.weight.grad = None\n\n        if bias:\n            self.layer_i.bias.grad = self.layer_f.bias.grad\n            self.layer_f.bias.grad = None\n\nclass NA_Linear2(NA_Layer2):\n    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):\n\n        super().__init__()\n        self.bias = bias\n        self.layer_f = nn.Linear(in_features, out_features, bias, device, dtype)\n        self.layer_i = nn.Linear(in_features, out_features, bias, device, dtype)\n\nclass NA_Conv2d2(NA_Layer2):\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dila

## NA_Layer3

In [5]:
class NA_Layer(ABC, nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        pass

class NA_combiner3(NA_Layer):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        yf = self.layer_f(x[0, :])
        yi1 = self.layer_f(x[1, :]) + self.layer_i1(x[0, :])
        yi2 = self.layer_f(x[2, :]) + self.layer_i1(x[1, :]) + self.layer_i2(x[0, :])

        return torch.stack((yf, yi1, yi2))

    def infinitesimal_gradient(self, i):
        if i == 1:
            self.layer_i1.weight.grad = self.layer_f.weight.grad
            self.layer_f.weight.grad = None
    
            if self.bias:
                self.layer_i1.bias.grad = self.layer_f.bias.grad
                self.layer_f.bias.grad = None
        
        elif i == 2:
            self.layer_i2.weight.grad = self.layer_f.weight.grad
            self.layer_i1.weight.grad = None
            self.layer_f.weight.grad = None
    
            if self.bias:
                self.layer_i2.bias.grad = self.layer_f.bias.grad
                self.layer_i1.bias.grad = None
                self.layer_f.bias.grad = None
            
class NA_separate3(NA_Layer):

    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        yf = self.layer_f(x[0, :])
        yi1 = self.layer_i1(x[1, :])
        yi2 = self.layer_i2(x[2, :])

        return torch.stack((yf, yi1, yi2))

class NA_independent(NA_Layer):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        y = []
        for i in np.arange(x.size(0)):
            y.append(self.layer(x[i, :]))

        return torch.stack(y)

class NA_Linear3(NA_combiner3):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):

        super().__init__()
        self.bias = bias
        self.layer_f = nn.Linear(in_features, out_features, bias, device, dtype)
        self.layer_i1 = nn.Linear(in_features, out_features, bias, device, dtype)
        self.layer_i2 = nn.Linear(in_features, out_features, bias, device, dtype)

class NA_Conv2d3(NA_combiner3):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 
                padding_mode='zeros', device=None, dtype=None):

        super().__init__()
        self.bias = bias
        self.layer_f = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)
        self.layer_i1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)
        self.layer_i2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, 
                padding_mode, device, dtype)

class NA_BatchNorm2d3(NA_separate3):

    def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True,
                 track_running_stats: bool = True, device=None, dtype=None):

        super().__init__()
        self.layer_f = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, device, dtype)
        self.layer_i1 = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, device, dtype)
        self.layer_i2 = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, device, dtype)

class NA_BatchNorm1d3(NA_separate3):

    def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True,
                 track_running_stats: bool = True, device=None, dtype=None):

        super().__init__()
        self.layer_f = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats, device, dtype)
        self.layer_i1 = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats, device, dtype)
        self.layer_i2 = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats, device, dtype)

class NA_MaxPool2d(NA_independent):

    def __init__(self, kernel_size, stride = None, padding = 0, dilation = 1, return_indices = False, ceil_mode = False):
        
        super().__init__()
        self.layer = nn.MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)

class NA_Dropout(NA_independent):

     def __init__(self, p: float = 0.5, inplace: bool = False):
         
         super().__init__()
         self.layer = nn.Dropout(p, inplace)

In [6]:
class NA_BCNN3(nn.Module):
    def __init__(self, learning_rate, momentum, nesterov, trainloader, testloader, 
                 epochs, num_class_c1, num_class_c2, num_class_c3, labels_c_1, labels_c_2, labels_c_3, 
                 every_print = 512, training_size = 50000):

        super().__init__()
        self.trainloader = trainloader
        self.testloader = testloader
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.nesterov = nesterov
        self.activation = F.relu
        self.class_levels = 3
        self.num_c_1 = num_class_c1
        self.num_c_2 = num_class_c2
        self.num_c_3 = num_class_c3
        self.epochs = epochs
        self.labels_c_1 = labels_c_1
        self.labels_c_2 = labels_c_2
        self.labels_c_3 = labels_c_3
        self.every_print = every_print - 1 # assumed power of 2, -1 to make the mask
        self.track_size = int( training_size / batch_size / every_print ) 

        self.layer1  = NA_Conv2d3(3, 64, (3,3), padding = 'same')
        self.layer2  = NA_BatchNorm2d3(64)
        self.layer3  = NA_Conv2d3(64, 64, (3,3), padding = 'same')
        self.layer4  = NA_BatchNorm2d3(64)
        self.layer5  = NA_MaxPool2d((2,2), stride = (2,2))

        self.layer6  = NA_Conv2d3(64, 128, (3,3), padding = 'same')
        self.layer7  = NA_BatchNorm2d3(128)
        self.layer8  = NA_Conv2d3(128, 128, (3,3), padding = 'same')
        self.layer9  = NA_BatchNorm2d3(128)
        self.layer10 = NA_MaxPool2d((2,2), stride = (2,2))

        self.layerb11 = NA_Linear3(8*8*128, 256)
        self.layerb12 = NA_BatchNorm1d3(256)
        self.layerb13 = NA_Dropout(0.5)
        self.layerb14 = NA_Linear3(256, 256)
        self.layerb15 = NA_BatchNorm1d3(256)
        self.layerb16 = NA_Dropout(0.5)
        self.layerb17 = nn.Linear(256, self.num_c_1)
        self.layerb18 = nn.Linear(256, self.num_c_2)
        self.layerb19 = nn.Linear(256, self.num_c_3)

        self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[0], 
                                   momentum = self.momentum, nesterov = self.nesterov)
        self.criterion = nn.CrossEntropyLoss()
        #self.criterion = nn.MSE()

        self.combiner_layers = [self.layer1, self.layer3, self.layer6, self.layer8,
                                   self.layerb11, self.layerb14]


    def forward(self, x):

        # block 1
        z = self.layer1(x)
        z = self.activation(z)
        z = self.layer2(z)
        z = self.layer3(z)
        z = self.activation(z)
        z = self.layer4(z)
        z = self.layer5(z)

        # block 2
        z = self.layer6(z)
        z = self.activation(z)
        z = self.layer7(z)
        z = self.layer8(z)
        z = self.activation(z)
        z = self.layer9(z)
        z = self.layer10(z)

        # branch 1
        b1 = torch.flatten(z, start_dim = 2)
        b1 = self.layerb11(b1)
        b1 = self.activation(b1)
        b1 = self.layerb12(b1)
        b1 = self.layerb13(b1)
        b1 = self.layerb14(b1)
        b1 = self.activation(b1)
        b1 = self.layerb15(b1)
        b1 = self.layerb16(b1)
        b1_f = self.layerb17(b1[0, :])
        b1_i1 = self.layerb18(b1[1, :])
        b1_i2 = self.layerb19(b1[2, :])

        return b1_f, b1_i1, b1_i2


    def update_training_params(self, epoch):
        
        if epoch == 41:
            self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[1], 
                               momentum = self.momentum, nesterov = self.nesterov)
        elif epoch == 51:
            self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[2], 
                               momentum = self.momentum, nesterov = self.nesterov)


    def predict_and_learn(self, batch, labels):
        self.optimizer.zero_grad()
        predict = self(batch)
        loss_f = self.criterion(predict[0], labels[:,0])
        loss_i1 = self.criterion(predict[1], labels[:,1])
        loss_i2 = self.criterion(predict[2], labels[:,2])

        loss_i2.backward(retain_graph=True)
        for l in self.combiner_layers:
            l.infinitesimal_gradient(2)

        loss_i1.backward(retain_graph=True)
        for l in self.combiner_layers:
            l.infinitesimal_gradient(1)

        loss_f.backward()
        
        self.optimizer.step()

        return torch.tensor([loss_f, loss_i1, loss_i2])


    def train_model(self, verbose = False):
        self.train()
        
        for epoch in np.arange(self.epochs):
            self.update_training_params(epoch)

            if verbose:
                running_loss = torch.zeros(self.class_levels)
            
            for iter, (batch, labels) in enumerate(self.trainloader):
                loss = self.predict_and_learn(torch.stack((batch, torch.zeros_like(batch), torch.zeros_like(batch))), labels)

                if verbose:
                    running_loss += (loss - running_loss) / (iter+1)
                    if (iter + 1) & self.every_print == 0:
                        print(f'[{epoch + 1}] loss_f : {running_loss[0] :.3f}')
                        print(f'[{epoch + 1}] loss_i1: {running_loss_[1] :.3f}')
                        print(f'[{epoch + 1}] loss_i2: {running_loss_[2] :.3f}')
                        for i in np.arange(self.class_levels):
                            running_loss[i] = 0.0

    
    def train_track(self, filename = None):
        self.train()
        
        self.loss_track = torch.zeros(self.epochs * self.track_size, self.class_levels)
        self.accuracy_track = torch.zeros(self.epochs * self.track_size, self.class_levels)
        num_push = 0
        
        for epoch in np.arange(self.epochs):

            self.update_training_params(epoch)

            running_loss = torch.zeros(self.class_levels)
            
            for iter, (batch, labels) in enumerate(self.trainloader):
                loss = self.predict_and_learn(torch.stack((batch, torch.zeros_like(batch), torch.zeros_like(batch))), labels)

                running_loss += (loss - running_loss) / (iter+1)
                if (iter + 1) & self.every_print == 0:
                    self.loss_track[num_push, :] = running_loss
                    self.accuracy_track[num_push, :] = self.test(mode = "train")
                    num_push += 1
                    for i in np.arange(self.class_levels):
                            running_loss[i] = 0.0

        self.plot_training_loss(filename+"_train_loss.pdf")
        self.plot_test_accuracy(filename+"_test_accuracy_.pdf")


    def initialize_memory(self):
        self.correct_c1_pred = torch.zeros(self.num_c_1)
        self.total_c1_pred = torch.zeros_like(self.correct_c1_pred)
        
        self.correct_c2_pred = torch.zeros(self.num_c_2)
        self.total_c2_pred = torch.zeros_like(self.correct_c2_pred)
        
        self.correct_c3_pred = torch.zeros(self.num_c_3)
        self.total_c3_pred = torch.zeros_like(self.correct_c3_pred)

        self.correct_c1_vs_c2_pred = torch.zeros(self.num_c_1)
        self.total_c1_vs_c2_pred = torch.zeros_like(self.correct_c1_vs_c2_pred)

        self.correct_c2_vs_c3_pred = torch.zeros(self.num_c_2)
        self.total_c2_vs_c3_pred = torch.zeros_like(self.correct_c2_vs_c3_pred)

        self.correct_c1_vs_c3_pred = torch.zeros(self.num_c_1)
        self.total_c1_vs_c3_pred = torch.zeros_like(self.correct_c1_vs_c3_pred)


    def collect_test_performance(self):
        with torch.no_grad():
            for images, labels in self.testloader:
                predictions = self(torch.stack((images, torch.zeros_like(images), torch.zeros_like(images))))
                predicted = torch.zeros(batch_size, self.class_levels, dtype=torch.long)
                _, predicted[:,0] = torch.max(predictions[0], 1)
                _, predicted[:,1] = torch.max(predictions[1], 1)
                _, predicted[:,2] = torch.max(predictions[2], 1)

                for i in np.arange(batch_size):
                    if labels[i,0] == predicted[i,0]:
                        self.correct_c1_pred[labels[i,0]] += 1
                        
                    if labels[i,1] == predicted[i,1]:
                        self.correct_c2_pred[labels[i,1]] += 1

                    if labels[i,2] == predicted[i,2]:
                        self.correct_c3_pred[labels[i,2]] += 1

                    if predicted[i,1] == c3_to_c2(predicted[i,2]):
                        self.correct_c2_vs_c3_pred[predicted[i,1]] += 1

                    if predicted[i,0] == c3_to_c1(predicted[i,2]):
                        self.correct_c1_vs_c3_pred[predicted[i,0]] += 1

                    if predicted[i,0] == c2_to_c1(predicted[i,1]):
                        self.correct_c1_vs_c2_pred[predicted[i,0]] += 1
                        
                    self.total_c1_pred[labels[i,0]] += 1
                    self.total_c2_pred[labels[i,1]] += 1
                    self.total_c3_pred[labels[i,2]] += 1
                    self.total_c1_vs_c3_pred[predicted[i,0]] += 1
                    self.total_c1_vs_c2_pred[predicted[i,0]] += 1
                    self.total_c2_vs_c3_pred[predicted[i,1]] += 1

    def print_test_results(self):
        # print accuracy for each class
        for i in np.arange(self.num_c_1):
            accuracy_c1 = 100 * float(self.correct_c1_pred[i]) / self.total_c1_pred[i]
            print(f'Accuracy for class {self.labels_c_1[i]:5s}: {accuracy_c1:.2f} %')

        print("")
        for i in np.arange(self.num_c_2):
            accuracy_c2 = 100 * float(self.correct_c2_pred[i]) / self.total_c2_pred[i]
            print(f'Accuracy for class {self.labels_c_2[i]:5s}: {accuracy_c2:.2f} %')

        print("")
        for i in np.arange(self.num_c_3):
            accuracy_c3 = 100 * float(self.correct_c3_pred[i]) / self.total_c3_pred[i]
            print(f'Accuracy for class {self.labels_c_3[i]:5s}: {accuracy_c3:.2f} %')
            
        # print accuracy for the whole dataset
        print("")
        print(f'Accuracy on c1: {100 * self.correct_c1_pred.sum() // self.total_c1_pred.sum()} %')
        print(f'Accuracy on c2: {100 * self.correct_c2_pred.sum() // self.total_c2_pred.sum()} %')
        print(f'Accuracy on c3: {100 * self.correct_c3_pred.sum() // self.total_c3_pred.sum()} %')

        # print cross classes accuracy (tree)
        print("")
        for i in np.arange(self.num_c_1):
            accuracy_c1_c2 = 100 * float(self.correct_c1_vs_c2_pred[i]) / self.total_c1_vs_c2_pred[i]
            print(f'Cross-accuracy {self.labels_c_1[i]:5s} vs c2: {accuracy_c1_c2:.2f} %')
        
        print("")
        for i in np.arange(self.num_c_2):
            accuracy_c2_c3 = 100 * float(self.correct_c2_vs_c3_pred[i]) / self.total_c2_vs_c3_pred[i]
            print(f'Cross-accuracy {self.labels_c_2[i]:5s} vs c3: {accuracy_c2_c3:.2f} %')

        print("")
        for i in np.arange(self.num_c_1):
            accuracy_c1_c3 = 100 * float(self.correct_c1_vs_c3_pred[i]) / self.total_c1_vs_c3_pred[i]
            print(f'Cross-accuracy {self.labels_c_1[i]:5s} vs c3: {accuracy_c1_c3:.2f} %')


    def barplot(self, x, accuracy, labels, title):
        plt.bar(x, accuracy, tick_label = labels)
        plt.xlabel("Classes")
        plt.ylabel("Accuracy")
        plt.title(title)
        plt.show();

    def plot_test_results(self):
        # accuracy for each class
        accuracy_c1 = torch.empty(self.num_c_1)
        for i in np.arange(self.num_c_1):
            accuracy_c1[i] = float(self.correct_c1_pred[i]) / self.total_c1_pred[i]
        self.barplot(np.arange(self.num_c_1), accuracy_c1, self.labels_c_1, "Accuracy on the first level")

        accuracy_c2 = torch.empty(self.num_c_2 + 1)
        for i in np.arange(self.num_c_2):
            accuracy_c2[i] = float(self.correct_c2_pred[i]) / self.total_c2_pred[i]
        accuracy_c2[self.num_c_2] = self.correct_c2_pred.sum() / self.total_c2_pred.sum()
        self.barplot(np.arange(self.num_c_2 + 1), accuracy_c2, (*self.labels_c_2, 'overall'), "Accuracy on the second level")

        accuracy_c3 = torch.empty(self.num_c_3 + 1)
        for i in np.arange(self.num_c_3):
            accuracy_c3[i] = float(self.correct_c3_pred[i]) / self.total_c3_pred[i]
        accuracy_c3[self.num_c_3] = self.correct_c3_pred.sum() / self.total_c3_pred.sum()
        self.barplot(np.arange(self.num_c_3 + 1), accuracy_c3, (*self.labels_c_3, 'overall'), "Accuracy on the third level")

    
    def test(self, mode = "print"):
        self.initialize_memory()
        self.eval()

        self.collect_test_performance()

        match mode:
            case "plot":
                self.plot_test_results()
            case "print":
                self.print_test_results()
            case "train":
                accuracy_c1 = self.correct_c1_pred.sum() / self.total_c1_pred.sum()
                accuracy_c2 = self.correct_c2_pred.sum() / self.total_c2_pred.sum()
                accuracy_c3 = self.correct_c3_pred.sum() / self.total_c3_pred.sum()

                self.train()

                return torch.tensor([accuracy_c1, accuracy_c2, accuracy_c3])
            case _:
                raise AttributeError("Test mode not available")
        
    
    def plot_training_loss(self, filename = None):
        plt.figure(figsize=(12, 6))
        plt.plot(np.linspace(1, self.epochs, self.loss_track.size(0)), self.loss_track[:, 0].numpy(), label = "First level")
        plt.plot(np.linspace(1, self.epochs, self.loss_track.size(0)), self.loss_track[:, 1].numpy(), label = "Second level")
        plt.plot(np.linspace(1, self.epochs, self.loss_track.size(0)), self.loss_track[:, 2].numpy(), label = "Third level")
        plt.title("Training loss")
        plt.xlabel("Epochs")
        plt.ylabel("Error")
        plt.xticks(np.linspace(1, self.epochs, self.epochs)[0::2])
        plt.legend()
        if filename is not None:
            plt.savefig(filename, bbox_inches='tight')
        plt.show();

    def plot_test_accuracy(self, filename = None):
        plt.figure(figsize=(12, 6))
        plt.plot(np.linspace(1, self.epochs, self.accuracy_track.size(0)), self.accuracy_track[:, 0].numpy(), label = "First level")
        plt.plot(np.linspace(1, self.epochs, self.accuracy_track.size(0)), self.accuracy_track[:, 1].numpy(), label = "Second level")
        plt.plot(np.linspace(1, self.epochs, self.accuracy_track.size(0)), self.accuracy_track[:, 2].numpy(), label = "Third level")
        plt.title("Test accuracy")
        plt.xlabel("Epochs")
        plt.ylabel("Accuracy")
        plt.xticks(np.linspace(1, self.epochs, self.epochs)[0::2])
        plt.legend()
        if filename is not None:
            plt.savefig(filename, bbox_inches='tight')
        plt.show();

    def save_model(self, path):
        torch.save(self.state_dict(), path)

    def load_model(self, path):
        self.load_state_dict(torch.load(path))
        self.eval()

In [7]:
learning_rate = [3e-3, 5e-4, 1e-4]
momentum = 0.9
nesterov = True
epochs = 60
num_class_c1 = 2
num_class_c2 = 7
num_class_c3 = 10
every_print = 32

#--- coarse 1 classes ---
labels_c_1 = ('transport', 'animal')
#--- coarse 2 classes ---
labels_c_2 = ('sky', 'water', 'road', 'bird', 'reptile', 'pet', 'medium')
#--- fine classes ---
labels_c_3 = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [8]:
bot = Terminator()
cnn = NA_BCNN3(learning_rate, momentum, nesterov, trainloader, testloader, 
                 epochs, num_class_c1, num_class_c2, num_class_c3, labels_c_1, labels_c_2, labels_c_3, every_print)

In [None]:
#cnn.train_model(verbose = False)
err = False
filename = "models/B-CNN3_CIFAR10_NA"

try:
    cnn.train_track(filename)
    cnn.save_model(filename+".pt")
    cnn.test(mode = "print")
    
except Exception as errore:
    err = errore

if err is False:
    bot.sendMessage("Programma terminato correttamente")
else:
    bot.sendMessage("Programma NON terminato correttamente\nTipo di errore: "+err.__class__.__name__+"\nMessaggio: "+str(err))
    raise err