In [None]:
import torch
from torch import nn
from transformers import BertModel, BertTokenizer

from mlguess.torch.class_losses import relu_evidence



### Example usage for K-class problem

In [None]:
class DNABert(nn.Module):
    def __init__(self, n_classes):
        super(DNABert, self).__init__()
        self.n_classes = n_classes
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # note how we only take one hidden state from the sequeunce, which corresponds with the CLS token
        cls_hidden_state = outputs.last_hidden_state[:, 0, :]
        
        out = self.fc(cls_hidden_state)
        return out
    
    def predict_uncertainty(self, input_ids, attention_mask, token_type_ids=None):
        y_pred = self(input_ids, attention_mask, token_type_ids)
        
        # dempster-shafer theory
        evidence = relu_evidence(y_pred) # can also try softplus and exp evidence schemes
        alpha = evidence + 1
        S = torch.sum(alpha, dim=1, keepdim=True)
        u = self.n_classes / S
        prob = alpha / S
        
        # law of total uncertainty 
        epistemic = prob * (1 - prob) / (S + 1)
        aleatoric = prob - prob**2 - epistemic
        return prob, u, aleatoric, epistemic

In [None]:
# Initialize the model
num_classes = 10

model = DNABert(n_classes=num_classes)

dna_sequence = "AGCTAGCTAGCT"

# We need to convert the DNA sequence to the format expected by BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
inputs = tokenizer(dna_sequence, return_tensors='pt')

# Forward pass through the model
outputs = model(**inputs)

In [None]:
inputs

In [None]:
outputs

In [None]:
prob, u, aleatoric, epistemic = model.predict_uncertainty(**inputs)

In [None]:
prob

In [None]:
u

In [None]:
aleatoric

In [None]:
epistemic

### Evidential loss

In [None]:
from mlguess.torch.class_losses import edl_digamma_loss, edl_log_loss, edl_mse_loss

In [None]:
loss = "digamma"
annealing_coefficient = 10.
epoch = 0
device = "cpu"

In [None]:
if loss == "digamma":
    criterion = edl_digamma_loss
elif loss == "log":
    criterion = edl_log_loss
elif loss == "mse":
    criterion = edl_mse_loss
else:
    logging.error("--uncertainty requires --mse, --log or --digamma.")

In [None]:
y_true_hot = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])

loss = criterion(
    outputs,
    y_true_hot.float(), 
    epoch, 
    num_classes, 
    annealing_coefficient, 
    device
)

In [None]:
loss

In [None]:
# loss.backward