# Universal Information Extraction

In [None]:
import json
import os
import torch
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.utils.data import DataLoader, Dataset
from typing import List, Dict, Set, Tuple

## Configuration

In [None]:
MODEL_NAME = "luyaojie/uie-large-en"

Define special tokens based on UIE's T5 usage

In [None]:
RECORD_START = "<extra_id_0>"
RECORD_END = "<extra_id_1>"
SPAN_START = "<extra_id_0>" # start of type/label name or association list
SPAN_END = "<extra_id_1>"   # end of type/label name or association list
TEXT_START = "<extra_id_5>" # separator between type/label and span text

Define Label Mappers for each dataset

In [None]:
BIO_LABEL_MAPPERS = {
    "biored": {
        "entities": {
            "GeneOrGeneProduct": "gene or gene product",
            "DiseaseOrPhenotypicFeature": "disease or phenotypic feature",
            "ChemicalEntity": "chemical entity",
            "SequenceVariant": "sequence variant",
            "Species": "species",
            "CellLine": "cell line",
        },
        "relations": {
            "Association": "is associated with",
            "Positive_Correlation": "positively correlates with",
            "Negative_Correlation": "negatively correlates with",
            "Bind": "binds to",
            "Cotreatment": "is cotreated with",
            "Comparison": "is compared to",
            "Conversion": "converts to",
            "Drug_Interaction": "interacts with drug",
        }
    },
    "ddi": {
        "entities": {
            "DRUG": "drug",
            "GROUP": "group of drugs",
        },
        "relations": {
            "MECHANISM": "has mechanism",
            "EFFECT": "has effect",
            "ADVISE": "is advised against",
            "INT": "interacts with",
        }
    },
    "chemprot": {
         "entities": {
            "CHEMICAL": "chemical",
            "GENE": "gene or protein",
            "GENE-Y": "gene or protein",
            "GENE-N": "gene or protein",
        },
        "relations": {
            # Only include relations evaluated in BioCreative VI
            "CPR:3": "upregulates or activates",
            "CPR:4": "downregulates or inhibits",
            "CPR:5": "acts as agonist",
            "CPR:6": "acts as antagonist",
            "CPR:9": "is substrate or product of",
        }
    },
    "bc5cdr": {
        "entities": {
            "Chemical": "chemical",
            "Disease": "disease",
        },
        "relations": {
            "CID": "causes or induces",
        }
    }
}

## Helper Functions

