In [1]:
import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score

In [18]:
class SoftAUC():
    def __init__(self,scale=10.0):
        self.scale=scale
    def loss(self,preds,labels):
        preds=torch.exp(preds[:,1])
        scale=self.scale
        n=labels.shape[0]
        n2=n*n
        np=n*(n-1)/2
        relu=torch.nn.ReLU()
        labels=torch.tensor(labels,dtype=torch.float).unsqueeze(-1)
        preds=preds.unsqueeze(-1)
        e=torch.ones((n,1),dtype=torch.float)
        yr=torch.matmul(e,labels.t())
        yc=torch.matmul(labels,e.t())
        pr=torch.matmul(e,preds.t())
        pc=torch.matmul(preds,e.t())
        print(labels.t())
        D=torch.tanh((yr-yc))*torch.tanh(pr-pc)
        #C=scale*D
        #C1=C+abs(C)
        #CI=torch.sum(C)/torch.sum(abs(C))
        NA=torch.tanh(yr-yc)
        PNA=torch.tanh(pr-pc)
        CNA=torch.tanh(scale*relu(NA*PNA))
        A=torch.ones((n,n))-torch.abs(NA)
        PA=torch.ones((n,n))-torch.abs(PNA)
        CA=torch.tanh(scale*relu(A*PA))
        #print('NA',NA)
        #print('PNA',PNA)
        #print('CNA',CNA)
        #print('A:',A)
        #print('PA:',PA)
        #print('CA',CA)
        SA=torch.sum(CA+CNA)/n2
        return -SA

All ones: $e=[1,1,\ldots,1]^T \in \mathbb{R}^{n\times 1}$

Broadcast vector $x \in \mathbb{R}^{n\times 1}$ along rows/columns: $x_r = e x^T$, $x_c = x e^T$ 

Labels not agreeing: $Y_{NA}=\tanh(y_r - y_c)$ (non-zero: \~ 1/-1 where labels don't agree)

Predictions not agreeing: $P_{NA}=\tanh(p_r - p_c)$ (larger pos/neg where predictions don't agree.

Labels not agreeing and predictions also not agreeing: $C_{NA}=\tanh(\sigma \mathtt{ReLU}(Y_{NA}P_{NA}))$

Labels agreeing: $Y_A = 1-\lvert Y_{NA} \rvert$ (1 where labels agree)

Predictions agreeing $P_{A} = 1 - \lvert P_{NA} \rvert$ (large pos where predictions agree)

Labels agreeing and predictions also agreeing: $C_{A} = \tanh(\sigma\mathtt{ReLU}(Y_A P_A))$

Soft Concordance Index = $C = C_A + C_{NA}$

Use loss (SoftAUC) as $-C$

In [3]:
class WeightedLoss():
    def __init__(self,lossfn1,lossfn2,a=1,b=1):
        self.lossfn1=lossfn1
        self.lossfn2=lossfn2
        self.a=a
        self.b=b
    def loss(self,preds,labels):
        a=self.a
        b=self.b
        return (a*self.lossfn1(preds,labels)+b*self.lossfn2(preds,labels))/(a+b)