In [2]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchviz import make_dot
from torch.linalg import vector_norm as vnorm
from torch.linalg import solve as solve_matrix_system

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Lambda

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm 

from telegramBot import Terminator

num_cores = 8
torch.set_num_interop_threads(num_cores) # Inter-op parallelism
torch.set_num_threads(num_cores) # Intra-op parallelism

In [None]:
def c3_to_c1(y):
    if y < 2 or y > 7:
        return 0
    return 1

def c3_to_c2(y):
    match y:
        case 0:
            return 0
        case 1:
            return 2
        case 2:
            return 3
        case 3:
            return 5
        case 4:
            return 6
        case 5:
            return 5
        case 6:
            return 4
        case 7:
            return 6
        case 8:
            return 1
        case _:
            return 2

def c2_to_c1(y):
    if y < 3:
        return 0
    return 1

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

coarser = Lambda(lambda y: torch.tensor([c3_to_c1(y), c3_to_c2(y), int(y)]))

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform, target_transform = coarser)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_cores)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform, target_transform = coarser)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_cores)

In [None]:
class BCNN3(nn.Module):
    def __init__(self, alpha, beta, gamma, learning_rate, momentum, nesterov, trainloader, testloader, 
                 epochs, num_class_c1, num_class_c2, num_class_c3, labels_c_1, labels_c_2, labels_c_3, 
                 every_print = 512, training_size = 50000):
        
        super().__init__()
        self.trainloader = trainloader
        self.testloader = testloader
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.nesterov = nesterov
        self.alphas = alpha
        self.betas = beta
        self.gammas = gamma
        self.alpha = self.alphas[0]
        self.beta = self.betas[0]
        self.gamma = self.gammas[0]
        self.activation = F.relu
        self.class_levels = 3
        self.num_c_1 = num_class_c1
        self.num_c_2 = num_class_c2
        self.num_c_3 = num_class_c3
        self.epochs = epochs
        self.labels_c_1 = labels_c_1
        self.labels_c_2 = labels_c_2
        self.labels_c_3 = labels_c_3
        self.every_print = every_print - 1 # assumed power of 2, -1 to make the mask
        self.track_size = int( training_size / batch_size / every_print ) 

        self.layer1  = nn.Conv2d(3, 64, (3,3), padding = 'same')
        self.layer2  = nn.BatchNorm2d(64)
        self.layer3  = nn.Conv2d(64, 64, (3,3), padding = 'same')
        self.layer4  = nn.BatchNorm2d(64)
        self.layer5  = nn.MaxPool2d((2,2), stride = (2,2))

        self.layer6  = nn.Conv2d(64, 128, (3,3), padding = 'same')
        self.layer7  = nn.BatchNorm2d(128)
        self.layer8  = nn.Conv2d(128, 128, (3,3), padding = 'same')
        self.layer9  = nn.BatchNorm2d(128)
        self.layer10 = nn.MaxPool2d((2,2), stride = (2,2))

        self.layerb11 = nn.Linear(8*8*128, 256)
        self.layerb12 = nn.BatchNorm1d(256)
        self.layerb13 = nn.Dropout(0.5)
        self.layerb14 = nn.Linear(256, 256)
        self.layerb15 = nn.BatchNorm1d(256)
        self.layerb16 = nn.Dropout(0.5)
        self.layerb17 = nn.Linear(256, self.num_c_1)
        self.layerb27 = nn.Linear(2*256, self.num_c_2)
        self.layerb37 = nn.Linear(2*256, self.num_c_3)


        self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[0], 
                                   momentum = self.momentum, nesterov = self.nesterov)
        self.criterion = nn.CrossEntropyLoss()

    
    def forward(self, x):

        # block 1
        z = self.layer1(x)
        z = self.activation(z)
        z = self.layer2(z)
        z = self.layer3(z)
        z = self.activation(z)
        z = self.layer4(z)
        z = self.layer5(z)

        # block 2
        z = self.layer6(z)
        z = self.activation(z)
        z = self.layer7(z)
        z = self.layer8(z)
        z = self.activation(z)
        z = self.layer9(z)
        z = self.layer10(z)
        z = torch.flatten(z, start_dim = 1)

        # branch 1
        z = self.layerb11(z)
        z = self.activation(z)
        z = self.layerb12(z)
        z = self.layerb13(z)
        z = self.layerb14(z)
        z = self.activation(z)
        z = self.layerb15(z)
        z = self.layerb16(z)
        b1 = self.layerb17(z)

        # branch 2
        b2 = self.project(z, self.layerb11)
        b2 = self.layerb27(b2)

        # branch 3
        b3 = self.project(z, torch.hstack((self.layerb17, self.layerb27)))
        b3 = self.layerb37(b3)

        return b1, b2, b3


    # Assumption: W is column full rank. 
    def project(self, z, W): #https://math.stackexchange.com/questions/4021915/projection-orthogonal-to-two-vectors
        
        P = solve_matrix_system(torch.matmul(torch.transpose(W), W))
        P = torch.matmul(P, torch.transpose(W))
        P = torch.matmul(W, P)

        prj = torch.matmul(z.clone().deatch(), P)
        ort = z - prj
        return torch.hstack((prj, ort))

    
    def update_training_params(self, epoch):
        if epoch == 41:
            self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[1], 
                               momentum = self.momentum, nesterov = self.nesterov)
        elif epoch == 51:
            self.optimizer = optim.SGD(self.parameters(), lr = self.learning_rate[2], 
                               momentum = self.momentum, nesterov = self.nesterov)


    def predict_and_learn(self, batch, labels):
        self.optimizer.zero_grad()
        predict = self(batch)
        loss =  self.alpha * self.criterion(predict[0], labels[:,0]) + \
                self.beta * self.criterion(predict[1], labels[:,1]) + \
                self.gamma * self.criterion(predict[2], labels[:,2])

        loss.backward()
        self.optimizer.step()

        return loss

    
    def train_model(self, verbose = False):
        self.train()
        
        for epoch in np.arange(self.epochs):
            self.update_training_params(epoch)

            if verbose:
                running_loss = 0.
            
            for iter, (batch, labels) in enumerate(self.trainloader):
                loss = self.predict_and_learn(batch, labels)

                if verbose:
                    running_loss += (loss.item() - running_loss) / (iter+1)
                    if (iter + 1) & self.every_print == 0:
                        print(f'[{epoch + 1}] loss: {running_loss :.3f}')
                        running_loss = 0.0

    
    def train_track(self, filename = None):
        self.train()
        
        self.loss_track = torch.zeros(self.epochs * self.track_size)
        self.accuracy_track = torch.zeros(self.epochs * self.track_size, self.class_levels)
        num_push = 0
        
        for epoch in np.arange(self.epochs):

            self.update_training_params(epoch)

            running_loss = 0.
            
            for iter, (batch, labels) in enumerate(self.trainloader):
                loss = self.predict_and_learn(batch, labels)

                running_loss += (loss.item() - running_loss) / (iter+1)
                if (iter + 1) & self.every_print == 0:
                    self.loss_track[num_push] = running_loss
                    self.accuracy_track[num_push, :] = self.test(mode = "train")
                    num_push += 1
                    running_loss = 0.0

        self.plot_training_loss(filename+"_train_loss.pdf")
        self.plot_test_accuracy(filename+"_test_accuracy_.pdf")

    
    def initialize_memory(self):
        self.correct_c1_pred = torch.zeros(self.num_c_1)
        self.total_c1_pred = torch.zeros_like(self.correct_c1_pred)
        
        self.correct_c2_pred = torch.zeros(self.num_c_2)
        self.total_c2_pred = torch.zeros_like(self.correct_c2_pred)
        
        self.correct_c3_pred = torch.zeros(self.num_c_3)
        self.total_c3_pred = torch.zeros_like(self.correct_c3_pred)

        self.correct_c1_vs_c2_pred = torch.zeros(self.num_c_1)
        self.total_c1_vs_c2_pred = torch.zeros_like(self.correct_c1_vs_c2_pred)

        self.correct_c2_vs_c3_pred = torch.zeros(self.num_c_2)
        self.total_c2_vs_c3_pred = torch.zeros_like(self.correct_c2_vs_c3_pred)

        self.correct_c1_vs_c3_pred = torch.zeros(self.num_c_1)
        self.total_c1_vs_c3_pred = torch.zeros_like(self.correct_c1_vs_c3_pred)

    
    def collect_test_performance(self):
        with torch.no_grad():
            for images, labels in self.testloader:
                predictions = self(images)
                predicted = torch.zeros(predictions[0].size(0), self.class_levels, dtype=torch.long)
                _, predicted[:,0] = torch.max(predictions[0], 1)
                _, predicted[:,1] = torch.max(predictions[1], 1)
                _, predicted[:,2] = torch.max(predictions[2], 1)

                for i in np.arange(predictions[0].size(0)):
                    if labels[i,0] == predicted[i,0]:
                        self.correct_c1_pred[labels[i,0]] += 1
                        
                    if labels[i,1] == predicted[i,1]:
                        self.correct_c2_pred[labels[i,1]] += 1

                    if labels[i,2] == predicted[i,2]:
                        self.correct_c3_pred[labels[i,2]] += 1

                    if predicted[i,1] == c3_to_c2(predicted[i,2]):
                        self.correct_c2_vs_c3_pred[predicted[i,1]] += 1

                    if predicted[i,0] == c3_to_c1(predicted[i,2]):
                        self.correct_c1_vs_c3_pred[predicted[i,0]] += 1

                    if predicted[i,0] == c2_to_c1(predicted[i,1]):
                        self.correct_c1_vs_c2_pred[predicted[i,0]] += 1
                        
                    self.total_c1_pred[labels[i,0]] += 1
                    self.total_c2_pred[labels[i,1]] += 1
                    self.total_c3_pred[labels[i,2]] += 1
                    self.total_c1_vs_c3_pred[predicted[i,0]] += 1
                    self.total_c1_vs_c2_pred[predicted[i,0]] += 1
                    self.total_c2_vs_c3_pred[predicted[i,1]] += 1
                    

    def print_test_results(self):
        # print accuracy for each class
        for i in np.arange(self.num_c_1):
            accuracy_c1 = 100 * float(self.correct_c1_pred[i]) / self.total_c1_pred[i]
            print(f'Accuracy for class {self.labels_c_1[i]:5s}: {accuracy_c1:.2f} %')

        print("")
        for i in np.arange(self.num_c_2):
            accuracy_c2 = 100 * float(self.correct_c2_pred[i]) / self.total_c2_pred[i]
            print(f'Accuracy for class {self.labels_c_2[i]:5s}: {accuracy_c2:.2f} %')

        print("")
        for i in np.arange(self.num_c_3):
            accuracy_c3 = 100 * float(self.correct_c3_pred[i]) / self.total_c3_pred[i]
            print(f'Accuracy for class {self.labels_c_3[i]:5s}: {accuracy_c3:.2f} %')
            
        # print accuracy for the whole dataset
        print("")
        print(f'Accuracy on c1: {100 * self.correct_c1_pred.sum() // self.total_c1_pred.sum()} %')
        print(f'Accuracy on c2: {100 * self.correct_c2_pred.sum() // self.total_c2_pred.sum()} %')
        print(f'Accuracy on c3: {100 * self.correct_c3_pred.sum() // self.total_c3_pred.sum()} %')

        # print cross classes accuracy (tree)
        print("")
        for i in np.arange(self.num_c_1):
            accuracy_c1_c2 = 100 * float(self.correct_c1_vs_c2_pred[i]) / self.total_c1_vs_c2_pred[i]
            print(f'Cross-accuracy {self.labels_c_1[i]:5s} vs c2: {accuracy_c1_c2:.2f} %')
        
        print("")
        for i in np.arange(self.num_c_2):
            accuracy_c2_c3 = 100 * float(self.correct_c2_vs_c3_pred[i]) / self.total_c2_vs_c3_pred[i]
            print(f'Cross-accuracy {self.labels_c_2[i]:5s} vs c3: {accuracy_c2_c3:.2f} %')

        print("")
        for i in np.arange(self.num_c_1):
            accuracy_c1_c3 = 100 * float(self.correct_c1_vs_c3_pred[i]) / self.total_c1_vs_c3_pred[i]
            print(f'Cross-accuracy {self.labels_c_1[i]:5s} vs c3: {accuracy_c1_c3:.2f} %')


    def barplot(self, x, accuracy, labels, title):
        plt.bar(x, accuracy, tick_label = labels)
        plt.xlabel("Classes")
        plt.ylabel("Accuracy")
        plt.title(title)
        plt.show();

    
    def plot_test_results(self):
        # accuracy for each class
        accuracy_c1 = torch.empty(self.num_c_1)
        for i in np.arange(self.num_c_1):
            accuracy_c1[i] = float(self.correct_c1_pred[i]) / self.total_c1_pred[i]
        self.barplot(np.arange(self.num_c_1), accuracy_c1, self.labels_c_1, "Accuracy on the first level")

        accuracy_c2 = torch.empty(self.num_c_2 + 1)
        for i in np.arange(self.num_c_2):
            accuracy_c2[i] = float(self.correct_c2_pred[i]) / self.total_c2_pred[i]
        accuracy_c2[self.num_c_2] = self.correct_c2_pred.sum() / self.total_c2_pred.sum()
        self.barplot(np.arange(self.num_c_2 + 1), accuracy_c2, (*self.labels_c_2, 'overall'), "Accuracy on the second level")

        accuracy_c3 = torch.empty(self.num_c_3 + 1)
        for i in np.arange(self.num_c_3):
            accuracy_c3[i] = float(self.correct_c3_pred[i]) / self.total_c3_pred[i]
        accuracy_c3[self.num_c_3] = self.correct_c3_pred.sum() / self.total_c3_pred.sum()
        self.barplot(np.arange(self.num_c_3 + 1), accuracy_c3, (*self.labels_c_3, 'overall'), "Accuracy on the third level")

    
    def test(self, mode = "print"):
        self.initialize_memory()
        self.eval()

        self.collect_test_performance()

        match mode:
            case "plot":
                self.plot_test_results()
            case "print":
                self.print_test_results()
            case "train":
                accuracy_c1 = self.correct_c1_pred.sum() / self.total_c1_pred.sum()
                accuracy_c2 = self.correct_c2_pred.sum() / self.total_c2_pred.sum()
                accuracy_c3 = self.correct_c3_pred.sum() / self.total_c3_pred.sum()

                self.train()

                return torch.tensor([accuracy_c1, accuracy_c2, accuracy_c3])
            case _:
                raise AttributeError("Test mode not available")
        
    
    def plot_training_loss(self, filename = None):
        plt.figure(figsize=(12, 6))
        plt.plot(np.linspace(1, self.epochs, self.loss_track.size(0)), self.loss_track.numpy())
        plt.title("Training loss")
        plt.xlabel("Epochs")
        plt.ylabel("Error")
        plt.xticks(np.linspace(1, self.epochs, self.epochs)[0::2])
        if filename is not None:
            plt.savefig(filename, bbox_inches='tight')
        plt.show();

    
    def plot_test_accuracy(self, filename = None):
        plt.figure(figsize=(12, 6))
        plt.plot(np.linspace(1, self.epochs, self.accuracy_track.size(0)), self.accuracy_track[:, 0].numpy(), label = "First level")
        plt.plot(np.linspace(1, self.epochs, self.accuracy_track.size(0)), self.accuracy_track[:, 1].numpy(), label = "Second level")
        plt.plot(np.linspace(1, self.epochs, self.accuracy_track.size(0)), self.accuracy_track[:, 2].numpy(), label = "Third level")
        plt.title("Test accuracy")
        plt.xlabel("Epochs")
        plt.ylabel("Accuracy")
        plt.xticks(np.linspace(1, self.epochs, self.epochs)[0::2])
        plt.legend()
        if filename is not None:
            plt.savefig(filename, bbox_inches='tight')
        plt.show();

    
    def save_model(self, path):
        torch.save(self.state_dict(), path)

    
    def load_model(self, path):
        self.load_state_dict(torch.load(path))
        self.eval()

