# Universal Information Extraction

In [None]:
import json
import os
import torch
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BatchEncoding, PreTrainedTokenizerBase
from torch.utils.data import DataLoader, Dataset
from typing import List, Dict, Set, Tuple

## Configuration

In [None]:
MODEL_NAME = "luyaojie/uie-large-en"

Define special tokens based on UIE's T5 usage

In [None]:
START = '<extra_id_0>'
END = '<extra_id_1>'
TARGET = '<extra_id_5>'

#### Define label mappers for each dataset

In [None]:
BIO_LABEL_MAPPERS = {
    "biored": {
        "entities": {
            "GeneOrGeneProduct": "gene or gene product",
            "DiseaseOrPhenotypicFeature": "disease or phenotypic feature",
            "ChemicalEntity": "chemical or drug",
            "SequenceVariant": "genomic or protein variant",
            "OrganismTaxon": "species",
            "CellLine": "cell line",
        },
        "relations": {
            "Association": "is associated with",
            "Positive_Correlation": "positively correlates with",
            "Negative_Correlation": "negatively correlates with",
            "Bind": "binds to",
            "Cotreatment": "is cotreated with",
            "Comparison": "is compared to",
            "Conversion": "converts to",
            "Drug_Interaction": "interacts with drug",
        }
    },
    "ddi": {
        "entities": {
            "DRUG": "drug",
            "GROUP": "groups of drugs",
            "BRAND": "drug brand",
            "DRUG_N": "unapproved drug"
        },
        "relations": {
            "MECHANISM": "has mechanism",
            "EFFECT": "has effect",
            "ADVISE": "is advised against",
            "INT": "interacts with",
        }
    },
    "chemprot": {
         "entities": {
            "CHEMICAL": "chemical",
            "GENE-N": "gene or protein",
            "GENE-Y": "gene or protein",
        },
        "relations": {
            "Agonist": "is agonist of",
            "Antagonist": "is antagonist of",
            "Cofactor": "is cofactor of",
            "Downregulator": "downregulates",
            "Modulator": "modulates",
            "Not": "not related to",
            "Part_of": "is part of",
            "Regulator": "regulates",
            "Substrate": "is substrate of",
            "Undefined": "unkown relation to",
            "Upregulator": "upregulates",
        }
    },
    "bc5cdr": {
        "entities": {
            "Chemical": "chemical",
            "Disease": "disease",
        },
        "relations": {
            "CID": "causes or induces",
        }
    }
}

## Helper Functions

### General

