In [None]:
from model import JointModel
from transformers import XLMRobertaTokenizerFast
import torch, pickle, os

import torch

def predict_single_instance(model, tokenizer, sentence_tokens, dataset_encoders, max_length=128):
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    encoding = tokenizer(sentence_tokens, is_split_into_words=True, return_tensors='pt',
                         padding='max_length', truncation=True, max_length=max_length)
    word_ids = encoding.word_ids()
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        trigger_logits, arg_logits, event_logits = model(input_ids, attention_mask)
        trigger_preds = torch.argmax(trigger_logits, dim=-1).cpu().numpy()[0]
        arg_preds = torch.argmax(arg_logits, dim=-1).cpu().numpy()[0]
        event_pred = torch.argmax(event_logits, dim=-1).cpu().item()

    trigger_labels, arg_labels = [], []
    for i, word_idx in enumerate(word_ids):
        if word_idx is not None and word_idx < len(sentence_tokens):
            trigger_labels.append(dataset_encoders['trigger'].inverse_transform([trigger_preds[i]])[0])
            arg_labels.append(dataset_encoders['arg'].inverse_transform([arg_preds[i]])[0])
        elif word_idx is None:
            trigger_labels.append('O')
            arg_labels.append('O')

    if len(trigger_labels) > len(sentence_tokens):
        trigger_labels = [label for i, label in enumerate(trigger_labels) if word_ids[i] is not None][:len(sentence_tokens)]
        arg_labels = [label for i, label in enumerate(arg_labels) if word_ids[i] is not None][:len(sentence_tokens)]
    elif len(trigger_labels) < len(sentence_tokens):
        print(f"Warning: Sentence truncated. Predicted tags length ({len(trigger_labels)}) is less than original tokens length ({len(sentence_tokens)})")

    event_label = dataset_encoders['event'].inverse_transform([event_pred])[0]

    return {
        'tokens': sentence_tokens,
        'trigger_tags': trigger_labels,
        'argument_tags': arg_labels,
        'event_type': event_label
    }

model_path = "best_model.pt"
encoder_dir = "oneie_encoders"

model = JointModel('xlm-roberta-base', 21, 8, 13)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

with open(os.path.join(encoder_dir, "trigger_labels.pkl"), "rb") as f:
    trigger_encoder = pickle.load(f)
with open(os.path.join(encoder_dir, "arg_labels.pkl"), "rb") as f:
    arg_encoder = pickle.load(f)
with open(os.path.join(encoder_dir, "event_labels.pkl"), "rb") as f:
    event_encoder = pickle.load(f)

tokenizer = XLMRobertaTokenizerFast.from_pretrained('xlm-roberta-base')

sentence = "Tiếp tục chương trình Kỳ họp thứ 9, chiều 20/6, Quốc hội họp phiên toàn thể tại Hội trường..."
tokens = sentence.split()

result = predict_single_instance(model, tokenizer, tokens,
    {'trigger': trigger_encoder, 'arg': arg_encoder, 'event': event_encoder}
)

print("\n===== PREDICTION =====")
for i in range(len(result['tokens'])):
    print(f"{result['tokens'][i]:<10} Trigger: {result['trigger_tags'][i]:<15} Argument: {result['argument_tags'][i]}")
print(f"➡️ Event Type: {result['event_type']}")


RuntimeError: Error(s) in loading state_dict for JointModel:
	size mismatch for trigger_classifier.4.weight: copying a param with shape torch.Size([21, 384]) from checkpoint, the shape in current model is torch.Size([20, 384]).
	size mismatch for trigger_classifier.4.bias: copying a param with shape torch.Size([21]) from checkpoint, the shape in current model is torch.Size([20]).
	size mismatch for arg_classifier.4.weight: copying a param with shape torch.Size([8, 384]) from checkpoint, the shape in current model is torch.Size([20, 384]).
	size mismatch for arg_classifier.4.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([20]).
	size mismatch for event_classifier.4.weight: copying a param with shape torch.Size([13, 384]) from checkpoint, the shape in current model is torch.Size([10, 384]).
	size mismatch for event_classifier.4.bias: copying a param with shape torch.Size([13]) from checkpoint, the shape in current model is torch.Size([10]).