In [140]:
class prova(nn.Module):

    def __init__(self):
        super().__init__()

        self.layer0 = nn.Linear(4, 12)
        self.layer1 = nn.Linear(12, 2)
        self.layer2 = nn.Linear(12, 4)
        self.layer2_1 = nn.Linear(12, 4)
        self.layer3 = nn.Linear(12, 4)
        self.layer3_1 = nn.Linear(12, 4)

        self.criterion = nn.MSELoss()
        self.optimizer = optim.SGD(self.parameters(), lr = 1e-3)

    def forward(self, x):
        z = self.layer0(x)
        z = F.relu(z)
        o1 = self.layer1(z)
        prj2, ort2 = self.project(z, self.layer1.weight.clone().detach())
        o2 = self.layer2(ort2) + self.layer2_1(prj2) 
        prj3, ort3 = self.project(z, torch.vstack((self.layer1.weight.clone().detach(), self.layer2.weight.clone().detach())))
        o3 = self.layer3(ort3) + self.layer3_1(prj2)

        return o1, o2, o3

    
    # Assumption: W is column full rank. 
    def project(self, z, W): #https://math.stackexchange.com/questions/4021915/projection-orthogonal-to-two-vectors
        
        P = solve_matrix_system(torch.matmul(W, W.T), torch.eye(W.size(0)))
        P = torch.matmul(P, W)
        P = torch.eye(W.size(1)) - torch.matmul(W.T, P)

        ort = torch.matmul(z, P)
        prj = z.clone().detach() - ort.clone().detach()
        print(W.mm(ort.T))
        return prj, ort


    def print_forward(self, x):
        o1, o2, o3 = self.forward(x)
        print(o1, end = "\n\n")
        print(o2, end = "\n\n")
        print(o3, end = "\n\n")

