In [1]:
import torch
from torch import nn
from transformers import BertModel, BertTokenizer

from mlguess.torch.class_losses import relu_evidence



### Example usage for K-class problem

In [2]:
class DNABert(nn.Module):
    def __init__(self, n_classes):
        super(DNABert, self).__init__()
        self.n_classes = n_classes
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # note how we only take one hidden state from the sequeunce, which corresponds with the CLS token
        cls_hidden_state = outputs.last_hidden_state[:, 0, :]
        
        out = self.fc(cls_hidden_state)
        return out
    
    def predict_uncertainty(self, input_ids, attention_mask, token_type_ids=None):
        y_pred = self(input_ids, attention_mask, token_type_ids)
        
        # dempster-shafer theory
        evidence = relu_evidence(y_pred) # can also try softplus and exp evidence schemes
        alpha = evidence + 1
        S = torch.sum(alpha, dim=1, keepdim=True)
        u = self.n_classes / S
        prob = alpha / S
        
        # law of total uncertainty 
        epistemic = prob * (1 - prob) / (S + 1)
        aleatoric = prob - prob**2 - epistemic
        return prob, u, aleatoric, epistemic

In [3]:
# Initialize the model
num_classes = 10

model = DNABert(n_classes=num_classes)

dna_sequence = "AGCTAGCTAGCT"

# We need to convert the DNA sequence to the format expected by BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
inputs = tokenizer(dna_sequence, return_tensors='pt')

# Forward pass through the model
outputs = model(**inputs)

In [4]:
inputs

