# Universal Information Extraction Benchmark on DDI Dataset

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import torch

from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from pprint import pprint

In [4]:
sys.path.append("../../")

from src.common.utils import load_jsonl, save_predictions
from src.common.metrics import get_metrics
from src.common.data import generative_collate_fn
from src.models.uie.data import UIEDataset
from src.models.uie.inference import run_inference

Initialize model and tokenizer

In [5]:
tokenizer = AutoTokenizer.from_pretrained("luyaojie/uie-large-en")
model = AutoModelForSeq2SeqLM.from_pretrained("luyaojie/uie-large-en")

device = torch.device('cuda:5')
model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(32102, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32102, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=1024, out_features=2816, bias=False)
              (wi_1): Linear(in_features=1024, out_features=2816, bias=False)
       

### Zero-shot Testing

Create dataset and dataloader

In [8]:
data = load_jsonl("../../data/ddi/test.jsonl")
test_dataset = UIEDataset(data=data, dataset_name='ddi', tokenizer=tokenizer)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=generative_collate_fn)

In [15]:
pprint(test_dataset[1]['text'])

('<spot> drug <spot> groups of drugs <spot> drug brand <spot> unapproved drug '
 '<asoc> has mechanism <asoc> has effect <asoc> is advised against <asoc> '
 'interacts with <extra_id_2> Tuberculosis and HIV co-infection: screening and '
 'treatment strategies. Globally, tuberculosis (TB) and HIV interact in deadly '
 'synergy. The high burden of TB among HIV-infected individuals underlies the '
 'importance of TB diagnosis, treatment and prevention for clinicians involved '
 'in HIV care. Despite expanding access to antiretroviral therapy (ART) to '
 'treat HIV infection in resource-limited settings, many individuals in need '
 'of therapy initiate ART too late and have already developed clinically '
 'significant TB by the time they present for care. Many co-infected '
 'individuals are in need of concurrent ART and anti-TB therapy, which '
 'dramatically improves survival, but also raises several management '
 'challenges, including drug interactions, shared drug toxicities and TB '


In [16]:
pprint(test_dataset[1]['label'])

('<extra_id_0> <extra_id_0> groups of drugs <extra_id_5> antiretroviral '
 '<extra_id_1> <extra_id_1>')


In [10]:
predictions = run_inference(model=model, dataloader=test_loader, tokenizer=tokenizer)

save_predictions(predictions, "../../predictions/uie/zero-shot", "ddi.jsonl")

Running inference:   0%|          | 0/10 [00:00<?, ?it/s]

Predictions saved to ../../predictions/uie/zero-shot/ddi.jsonl


In [14]:
metrics = get_metrics(f"../../predictions/uie/zero-shot/ddi.jsonl")

print(f"\nEntity F1: {metrics['entity_f1']:.4f}")
print(f"Relation F1: {metrics['relation_f1']:.4f}")


Entity F1: 0.0000
Relation F1: 0.0000