In [None]:
def load_jsonl(file_path):
    """Loads a JSONL file into a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def build_ssi(entity_types: List[str], relation_types: List[str]) -> str:
    """
    Builds the Structured Schema Instructor (SSI) string.
    Uses <spot> for entities and <asoc> for relations.

    Args:
        entity_types: A sorted list of descriptive entity type names.
        relation_types: A sorted list of descriptive relation type names.

    Returns:
        The SSI string to prefix to the input text.
    """
    ssi = "<spot> " + "<spot> ".join(entity_types)
    ssi += " <asoc> " + "<asoc> ".join(relation_types) # Add space before <asoc>
    ssi += " <extra_id_2>"
    return ssi

def parse_sel_string(sel_string: str) -> Dict[str, List[Tuple[str, List[Tuple[str, str]]]]]:
    """
    Parses the cleaned generated SEL string into a structured format using defined markers.
    Output: { "entity_type": [ (span, [(relation, object_span), ...]), ... ], ... }
    """
    structured_output = {}

    # The entire SEL is wrapped in RECORD_START/END
    if sel_string.startswith(RECORD_START):
        sel_string = sel_string[len(RECORD_START):]
    
    if sel_string.endswith(RECORD_END):
        sel_string = sel_string[:-len(RECORD_END)]
    
    sel_string = sel_string.strip() # remove any extra spaces after stripping

    # Split into individual record blocks
    # Each block starts with RECORD_START and ends with RECORD_END
    record_blocks = []
    current_pos = 0
    start_index = sel_string.find(RECORD_START)

    # Handle cases where the string might not contain any blocks or is malformed
    if start_index == -1 and len(sel_string) > 0:
        # If no RECORD_START but content exists
        print(f"Warning: SEL string may be malformed or missing record markers: {sel_string}")
        return {} # Return empty as structure is unexpected

    while start_index != -1:
        end_index = sel_string.find(RECORD_END, start_index + len(RECORD_START))
        if end_index == -1:
            # Malformed - found a start but no end
            print(f"Warning: Malformed SEL block (no RECORD_END): {sel_string[start_index:]}")
            break # Stop processing the string
        
        # Extract the content between RECORD_START and RECORD_END
        block_content = sel_string[start_index + len(RECORD_START):end_index].strip()
        if block_content: # Only add non-empty blocks
             record_blocks.append(block_content)
             
        # Find the start of the next block
        start_index = sel_string.find(RECORD_START, end_index + len(RECORD_END))

    # Process each block
    for block in record_blocks:
        try:
            # Extract entity type
            type_start = block.find(SPAN_START)
            type_text_sep = block.find(TEXT_START)
            if not (type_start == 0 and type_text_sep > type_start):
                 print(f"Warning: Could not find entity type in block: {block}")
                 continue # Skip malformed block
            
            entity_type = block[type_start + len(SPAN_START):type_text_sep].strip()

            # Extract entity span and relations (if any)
            remaining_block = block[type_text_sep + len(TEXT_START):].strip()
            
            # Check if there are associations after the main span
            assoc_list_start = remaining_block.find(SPAN_START)
            
            entity_span = ""
            relations = []

            if assoc_list_start != -1:
                # Relations exist
                entity_span = remaining_block[:assoc_list_start].strip()
                assoc_list_str = remaining_block[assoc_list_start:].strip()

                # The association list should be wrapped in SPAN_START/SPAN_END
                if assoc_list_str.startswith(SPAN_START) and assoc_list_str.endswith(SPAN_END):
                    assoc_list_content = assoc_list_str[len(SPAN_START):-len(SPAN_END)].strip()
                    
                    # Split individual associations (each starts with SPAN_START)
                    assoc_parts = []
                    part_start = assoc_list_content.find(SPAN_START)
                    while part_start != -1:
                        part_end = assoc_list_content.find(SPAN_END, part_start + len(SPAN_START))
                        if part_end == -1: break
                        assoc_parts.append(assoc_list_content[part_start:part_end + len(SPAN_END)])
                        part_start = assoc_list_content.find(SPAN_START, part_end + len(SPAN_END))

                    for part in assoc_parts:
                         # Inside each part: SPAN_START rel_type TEXT_START obj_span SPAN_END
                         part_content = part[len(SPAN_START):-len(SPAN_END)].strip()
                         rel_sep = part_content.find(TEXT_START)
                         if rel_sep != -1:
                             rel_type = part_content[:rel_sep].strip()
                             obj_span = part_content[rel_sep + len(TEXT_START):].strip()
                             if rel_type and obj_span:
                                 relations.append((rel_type, obj_span))
                else:
                    print(f"Warning: Malformed association list: {assoc_list_str}")
            else:
                # No relations, the rest is the entity span
                entity_span = remaining_block

            # Add to output if valid
            if entity_type and entity_span:
                if entity_type not in structured_output:
                    structured_output[entity_type] = []
                # Avoid adding duplicates within the same record block parsing
                current_entry = (entity_span, relations)
                is_duplicate = False
                for existing_span, existing_rels in structured_output[entity_type]:
                    if existing_span == entity_span and set(existing_rels) == set(relations):
                        is_duplicate = True
                        break
                if not is_duplicate:
                    structured_output[entity_type].append(current_entry)

        except Exception as e:
            print(f"Error parsing block: {block}\nError: {e}")
            continue # Skip block on error

    return structured_output

def get_ground_truth_sets(data_sample: Dict, entity_map: Dict, relation_map: Dict) -> Tuple[Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
    """Extracts ground truth entities and relations into sets for easy comparison."""
    gt_entities = set()
    gt_relations = set()

    # Map entity types using the mapper
    entity_spans = {} # Store span -> canonical type
    for entity in data_sample.get('entities', []):
        raw_type = entity['type']
        if raw_type in entity_map:
            canonical_type = entity_map[raw_type]
            span = entity['text']
            gt_entities.add((span, canonical_type))
            entity_spans[span] = canonical_type # Track for relation mapping

    # Map relation types using the mapper
    for relation in data_sample.get('relations', []):
        raw_type = relation['type']
        if raw_type in relation_map:
            canonical_rel_type = relation_map[raw_type]
            head_span = relation['head']['text']
            tail_span = relation['tail']['text']

            gt_relations.add((head_span, canonical_rel_type, tail_span))

    return gt_entities, gt_relations


def get_predicted_sets(parsed_sel: Dict) -> Tuple[Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
    """Extracts predicted entities and relations from the parsed SEL structure."""
    pred_entities = set()
    pred_relations = set()

    for entity_type, span_list in parsed_sel.items():
        for entity_span, relations in span_list:
            pred_entities.add((entity_span, entity_type))
            for rel_type, obj_span in relations:
                pred_relations.add((entity_span, rel_type, obj_span))

    return pred_entities, pred_relations

def calculate_extraction_metrics(preds: List[Tuple[Set, Set]], golds: List[Tuple[Set, Set]]) -> Dict:
    """Calculates P/R/F1 for entities and relations."""
    total_ent_tp, total_ent_fp, total_ent_fn = 0, 0, 0
    total_rel_tp, total_rel_fp, total_rel_fn = 0, 0, 0

    for (pred_ents, pred_rels), (gold_ents, gold_rels) in zip(preds, golds):
        # Entity metrics
        total_ent_tp += len(pred_ents.intersection(gold_ents))
        total_ent_fp += len(pred_ents.difference(gold_ents))
        total_ent_fn += len(gold_ents.difference(pred_ents))

        # Relation metrics
        total_rel_tp += len(pred_rels.intersection(gold_rels))
        total_rel_fp += len(pred_rels.difference(gold_rels))
        total_rel_fn += len(gold_rels.difference(pred_rels))

    ent_precision = total_ent_tp / (total_ent_tp + total_ent_fp) if (total_ent_tp + total_ent_fp) > 0 else 0
    ent_recall = total_ent_tp / (total_ent_tp + total_ent_fn) if (total_ent_tp + total_ent_fn) > 0 else 0
    ent_f1 = 2 * (ent_precision * ent_recall) / (ent_precision + ent_recall) if (ent_precision + ent_recall) > 0 else 0

    rel_precision = total_rel_tp / (total_rel_tp + total_rel_fp) if (total_rel_tp + total_rel_fp) > 0 else 0
    rel_recall = total_rel_tp / (total_rel_tp + total_rel_fn) if (total_rel_tp + total_rel_fn) > 0 else 0
    rel_f1 = 2 * (rel_precision * rel_recall) / (rel_precision + rel_recall) if (rel_precision + rel_recall) > 0 else 0

    return {
        "entity_precision": ent_precision,
        "entity_recall": ent_recall,
        "entity_f1": ent_f1,
        "relation_precision": rel_precision,
        "relation_recall": rel_recall,
        "relation_f1": rel_f1,
    }

def calculate_metrics_from_files(predictions_dir: str):
    """
    Reads prediction files from the directory and calculates metrics for each.
    """
    print(f"\n--- Calculating Metrics from Saved Files in '{predictions_dir}' ---")
    all_dataset_metrics = {}

    if not os.path.exists(predictions_dir):
        print(f"Predictions directory '{predictions_dir}' not found.")
        return {} # Return empty dict if directory doesn't exist

    found_files = False
    for filename in os.listdir(predictions_dir):
        if filename.endswith("_predictions.jsonl"):
            found_files = True
            dataset_name = filename.replace("_predictions.jsonl", "")
            filepath = os.path.join(predictions_dir, filename)
            print(f"\nCalculating metrics for: {dataset_name} from {filepath}")

            all_preds_sets = []
            all_labels_sets = []
            line_count = 0

            try:
                with open(filepath, 'r', encoding='utf-8') as f:
                    for line in f:
                        line_count += 1
                        try:
                            data = json.loads(line)
                            # Convert saved lists back to sets of tuples
                            pred_entities = set(tuple(e) for e in data['predicted_entities'])
                            pred_relations = set(tuple(r) for r in data['predicted_relations'])
                            gt_entities = set(tuple(e) for e in data['ground_truth_entities'])
                            gt_relations = set(tuple(r) for r in data['ground_truth_relations'])

                            all_preds_sets.append((pred_entities, pred_relations))
                            all_labels_sets.append((gt_entities, gt_relations))
                        except json.JSONDecodeError:
                             print(f"Skipping malformed JSON line {line_count} in {filename}")
                        except KeyError as e:
                             print(f"Skipping line {line_count} due to missing key {e} in {filename}")

                if not all_preds_sets:
                    print(f"No valid predictions found in file for {dataset_name}.")
                    continue

                metrics = calculate_extraction_metrics(all_preds_sets, all_labels_sets)
                all_dataset_metrics[dataset_name] = metrics
                print(f"Metrics for {dataset_name} ({len(all_preds_sets)} examples):")
                print(json.dumps(metrics, indent=2))

            except FileNotFoundError:
                 print(f"File not found during metric calculation: {filepath}")
            except Exception as e:
                 print(f"Error reading or processing file {filepath}: {e}")

    if not found_files:
        print(f"No prediction files found in '{predictions_dir}'.")

    print("-------------------------------------------------------")
    return all_dataset_metrics

## Zero-Shot Testing

Dataset Class for Batching

In [None]:
class InferenceDataset(Dataset):
    def __init__(self, data, tokenizer, ssi_string, max_source_length):
        self.data = data
        self.tokenizer = tokenizer
        self.ssi_string = ssi_string
        self.max_source_length = max_source_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['text']
        input_text = f"{self.ssi_string} {text}" # Add space after SSI
        
        tokenized = self.tokenizer(
            input_text,
            max_length=self.max_source_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        # Squeeze to remove batch dim added by tokenizer
        return {
            "input_ids": tokenized.input_ids.squeeze(0),
            "attention_mask": tokenized.attention_mask.squeeze(0),
            "original_data": item # Keep original data for ground truth comparison
        }

Setting up model and device

In [None]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load Model and tokenizer
print(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"Loading model: {MODEL_NAME}")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

Add special tokens if necessary

In [None]:
current_special_tokens = tokenizer.all_special_tokens
needed_markers = ["<spot>", "<asoc>"]
tokens_to_add = [tok for tok in needed_markers if tok not in current_special_tokens]

if tokens_to_add:
    print(f"Adding special tokens: {tokens_to_add}")
    tokenizer.add_special_tokens({'additional_special_tokens': tokens_to_add})
    model.resize_token_embeddings(len(tokenizer))
    print("Resized model embeddings.")
else:
    print("Special tokens already present in the tokenizer.")

Setting model to evaluation mode and moving to GPU

In [None]:
model.to(device)
model.eval()

Dataset Configuration

In [None]:
# Paths to test files
DATASET_FILES = {
    "biored": "/path/to/your/biored_test.jsonl",
    "ddi": "/path/to/your/ddi_test.jsonl",
    "chemprot": "/path/to/your/chemprot_test.jsonl",
    "bc5cdr": "/path/to/your/bc5cdr_test.jsonl",
}

MAX_SOURCE_LENGTH = 512 # adjust based on the data/model
MAX_TARGET_LENGTH = 512 # adjust based on model
BATCH_SIZE = 16
PREDICTIONS_DIR = "./predictions/uie"

Loop through datasets and save predictions

In [None]:
# Create output directory
os.makedirs(PREDICTIONS_DIR, exist_ok=True)
print(f"Saving predictions to: {PREDICTIONS_DIR}")

# Loop through datasets for inference
for dataset_name, test_file in DATASET_FILES.items():
    print(f"\n--- Processing Dataset: {dataset_name} ---")

    # Load data
    try:
        test_data = load_jsonl(test_file)
        print(f"Loaded {len(test_data)} examples.")
    except FileNotFoundError:
        print(f"Test file not found: {test_file}. Skipping dataset.")
        continue
    except Exception as e:
        print(f"Error loading {test_file}: {e}. Skipping dataset.")
        continue

    # Get mapper and build SSI
    mapper = BIO_LABEL_MAPPERS.get(dataset_name)
    if not mapper:
        print(f"Label mapper not found for {dataset_name}. Skipping.")
        continue
    
    entity_map = mapper['entities']
    relation_map = mapper['relations']
    entity_types_sorted = sorted(list(entity_map.values()))
    relation_types_sorted = sorted(list(relation_map.values()))
    ssi_string = build_ssi(entity_types_sorted, relation_types_sorted)
    print("Generated SSI string:", ssi_string)

    # Create dataLoader
    inference_dataset = InferenceDataset(test_data, tokenizer, ssi_string, MAX_SOURCE_LENGTH)
    dataloader = DataLoader(inference_dataset, batch_size=BATCH_SIZE)

    output_filepath = os.path.join(PREDICTIONS_DIR, f"{dataset_name}.jsonl")

    # Run inference and write predictions
    with torch.no_grad(), open(output_filepath, 'w', encoding='utf-8') as outfile:
        for batch in tqdm(dataloader, desc=f"Inferring & Saving {dataset_name}"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=MAX_TARGET_LENGTH,
            )

            # Decode, parse SEL
            sel_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
            cleaned_sels = [s.replace(tokenizer.pad_token, "").replace(tokenizer.eos_token, "").strip() for s in sel_outputs]
            parsed_preds = [parse_sel_string(s) for s in cleaned_sels]

            # Extract sets from parsed predictions and ground truth
            batch_preds_sets = [get_predicted_sets(p) for p in parsed_preds]
            batch_gts_sets = [get_ground_truth_sets(item, entity_map, relation_map) for item in batch['original_data']]
            original_texts = [item['text'] for item in batch['original_data']]

            # Write each item in the batch to the output file
            for i in range(len(original_texts)):
                pred_ents, pred_rels = batch_preds_sets[i]
                gt_ents, gt_rels = batch_gts_sets[i]

                output_record = {
                    "text": original_texts[i],
                    "sel_output": cleaned_sels[i],
                    "predicted_entities": sorted([list(e) for e in pred_ents]),
                    "predicted_relations": sorted([list(r) for r in pred_rels]),
                    "ground_truth_entities": sorted([list(e) for e in gt_ents]),
                    "ground_truth_relations": sorted([list(r) for r in gt_rels])
                }
                outfile.write(json.dumps(output_record) + "\n")

    print(f"Finished processing and saved predictions for {dataset_name} to {output_filepath}")

print("\n=== Inference and Saving Complete ===")

Get metrics

In [None]:
final_metrics = calculate_metrics_from_files(PREDICTIONS_DIR)

print("\n=== Final Zero-Shot Results Summary ===")
if final_metrics:
    print(json.dumps(final_metrics, indent=2))
else:
    print("No metrics were calculated.")
print("=======================================")