In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import DistilBertTokenizer, DistilBertModel

class DistilBertReFixMatch(nn.Module):
    def __init__(self, num_labels):
        super(DistilBertReFixMatch, self).__init__()
        self.Model = DistilBertModel.from_pretrained("distilbert-base-uncased")
        hidden_size = self.Model.config.dim
        self.Classifier = nn.Linear(hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        out = self.Model(input_ids, attention_mask=attention_mask)
        hidden_state = out.last_hidden_state[:, 0]
        logits = self.Classifier(hidden_state)
        return logits

def predict(text, model, tokenizer, label_list, max_length=256, device="cpu"):
    encoding = tokenizer.encode_plus(
        text,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
        padding="max_length"
    )
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)
    model.eval()
    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probs = F.softmax(logits, dim=1)
        prediction_idx = torch.argmax(probs, dim=1).item()
    return label_list[prediction_idx], probs.cpu().numpy()

device = torch.device("cuda")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
label_list = ['BACKGROUND', 'CONCLUSIONS', 'METHODS', 'OBJECTIVE', 'RESULTS']
num_labels = len(label_list)
model = DistilBertReFixMatch(num_labels)
weights_path = r"C:\Users\clash\Downloads\mkc\Refix Match\ReFixMatch_PubMed_RCT_80000_labeled.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)
eval_text = ("Minimally invasive endovascular aneurysm repair ( EVAR ) could be a surgical technique that improves outcome of patients with ruptured abdominal aortic aneurysm ( rAAA ) .")
predicted_label, probabilities = predict(eval_text, model, tokenizer, label_list, device=device)
print("Input Text:")
print(eval_text)
print("\nPredicted Label:", predicted_label)

Input Text:
Minimally invasive endovascular aneurysm repair ( EVAR ) could be a surgical technique that improves outcome of patients with ruptured abdominal aortic aneurysm ( rAAA ) .

Predicted Label: BACKGROUND