In [141]:
p = prova()

In [142]:
x = torch.rand(2, 4)

In [143]:
p.layer1.weight.shape

torch.Size([2, 12])

In [144]:
p(x)

tensor([[ 3.7253e-09, -3.7253e-09],
        [-1.4901e-08, -2.2352e-08]], grad_fn=<MmBackward0>)
tensor([[-1.8626e-08, -4.2841e-08],
        [-1.8626e-09, -5.5879e-09],
        [ 2.6077e-08,  4.4703e-08],
        [ 4.4703e-08,  4.0978e-08],
        [-2.3283e-08, -4.2375e-08],
        [-5.5879e-09, -1.3039e-08]], grad_fn=<MmBackward0>)


(tensor([[-0.2395,  0.3159],
         [-0.2360,  0.3514]], grad_fn=<AddmmBackward0>),
 tensor([[ 0.3891, -0.4173,  0.0448, -0.0094],
         [ 0.5370, -0.2331, -0.0271,  0.0128]], grad_fn=<AddBackward0>),
 tensor([[-0.4176,  0.5648, -0.1441,  0.3095],
         [-0.4094,  0.5595, -0.1750,  0.2645]], grad_fn=<AddBackward0>))

In [91]:
make_dot(p(x), params=dict(list(p.named_parameters()))).render("rnn_torchviz", format="png")

