In [1]:
# !python reassembler_BERT.py
from transformers import AutoTokenizer
import torch
from transformers import BertModel
import torch.nn.functional as F
from safetensors.torch import load_file

# BERT model definition
class BertLSTMClassifier(torch.nn.Module):
    def __init__(self, model_name, num_labels=2, hidden_size=768, lstm_hidden_size=256, num_lstm_layers=1):
        super(BertLSTMClassifier, self).__init__()
        self.num_labels = num_labels
        self.bert = BertModel.from_pretrained(model_name)
        self.lstm = torch.nn.LSTM(input_size=hidden_size,hidden_size=lstm_hidden_size,num_layers=num_lstm_layers,batch_first=True,bidirectional=False)
        self.classifier = torch.nn.Linear(lstm_hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state 
        lstm_output, _ = self.lstm(sequence_output) 
        lstm_output = lstm_output[:, -1, :] 
        logits = self.classifier(lstm_output) 
        return logits

model_dir = './'
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = BertLSTMClassifier(model_name='bert-base-uncased', num_labels=2)


model_state_dict = load_file(f'{model_dir}BERTLSTM.safetensors')
model.load_state_dict(model_state_dict)
model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# prediction
def predict(text):
    encoding = tokenizer(
        text,
        truncation=True,
        max_length=256,
        padding='max_length',
        return_tensors='pt',
    )
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        probabilities = F.softmax(logits, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
    return predicted_class, probabilities.cpu().numpy()

# Example 
text = "The product did not meet my expectations."
predicted_class, probabilities = predict(text)
print(f"Predicted class: {predicted_class}")
print(f"Probabilities: {probabilities}")







Predicted class: 0
Probabilities: [[0.99386686 0.00613318]]
