# Universal Information Extraction

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import torch

from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
sys.path.append("../")

from common.utils import load_jsonl, save_predictions
from common.metrics import get_metrics
from models.uie.data import UIEDataset
from models.uie.inference import run_inference

Initialize model and tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("luyaojie/uie-large-en")
model = AutoModelForSeq2SeqLM.from_pretrained("luyaojie/uie-large-en")

device = torch.device('cuda:3')
model.to(device)

## ChemProt

### Zero-shot Testing

Create dataset and dataloader

In [None]:
data = load_jsonl("data/chemprot/test.jsonl")
cp_test_dataset = UIEDataset(data=data, dataset_name='chemprot', tokenizer=tokenizer)
cp_test_loader = DataLoader(cp_test_dataset, batch_size=32, shuffle=False)

In [None]:
predictions = run_inference()
save_predictions(predictions, "prediction/uie/zero-shot", "chemprot.jsonl")

In [None]:
metrics = get_metrics(f"prediction/uie/zero-shot/chemprot.jsonl")

print(f"\nEntity F1: {metrics['entity_f1']:.4f}")
print(f"Relation F1: {metrics['relation_f1']:.4f}")