In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchviz import make_dot
from torch.linalg import vector_norm as vnorm
from torch.linalg import solve as solve_matrix_system

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Lambda

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm 

from telegramBot import Terminator

num_cores = 8
torch.set_num_interop_threads(num_cores) # Inter-op parallelism
torch.set_num_threads(num_cores) # Intra-op parallelism

In [None]:
class prova(nn.Module):

    def __init__(self):
        super().__init__()

        self.layer0 = nn.Linear(4, 12)
        self.layer1 = nn.Linear(12, 2)
        self.layer2 = nn.Linear(12, 4)
        self.layer2_1 = nn.Linear(12, 4)
        self.layer3 = nn.Linear(12, 4)
        self.layer3_1 = nn.Linear(12, 4)

        self.criterion = nn.MSELoss()
        self.optimizer = optim.SGD(self.parameters(), lr = 1e-3)

    def forward(self, x):
        z = self.layer0(x)
        z = F.relu(z)
        o1 = self.layer1(z)
        prj2, ort2, prj3, ort3 = self.project(z)
        o2 = self.layer2(ort2) + self.layer2_1(prj2) 
        o3 = self.layer3(ort3) + self.layer3_1(prj3)

        return o1, o2, o3

    
    # Assumption: W is column full rank. 
    def project(self, z): #https://math.stackexchange.com/questions/4021915/projection-orthogonal-to-two-vectors

        W1 = self.layer1.weight.clone().detach()
        W2 = self.layer2.weight.clone().detach()
        ort2 = torch.empty_like(z)
        ort3 = torch.empty_like(z)

        for i, zi in enumerate(z):
            Rk = torch.diag(torch.where(zi.clone().detach() != 0, 1.0, 0.0))
            W1k = W1.mm(Rk)
            W2k_ = W2.mm(Rk)
            W2k = torch.vstack((W1k, W2k_))
            ort2[i,:] = self.compute_othogonal(zi, W1k)
            ort3[i,:] = self.compute_othogonal(zi, W2k)
            
        prj2 = z.clone().detach() - ort2.clone().detach()
        prj3 = z.clone().detach() - ort3.clone().detach()
        
        return prj2, ort2, prj3, ort3

    def compute_othogonal(self, z, W, eps = 1e-8):
        WWT = torch.matmul(W, W.T)
        P = solve_matrix_system(WWT + torch.randn_like(WWT) * eps, torch.eye(W.size(0)))
        P = torch.matmul(P, W)
        P = torch.eye(W.size(1)) - torch.matmul(W.T, P)
        
        return torch.matmul(z, P)

    def print_forward(self, x):
        o1, o2, o3 = self.forward(x)
        print(o1, end = "\n\n")
        print(o2, end = "\n\n")
        print(o3, end = "\n\n")

In [None]:
p = prova()
x = torch.rand(2, 4)
u = p(x)
y = torch.rand(2,4)

In [None]:
#make_dot(p(x), params=dict(list(p.named_parameters()))).render("../imgs/hcnn3_torchviz", format="png")

In [None]:
u = p(x)
p.optimizer.zero_grad()
loss2 = p.criterion(u[1], y)
loss2.backward()
p.optimizer.step()
p(x)[0]-u[0]

In [None]:
uu = p(x)
p.optimizer.zero_grad()
loss3 = p.criterion(uu[2], y)
loss3.backward()
print(p.layer2.weight.grad)
p.optimizer.step()
print(p(x)[0]-uu[0])
print(p(x)[1]-uu[1])