{'input_ids': tensor([[  101, 12943, 25572, 18195, 15900,  6593,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [5]:
outputs

tensor([[ 0.5873,  0.1587,  0.2231, -0.0498,  0.5132, -0.6980, -0.0172, -0.1390,
          0.7385,  0.0981]], grad_fn=<AddmmBackward0>)

In [6]:
prob, u, aleatoric, epistemic = model.predict_uncertainty(**inputs)

In [7]:
prob

tensor([[0.1289, 0.0941, 0.0993, 0.0812, 0.1228, 0.0812, 0.0812, 0.0812, 0.1411,
         0.0891]], grad_fn=<DivBackward0>)

In [8]:
u

tensor([[0.8118]], grad_fn=<MulBackward0>)

In [9]:
aleatoric

tensor([[0.1038, 0.0788, 0.0827, 0.0690, 0.0997, 0.0690, 0.0690, 0.0690, 0.1121,
         0.0751]], grad_fn=<SubBackward0>)

In [10]:
epistemic

tensor([[0.0084, 0.0064, 0.0067, 0.0056, 0.0081, 0.0056, 0.0056, 0.0056, 0.0091,
         0.0061]], grad_fn=<DivBackward0>)

### Evidential loss

In [11]:
from mlguess.torch.class_losses import edl_digamma_loss, edl_log_loss, edl_mse_loss

In [12]:
loss = "digamma"
annealing_coefficient = 10.
epoch = 0
device = "cpu"

In [13]:
if loss == "digamma":
    criterion = edl_digamma_loss
elif loss == "log":
    criterion = edl_log_loss
elif loss == "mse":
    criterion = edl_mse_loss
else:
    logging.error("--uncertainty requires --mse, --log or --digamma.")

In [14]:
y_true_hot = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])

loss = criterion(
    outputs,
    y_true_hot.float(), 
    epoch, 
    num_classes, 
    annealing_coefficient, 
    device
)

In [15]:
loss

tensor(2.3549, grad_fn=<MeanBackward0>)

In [17]:
# loss.backward

### Regression example

In [41]:
import numpy as np

In [42]:
class LinearNormalGamma(nn.Module):
    def __init__(self, in_chanels, out_channels):
        super().__init__()
        self.linear = nn.Linear(in_chanels, out_channels*4)

    def evidence(self, x):
        return  torch.log(torch.exp(x) + 1)

    def forward(self, x):
        pred = self.linear(x).view(x.shape[0], -1, 4)
        mu, logv, logalpha, logbeta = [w.squeeze(-1) for w in torch.split(pred, 1, dim=-1)]
        return mu, self.evidence(logv), self.evidence(logalpha) + 1, self.evidence(logbeta)

In [59]:
class DNABertRegressor(nn.Module):
    def __init__(self, n_tasks, training_var = [1.0]):
        super(DNABertRegressor, self).__init__()
        self.output_size = n_tasks
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = LinearNormalGamma(self.bert.config.hidden_size, self.output_size)
        self.training_var = training_var

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # note how we only take one hidden state from the sequeunce, which corresponds with the CLS token
        cls_hidden_state = outputs.last_hidden_state[:, 0, :]
        
        out = self.fc(cls_hidden_state)
        return out

    def predict_uncertainty(self, input_ids, attention_mask, token_type_ids=None, y_scaler=None):
        mu, v, alpha, beta = self(input_ids, attention_mask, token_type_ids)
        aleatoric = beta / (alpha - 1)
        epistemic = beta / (v * (alpha - 1))

        if len(mu.shape) == 1:
            mu = np.expand_dims(mu, 1)
            aleatoric = np.expand_dims(aleatoric, 1)
            epistemic = np.expand_dims(epistemic, 1)

        if y_scaler:
            mu = y_scaler.inverse_transform(mu)

        for i in range(mu.shape[-1]):
            aleatoric[:, i] *= self.training_var[i]
            epistemic[:, i] *= self.training_var[i]

        return mu, aleatoric, epistemic

In [60]:
# gamma, nu, alpha, beta = pred
# loss = nll_loss(gamma, nu, alpha, beta, labels)
# loss += reg(gamma, nu, alpha, beta, labels)
# loss += mmse_loss(gamma, nu, alpha, beta, labels)

##### Compute binder score training variance: np.var(scores)

In [67]:
model = DNABertRegressor(n_tasks=1, training_var = [0.0033433])

dna_sequence = "AGCTAGCTAGCT"

# We need to convert the DNA sequence to the format expected by BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
inputs = tokenizer(dna_sequence, return_tensors='pt')

# Forward pass through the model
gamma, nu, alpha, beta = model(**inputs)

In [65]:
mu, a, e = model.predict_uncertainty(**inputs)

In [66]:
mu

tensor([[0.2118]], grad_fn=<SqueezeBackward1>)

### Regression losses

In [83]:
tol = 1e-8

def modified_mse(gamma, nu, alpha, beta, target, reduction='mean'):
    """
    Lipschitz MSE loss of the "Improving evidential deep learning via multitask learning."

    Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
    Source: https://github.com/deargen/MT-ENet/tree/468822188f52e517b1ee8e386eea607b2b7d8829

    Args:
        gamma ([FloatTensor]): the output of the ENet.
        nu ([FloatTensor]): the output of the ENet.
        alpha ([FloatTensor]): the output of the ENet.
        beta ([FloatTensor]): the output of the ENet.
        target ([FloatTensor]): true labels.
        reduction (str, optional): . Defaults to 'mean'.
    Returns:
        [FloatTensor]: The loss value. 
    """
    mse = (gamma-target)**2
    c = get_mse_coef(gamma, nu, alpha, beta, target).detach()
    mod_mse = mse*c
    
    if reduction == 'mean': 
        return mod_mse.mean()
    elif reduction == 'sum':
        return mod_mse.sum()
    else:
        return mod_mse

def get_mse_coef(gamma, nu, alpha, beta, y):
    """
    Return the coefficient of the MSE loss for each prediction.
    By assigning the coefficient to each MSE value, it clips the gradient of the MSE
    based on the threshold values U_nu, U_alpha, which are calculated by check_mse_efficiency_* functions.

    Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
    Source: https://github.com/deargen/MT-ENet/tree/468822188f52e517b1ee8e386eea607b2b7d8829

    Args:
        gamma ([FloatTensor]): the output of the ENet.
        nu ([FloatTensor]): the output of the ENet.
        alpha ([FloatTensor]): the output of the ENet.
        beta ([FloatTensor]): the output of the ENet.
        y ([FloatTensor]): true labels.
    Returns:
        [FloatTensor]: [0.0-1.0], the coefficient of the MSE for each prediction.
    """
    alpha_eff = check_mse_efficiency_alpha(nu, alpha, beta)
    nu_eff = check_mse_efficiency_nu(gamma, nu, alpha, beta)
    delta = (gamma - y).abs()
    min_bound = torch.min(nu_eff, alpha_eff).min()
    c = (min_bound.sqrt()/(delta + tol)).detach()
    return torch.clip(c, min=False, max=1.)


def check_mse_efficiency_alpha(nu, alpha, beta):
    """
    Check the MSE loss (gamma - y)^2 can make negative gradients for alpha, which is
    a pseudo observation of the normal-inverse-gamma. We can use this to check the MSE
    loss can success(increase the pseudo observation, alpha).

    Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
    Source: https://github.com/deargen/MT-ENet/tree/468822188f52e517b1ee8e386eea607b2b7d8829

    Args:
        nu (torch.Tensor): nu output value of the evidential network
        alpha (torch.Tensor): alpha output value of the evidential network
        beta (torch.Tensor): beta output value of the evidential network

    Return:
        partial f / partial alpha(numpy.array) 
        where f => the NLL loss (BayesianDTI.loss.MarginalLikelihood)
    
    """
    right = (torch.exp((torch.digamma(alpha+0.5)-torch.digamma(alpha))) - 1)*2*beta*(1+nu) / (nu + 1e-8)
    return right.detach()


def check_mse_efficiency_nu(gamma, nu, alpha, beta):
    """
    Check the MSE loss (gamma - y)^2 can make negative gradients for nu, which is
    a pseudo observation of the normal-inverse-gamma. We can use this to check the MSE
    loss can success(increase the pseudo observation, nu).

    Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
    Source: https://github.com/deargen/MT-ENet/tree/468822188f52e517b1ee8e386eea607b2b7d8829

    Args:
        gamma (torch.Tensor): gamma output value of the evidential network
        nu (torch.Tensor): nu output value of the evidential network
        alpha (torch.Tensor): alpha output value of the evidential network
        beta (torch.Tensor): beta output value of the evidential network
    
    Return:
        partial f / partial nu(torch.Tensor) 
        where f => the NLL loss (BayesianDTI.loss.MarginalLikelihood)
    """
    gamma, nu, alpha, beta = gamma.detach(), nu.detach(), alpha.detach(), beta.detach()
    nu_1 = (nu + 1) / (nu + tol)
    return beta * nu_1 / (alpha + tol)

        
class EvidentialMarginalLikelihood(torch.nn.modules.loss._Loss):
    """
    Marginal likelihood error of prior network.
    The target value is not a distribution (mu, std), but a just value.
    
    This is a negative log marginal likelihood, with integral mu and sigma.

    Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
    Source: https://github.com/deargen/MT-ENet/tree/468822188f52e517b1ee8e386eea607b2b7d8829
    """
    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean'):
        super(EvidentialMarginalLikelihood, self).__init__(size_average, reduce, reduction)
    
    def forward(self, gamma: torch.Tensor, nu: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            gamma (torch.Tensor): gamma output value of the evidential network
            nu (torch.Tensor): nu output value of the evidential network
            alpha (torch.Tensor): alpha output value of the evidential network
            beta (torch.Tensor): beta output value of the evidential network
            target (torch.Tensor): target value
            
        Return:
            (Tensor) Negative log marginal likelihood of EvidentialNet
                p(y|m) = Student-t(y; gamma, (beta(1+nu))/(nu*alpha) , 2*alpha)
                then, the negative log likelihood is (CAUTION QUITE COMPLEX!)
                NLL = -log(p(y|m)) =
                    log(3.14/nu)*0.5 - alpha*log(2*beta*(1 + nu)) + (alpha + 0.5)*log( nu(target - gamma)^2 + 2*beta(1 + nu) )
                    + log(GammaFunc(alpha)/GammaFunc(alpha + 0.5))
        """
        pi = torch.tensor(np.pi)
        x1 = torch.log(pi/(nu + tol))*0.5
        x2 = -alpha*torch.log(2.*beta*(1.+ nu) + tol)
        x3 = (alpha + 0.5)*torch.log( nu*(target - gamma)**2 + 2.*beta*(1. + nu) + tol)
        x4 = torch.lgamma(alpha + tol) - torch.lgamma(alpha + 0.5 + tol)
        if self.reduction == 'mean': 
            return (x1 + x2 + x3 + x4).mean()
        elif self.reduction == 'sum':
            return (x1 + x2 + x3 + x4).sum()
        else:
            return x1 + x2 + x3 + x4

class EvidenceRegularizer(torch.nn.modules.loss._Loss):
    """
    Regularization for the regression prior network.
    If self.factor increases, the model output the wider(high confidence interval) predictions.

    Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
    Source: https://github.com/deargen/MT-ENet/tree/468822188f52e517b1ee8e386eea607b2b7d8829
    """
    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', factor=0.1):
        super(EvidenceRegularizer, self).__init__(size_average, reduce, reduction)
        self.factor = factor
    
    def forward(self, gamma: torch.Tensor, nu: torch.Tensor, alpha: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            gamma (torch.Tensor): gamma output value of the evidential network
            nu (torch.Tensor): nu output value of the evidential network
            alpha (torch.Tensor): alpha output value of the evidential network
            target (torch.Tensor): target value

        Return:
            (Tensor) prior network regularization
            Loss = |y - gamma|*(2*nu + alpha) * factor
            
        """
        loss_value =  torch.abs(target - gamma)*(2*nu + alpha) * self.factor
        if self.reduction == 'mean': 
            return loss_value.mean()
        elif self.reduction == 'sum':
            return loss_value.sum()
        else:
            return loss_value
    

In [91]:
y_pred = torch.Tensor([1.45456457])

In [92]:
loss = EvidentialMarginalLikelihood()(gamma, nu, alpha, beta, y_pred)
loss += EvidenceRegularizer()(gamma, nu, alpha, y_pred)
loss += modified_mse(gamma, nu, alpha, beta, y_pred)

In [93]:
loss

tensor(4.5012, grad_fn=<AddBackward0>)

In [94]:
mae = torch.nn.L1Loss()(y_pred, gamma)

  return F.l1_loss(input, target, reduction=self.reduction)


In [95]:
mae

tensor(1.6892, grad_fn=<MeanBackward0>)