In [21]:
import torch
from torch import nn
from transformers import BertModel, BertTokenizer

from mlguess.torch.class_losses import relu_evidence

### Example usage for K-class problem

In [41]:
class DNABert(nn.Module):
    def __init__(self, n_classes):
        super(DNABert, self).__init__()
        self.n_classes = n_classes
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # note how we only take one hidden state from the sequeunce, which corresponds with the CLS token
        cls_hidden_state = outputs.last_hidden_state[:, 0, :]
        
        out = self.fc(cls_hidden_state)
        return out
    
    def predict_uncertainty(self, input_ids, attention_mask, token_type_ids=None):
        y_pred = self(input_ids, attention_mask, token_type_ids)
        
        # dempster-shafer theory
        evidence = relu_evidence(outputs) # can also try softplus and exp evidence schemes
        alpha = evidence + 1
        S = torch.sum(alpha, dim=1, keepdim=True)
        u = self.n_classes / S
        prob = alpha / S
        
        # law of total uncertainty 
        epistemic = prob * (1 - prob) / (S + 1)
        aleatoric = prob - prob**2 - epistemic
        return prob, u, aleatoric, epistemic

In [55]:
# Initialize the model
num_classes = 10

model = DNABert(n_classes=num_classes)

dna_sequence = "AGCTAGCTAGCT"

# We need to convert the DNA sequence to the format expected by BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
inputs = tokenizer(dna_sequence, return_tensors='pt')

# Forward pass through the model
outputs = model(**inputs)

In [43]:
inputs

{'input_ids': tensor([[  101, 12943, 25572, 18195, 15900,  6593,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [44]:
outputs

tensor([[-0.7508, -0.6081, -0.0026, -0.0115,  0.1004,  0.1924, -0.4315, -0.0052,
          0.0900,  0.8016]], grad_fn=<AddmmBackward0>)

In [45]:
prob, u, aleatoric, epistemic = model.predict_uncertainty(**inputs)

In [46]:
prob

tensor([[0.0894, 0.0894, 0.0894, 0.0894, 0.0984, 0.1066, 0.0894, 0.0894, 0.0975,
         0.1611]], grad_fn=<DivBackward0>)

In [47]:
u

tensor([[0.8941]], grad_fn=<MulBackward0>)

In [48]:
aleatoric

tensor([[0.0747, 0.0747, 0.0747, 0.0747, 0.0814, 0.0874, 0.0747, 0.0747, 0.0807,
         0.1240]], grad_fn=<SubBackward0>)

In [50]:
epistemic

tensor([[0.0067, 0.0067, 0.0067, 0.0067, 0.0073, 0.0078, 0.0067, 0.0067, 0.0072,
         0.0111]], grad_fn=<DivBackward0>)

### Evidential loss

In [52]:
from mlguess.torch.class_losses import edl_digamma_loss, edl_log_loss, edl_mse_loss

In [59]:
loss = "digamma"
annealing_coefficient = 10.
epoch = 0
device = "cpu"

In [54]:
if loss == "digamma":
    criterion = edl_digamma_loss
elif loss == "log":
    criterion = edl_log_loss
elif loss == "mse":
    criterion = edl_mse_loss
else:
    logging.error("--uncertainty requires --mse, --log or --digamma.")

In [60]:
y_true_hot = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])

loss = criterion(
    outputs,
    y_true_hot.float(), 
    epoch, 
    num_classes, 
    annealing_coefficient, 
    device
)

In [61]:
loss

tensor(2.8403, grad_fn=<MeanBackward0>)

In [None]:
# loss.backward