'rnn_torchviz.png'

In [92]:
u = p(x)

In [93]:
u

(tensor([[-0.4516,  0.3955],
         [-0.6314,  0.4501]], grad_fn=<AddmmBackward0>),
 tensor([[ 0.3884, -0.2266,  0.3830, -0.0802],
         [ 0.3142, -0.3167,  0.4863, -0.1240]], grad_fn=<AddBackward0>),
 tensor([[-0.1178,  0.1066,  0.0310, -0.4078],
         [-0.1460,  0.1493, -0.1682, -0.5276]], grad_fn=<AddBackward0>))

In [96]:
y = torch.rand(2,4)

In [97]:
loss2 = p.criterion(u[1], y)

In [98]:
loss2.backward()

In [99]:
p.optimizer.step()

In [100]:
p(x)

(tensor([[-0.4516,  0.3955],
         [-0.6314,  0.4501]], grad_fn=<AddmmBackward0>),
 tensor([[ 0.3887, -0.2255,  0.3833, -0.0791],
         [ 0.3144, -0.3153,  0.4866, -0.1227]], grad_fn=<AddBackward0>),
 tensor([[-0.1177,  0.1064,  0.0310, -0.4079],
         [-0.1460,  0.1490, -0.1681, -0.5277]], grad_fn=<AddBackward0>))

