In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import DistilBertTokenizer, DistilBertModel

class DistilBertSeqMatch(nn.Module):
    def __init__(self, num_labels):
        super(DistilBertSeqMatch, self).__init__()
        self.Backbone = DistilBertModel.from_pretrained("distilbert-base-uncased")
        hidden_size = self.Backbone.config.dim
        self.Classifier = nn.Linear(hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        out = self.Backbone(input_ids, attention_mask=attention_mask)
        hidden_state = out.last_hidden_state[:, 0]
        logits = self.Classifier(hidden_state)
        return logits

def predict(text, model, tokenizer, label_list, max_length=256, device="cpu"):
    encoding = tokenizer.encode_plus(
        text,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
        padding="max_length"
    )
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)
    model.eval()
    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probs = F.softmax(logits, dim=1)
        prediction_idx = torch.argmax(probs, dim=1).item()
    return label_list[prediction_idx], probs.cpu().numpy()

device = torch.device("cuda")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
label_list = ['BACKGROUND', 'CONCLUSIONS', 'METHODS', 'OBJECTIVE', 'RESULTS']
num_labels = len(label_list)
model = DistilBertSeqMatch(num_labels)
weights_path = r"C:\Users\clash\Downloads\mkc\Sequence Match\SequenceMatch_PubMedRCT_40000_labeled.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)
eval_text = ("When mean hemodynamic profiles were compared in patients with abnormal versus normal LFTs , elevated total bilirubin was associated with a significantly lower cardiac index ( 1.80 vs 2.1 ; P < .001 ) and higher central venous pressure ( 14.2 vs 12.0 ; P = .03 ) .")
predicted_label, probabilities = predict(eval_text, model, tokenizer, label_list, device=device)
print("Input Text:")
print(eval_text)
print("\nPredicted Label:", predicted_label)

Input Text:
When mean hemodynamic profiles were compared in patients with abnormal versus normal LFTs , elevated total bilirubin was associated with a significantly lower cardiac index ( 1.80 vs 2.1 ; P < .001 ) and higher central venous pressure ( 14.2 vs 12.0 ; P = .03 ) .

Predicted Label: RESULTS
