In [None]:
import math
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import Polytope_basics.simplex_coordinates2 as simplex_coordinates2
verbose = False

In [None]:
class Fixed_weight_loss(nn.Module):
    def __init__(self, s=2.0, in_feature=10,out_feature=10):
        super(Fixed_weight_loss, self).__init__()

        self.d = out_feature - 1
        m = 0.5 #math.acos(-1/self.d)
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.alpha = (1-math.sqrt(self.d +1))/self.d
        self.s = s
        vertex_simplex= torch.Tensor(simplex_coordinates2.simplex_coordinates2(self.d)).permute(1,0)
        
        lvalue = torch.stack([self.alpha*sum(vertex_simplex[i,:]) for i in range(vertex_simplex.size()[0])])
        self.weight = Parameter(torch.cat((vertex_simplex,lvalue.unsqueeze(1)),1))
        self.weight.requires_grad = False
        #nn.init.xavier_uniform_(self.weight)
        if verbose: print("Norm_Arc_loss _ weight matrix {}".format(self.weight.size()))
        #from arcface
        #make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        
        print("margin {}  cos_m {}  sin_m {}".format(m,self.sin_m,self.cos_m))

    def saveFigure(self,data,epoch,batch_id,folder_name,name_var):
        classes = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"] 
        
        fig, ax = plt.subplots()
        A = data[1,:,:].cpu().detach().numpy()
        im = ax.imshow(A)
        cbar = ax.figure.colorbar(im)
        ax.set_xticks(np.arange(10),(classes)) 
        ax.set_yticks(np.arange(10),(classes))
        plt.savefig(folder_name+"/"+str(batch_id)+name_var+str(epoch)+".jpg")
        

        
    def arc_loss(self, x, label,epoch,batch_id,folder_name,val = 0):
        if verbose: print("ARC LOSS")
        cosine = []   
        for i in range(len(x)):
            x_i = x[i,:]
            if (x_i.size()) != self.weight.size():
                print("x dimension and weight dimension do not match")
                break
            cosine_i = F.linear(F.normalize(x_i),F.normalize(self.weight))
            cosine.append(cosine_i)
        self.cosine = torch.stack(cosine)
        #if batch_id ==0: self.saveFigure(self.cosine,epoch,batch_id,folder_name,"cosine")
                
        self.sine = torch.sqrt(1.0 - torch.pow(self.cosine, 2))      
        #if batch_id ==0: self.saveFigure(self.sine,epoch,batch_id,folder_name,"sine")
            
        self.phi = self.cosine * self.cos_m - self.sine * self.sin_m
        #if batch_id ==0: self.saveFigure(self.phi,epoch,batch_id,folder_name,"phi")                
        #if batch_id ==0: self.saveFigure(label,epoch,batch_id,folder_name,"label")                
          
        if val == 0:
            output = (label * self.phi) + ((1.0 - label) * self.cosine)
        else:
            output = self.cosine
        output = output * self.s
        #if batch_id ==0: self.saveFigure(output,epoch,batch_id,folder_name,"output")
                
        return output
    
    def digit_angle_loss(self,x,labels):
        margin = np.repeat(0.5,x[0,:].size()[1]-1)
        loss_cosine = []
        for i in range(len(x)):
            ind = (labels[i,:] == 1).nonzero()
            x_class = (x[i,ind,:].squeeze()).unsqueeze(1).view(1,-1)
            if i ==0: print(ind)
            cosine = []
            for b in range(len(x[i,:])):
                if not ind == b:
                    x_b = x[i,b,:].view(1,-1)
                    cosine_b = F.linear(F.normalize(x_b),F.normalize(x_class))
                    cosine.append(math.acos(cosine_b))
            loss_cosine.append(cosine-margin)
        loss_cosine = torch.tensor(loss_cosine)
        loss_cosine.requires_grad = True
        loss = loss_cosine.sum(dim=1).mean()  
        
        return loss
    
    #margin loss di capsule
    def margin_loss(self, x, labels, size_average=True):
        if verbose: print("x {}".format(x.size()))
        if verbose: print("labels {}".format(labels.size()))
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True)) #<-L2
        if verbose: print("v_c {}".format(v_c.size()))
        left = F.relu(0.9 - v_c).view(batch_size, -1) #**2
        right = F.relu(v_c - 0.1).view(batch_size, -1) #**2

        loss = labels * left + 0.5 * (1.0 - labels) * right

        loss = loss.sum(dim=1).mean()

        return loss
    
    def forward(self,x,labels,L_angle):
        L_margin = self.margin_loss(x,labels)
        
        loss = L_angle + L_margin
    