In [102]:
p(x)[0]-u[0]

tensor([[ 3.6955e-06, -2.2650e-06],
        [ 2.4498e-05, -7.1228e-06]], grad_fn=<SubBackward0>)

In [110]:
uu = p(x)

In [111]:
loss3 = p.criterion(uu[2], y)

In [112]:
loss3.backward()

In [113]:
p.optimizer.step()

In [114]:
print(p(x)[0]-uu[0])
print(p(x)[1]-uu[1])

tensor([[3.6061e-05, 1.7196e-05],
        [2.0903e-04, 9.3848e-05]], grad_fn=<SubBackward0>)
tensor([[1.1700e-04, 1.2027e-03, 1.8212e-04, 1.1274e-03],
        [9.7781e-05, 1.4476e-03, 9.4593e-05, 1.2653e-03]],
       grad_fn=<SubBackward0>)


In [115]:
uu

(tensor([[-0.4516,  0.3955],
         [-0.6312,  0.4501]], grad_fn=<AddmmBackward0>),
 tensor([[ 0.3888, -0.2243,  0.3835, -0.0780],
         [ 0.3146, -0.3139,  0.4868, -0.1214]], grad_fn=<AddBackward0>),
 tensor([[-0.1156,  0.1071,  0.0330, -0.4047],
         [-0.1437,  0.1498, -0.1657, -0.5239]], grad_fn=<AddBackward0>))

In [116]:
p(x)

(tensor([[-0.4516,  0.3955],
         [-0.6310,  0.4502]], grad_fn=<AddmmBackward0>),
 tensor([[ 0.3889, -0.2231,  0.3837, -0.0768],
         [ 0.3147, -0.3125,  0.4869, -0.1202]], grad_fn=<AddBackward0>),
 tensor([[-0.1125,  0.1083,  0.0360, -0.3998],
         [-0.1403,  0.1511, -0.1621, -0.5182]], grad_fn=<AddBackward0>))

In [117]:
print(p(x)[2]-uu[2])

tensor([[0.0031, 0.0012, 0.0030, 0.0049],
        [0.0034, 0.0013, 0.0036, 0.0057]], grad_fn=<SubBackward0>)