In [None]:
def load_jsonl(file_path):
    """Loads a JSONL file into a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    
    return data

### UIE

In [None]:
def build_ssi(entity_types: List[str], relation_types: List[str]) -> str:
    """
    Builds the Structured Schema Instructor (SSI) string.
    Uses <spot> for entities and <asoc> for relations.
    """
    ssi = "<spot> " + "<spot> ".join(entity_types)
    ssi += " <asoc> " + "<asoc> ".join(relation_types)
    ssi += " <extra_id_2>"
    return ssi

def parse_sel_string(sel_string: str) -> List[Dict[str, any]]:
    """
    Parses the cleaned generated SEL string into a structured format using defined markers.
    """
    structured_output = []

    # The entire SEL is wrapped in START/END
    if sel_string.startswith(START) and sel_string.endswith(END):
        sel_string = sel_string[len(START):-len(END)].strip()
    else:
        return structured_output

    # Each block is wrapped in START/END
    records: List[str] = []
    start_index = sel_string.find(START)

    # Handle cases where the string might not contain any records or is malformed
    if start_index == -1 and len(sel_string) > 0:
        return structured_output

    while start_index != -1:
        end_index = len(sel_string)
        curr_start = start_index
        while curr_start < end_index and curr_start != -1 and end_index != -1:
            end_index = sel_string.find(END, curr_start + len(START))
            curr_start = sel_string.find(START, curr_start + len(START))
        
        if end_index == -1:
            return structured_output
        
        # Extract the content between START and END
        record_content = sel_string[start_index + len(START):end_index].strip()
        if record_content:
             records.append(record_content)

    # Process each record
    for record in records:
        try:
            # Extract entity type
            start = record.find(START)
            subj_sep = record.find(TARGET)
            if not (start == 0 and subj_sep > start):
                 continue
            
            entity_type = record[start + len(START):subj_sep].strip()

            # Extract entity span and relations (if any)
            remaining_record = record[subj_sep + len(TARGET):].strip()
            
            # Check if there are associations after the main span
            rel_start = remaining_record.find(START)
            
            subj_span = ""
            relations = []

            if rel_start != -1: # relations exist
                subj_span = remaining_record[:rel_start].strip()
                while rel_start != -1:
                    rel_end = remaining_record.find(END, rel_start + len(START))
                    if rel_end == -1:
                        break
                    
                    # Extract the content between START and END
                    relation_str = remaining_record[rel_start + len(START):rel_end].strip()
                    target_sep = relation_str.find(TARGET)
                    if target_sep == -1:
                        continue

                    rel_type = relation_str[:target_sep].strip()
                    obj_span = relation_str[target_sep + len(TARGET):].strip()
                    if rel_type and obj_span:
                        relations.append((rel_type, obj_span))
                    
                    rel_start = remaining_record.find(START, rel_end + len(END))

            else: # no relations, the rest is the entity span
                subj_span = remaining_record

            # Add to output if valid
            if entity_type and subj_span:
                record_info = {'span': subj_span, 'spot': entity_type, 'asoc': relations}
                structured_output.append(record_info)

        except Exception as e:
            print(f"Error parsing block: {record}\nError: {e}")
            continue

    return structured_output

### Inference

In [None]:
def get_predictions(dataset_name: str, test_file: str, output_path: str) -> None:
    try:
        test_data = load_jsonl(test_file)
        print(f"Loaded {len(test_data)} examples.")
    except Exception as e:
        print(f"Error loading {test_file}: {e}.")
        return -1

    # Get mapper and build SSI
    mapper = BIO_LABEL_MAPPERS.get(dataset_name)
    if not mapper:
        print(f"Label mapper not found for {dataset_name}. Skipping.")
        return -1

    entity_map = mapper['entities']
    relation_map = mapper['relations']
    entity_types_sorted = sorted(list(entity_map.values()))
    relation_types_sorted = sorted(list(relation_map.values()))
    ssi_string = build_ssi(entity_types_sorted, relation_types_sorted)
    print("Generated SSI string:", ssi_string)

    # Create dataLoader
    inference_dataset = InferenceDataset(test_data, tokenizer, ssi_string, MAX_SOURCE_LENGTH)
    dataloader = DataLoader(inference_dataset, batch_size=BATCH_SIZE)

    output_filepath = os.path.join(PREDICTIONS_DIR, f"{dataset_name}.jsonl")

# Run inference and write predictions
with torch.no_grad(), open(output_filepath, 'w', encoding='utf-8') as outfile:
    for batch in tqdm(dataloader, desc=f"Inferring & Saving {dataset_name}"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        generated_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=MAX_TARGET_LENGTH,
        )

        # Decode, parse SEL
        sel_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
        cleaned_sels = [s.replace(tokenizer.pad_token, "").replace(tokenizer.eos_token, "").strip() for s in sel_outputs]
        parsed_preds = [parse_sel_string(s) for s in cleaned_sels]

        # Extract sets from parsed predictions and ground truth
        batch_preds_sets = [get_predicted_sets(p) for p in parsed_preds]
        batch_gts_sets = [get_ground_truth_sets(item, entity_map, relation_map) for item in batch['original_data']]
        original_texts = [item['text'] for item in batch['original_data']]

        # Write each item in the batch to the output file
        for i in range(len(original_texts)):
            pred_ents, pred_rels = batch_preds_sets[i]
            gt_ents, gt_rels = batch_gts_sets[i]

            output_record = {
                "text": original_texts[i],
                "sel_output": cleaned_sels[i],
                "predicted_entities": sorted([list(e) for e in pred_ents]),
                "predicted_relations": sorted([list(r) for r in pred_rels]),
                "ground_truth_entities": sorted([list(e) for e in gt_ents]),
                "ground_truth_relations": sorted([list(r) for r in gt_rels])
            }
            outfile.write(json.dumps(output_record) + "\n")

print(f"Finished processing and saved predictions for {dataset_name} to {output_filepath}")

### Evaluation

In [None]:
def get_ground_truth_sets(data_sample: Dict, entity_map: Dict, relation_map: Dict) -> Tuple[Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
    """Extracts ground truth entities and relations into sets for easy comparison."""
    gt_entities = set()
    gt_relations = set()

    for entity in data_sample.get('entities', []):
        entity_span = entity['text']
        entity_type = entity_map.get(entity['type'], entity['type'])
        gt_entities.add((entity_span, entity_type))
    
    for relation in data_sample.get('relations', []):
        head_span = relation['head']['text']
        tail_span = relation['tail']['text']
        relation_type = relation_map.get(relation['type'], relation['type'])
        gt_relations.add((head_span, relation_type, tail_span))

    return gt_entities, gt_relations

def get_predicted_sets(parsed_sel: List[Dict['str', any]]) -> Tuple[Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
    """Extracts predicted entities and relations from the parsed SEL structure."""
    pred_entities = set()
    pred_relations = set()

    for record in parsed_sel:
        subj_span = record['span']
        subj_type = record['spot']
        pred_entities.add((subj_span, subj_type))
        for rel_type, obj_span in record['asoc']:
            pred_relations.add((subj_span, rel_type, obj_span))

    return pred_entities, pred_relations

def calculate_metrics(preds: List[Tuple[Set, Set]], golds: List[Tuple[Set, Set]]) -> Dict:
    """Calculates P/R/F1 for entities and relations."""
    total_ent_tp, total_ent_fp, total_ent_fn = 0, 0, 0
    total_rel_tp, total_rel_fp, total_rel_fn = 0, 0, 0

    for (pred_ents, pred_rels), (gold_ents, gold_rels) in zip(preds, golds):
        # Entity metrics
        total_ent_tp += len(pred_ents.intersection(gold_ents))
        total_ent_fp += len(pred_ents.difference(gold_ents))
        total_ent_fn += len(gold_ents.difference(pred_ents))

        # Relation metrics
        total_rel_tp += len(pred_rels.intersection(gold_rels))
        total_rel_fp += len(pred_rels.difference(gold_rels))
        total_rel_fn += len(gold_rels.difference(pred_rels))

    ent_precision = total_ent_tp / (total_ent_tp + total_ent_fp) if (total_ent_tp + total_ent_fp) > 0 else 0
    ent_recall = total_ent_tp / (total_ent_tp + total_ent_fn) if (total_ent_tp + total_ent_fn) > 0 else 0
    ent_f1 = 2 * (ent_precision * ent_recall) / (ent_precision + ent_recall) if (ent_precision + ent_recall) > 0 else 0

    rel_precision = total_rel_tp / (total_rel_tp + total_rel_fp) if (total_rel_tp + total_rel_fp) > 0 else 0
    rel_recall = total_rel_tp / (total_rel_tp + total_rel_fn) if (total_rel_tp + total_rel_fn) > 0 else 0
    rel_f1 = 2 * (rel_precision * rel_recall) / (rel_precision + rel_recall) if (rel_precision + rel_recall) > 0 else 0

    return {
        "entity_precision": ent_precision,
        "entity_recall": ent_recall,
        "entity_f1": ent_f1,
        "relation_precision": rel_precision,
        "relation_recall": rel_recall,
        "relation_f1": rel_f1,
    }

def calculate_metrics_from_file(file_path: str) -> Dict[str, float]:
    """
    Reads prediction files from the directory and calculates metrics for each.
    """
    metrics = {}
    all_preds_sets = []
    all_labels_sets = []
    line_count = 0
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line_count += 1
                try:
                    data = json.loads(line)
                    # Convert saved lists back to sets of tuples
                    pred_entities = set(tuple(e) for e in data['pred_entities'])
                    pred_relations = set(tuple(r) for r in data['pred_relations'])
                    gt_entities = set(tuple(e) for e in data['gt_entities'])
                    gt_relations = set(tuple(r) for r in data['gt_relations'])

                    all_preds_sets.append((pred_entities, pred_relations))
                    all_labels_sets.append((gt_entities, gt_relations))
                
                except json.JSONDecodeError:
                        print(f"Skipping malformed JSON line {line_count} in {file_path}")

        metrics = calculate_metrics(all_preds_sets, all_labels_sets)
    
    except Exception as e:
            print(f"Error reading or processing {file_path}: {e}")

    return metrics

## Zero-Shot Testing

Dataset Class for Batching

In [None]:
class UIEDataset(Dataset):
    def __init__(self, data: List[Dict[str, any]], tokenizer: AutoTokenizer, ssi_string: str, max_source_length: int = 512):
        self.data = data
        self.tokenizer = tokenizer
        self.ssi_string = ssi_string
        self.max_source_length = max_source_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['text']
        input_text = f"{self.ssi_string} {text}"
        tokenized = self.tokenizer(
            input_text,
            max_length=self.max_source_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        return {
            "input_ids": tokenized.input_ids.squeeze(0), # squeeze to remove batch dim added by tokenizer
            "attention_mask": tokenized.attention_mask.squeeze(0),
            "original_data": item # original data for ground truth comparison
        }

Setting up model and device

In [None]:
device = torch.device("cuda:3")
print("Using device:", device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

Add special tokens if necessary

In [None]:
special_tokens = tokenizer.all_special_tokens
needed_markers = ["<spot>", "<asoc>"]
tokens_to_add = [tok for tok in needed_markers if tok not in special_tokens]

if tokens_to_add:
    print(f"Adding special tokens: {tokens_to_add}")
    tokenizer.add_special_tokens({'additional_special_tokens': tokens_to_add})
    model.resize_token_embeddings(len(tokenizer))
    print("Resized model embeddings.")
else:
    print("Special tokens already present in the tokenizer.")

Setting model to evaluation mode and moving to GPU

In [None]:
model.to(device)
model.eval()

Dataset Configuration

In [None]:
BATCH_SIZE = 16

## Inference

### ChemProt

Load dataset

Get metrics

In [None]:
final_metrics = calculate_metrics_from_file(PREDICTIONS_DIR)

print("\n=== Final Zero-Shot Results Summary ===")
if final_metrics:
    print(json.dumps(final_metrics, indent=2))
else:
    print("No metrics were calculated.")
print("=======================================")