# Odin Medical NER/RE — Inference Demo

This notebook demonstrates how to use the **Odin Medical NER/RE** model to extract medical entities and their relationships from clinical text.

**Model**: [odin-deus/odin-llama3.1-medical-ner-v14](https://huggingface.co/odin-deus/odin-llama3.1-medical-ner-v14)  
**Base**: Meta Llama 3.1 8B (4-bit quantized)  
**Task**: Named Entity Recognition + Relation Extraction  

| Entity Types | Relation Types |
|---|---|
| Disease, Drug, Symptom | causes, treats, associated_with, interacts_with |

| Metric | F1 |
|---|---|
| Entity (micro) | 0.911 |
| Relation (micro) | 0.832 |

## 1. Setup

Install Unsloth, which handles model loading, 4-bit quantization, and all dependency versioning.

> **Runtime**: This notebook requires a GPU runtime. In Colab: *Runtime → Change runtime type → T4 GPU*.

In [None]:
!pip install -q unsloth

## 2. Load Model

The model is a LoRA adapter on top of Llama 3.1 8B. Unsloth loads the base model in 4-bit
quantization (~4 GB VRAM) and applies the adapter automatically.

In [None]:
from unsloth import FastLanguageModel

MODEL_ID = "odin-deus/odin-llama3.1-medical-ner-v14"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_ID,
    max_seq_length=512,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
print(f"Model loaded: {MODEL_ID}")

## 3. Helper Functions

Two utilities:
- `extract()` — runs inference on a clinical text and returns the raw model output
- `parse_output()` — parses the structured output into Python dicts for entities and relations

In [None]:
import re
import torch

INSTRUCTION = (
    "Extract all medical entities and their relations from the following clinical text. "
    "Identify diseases, symptoms, drugs, procedures, and lab tests. "
    "For each entity found, specify its type. Then identify relations between entities."
)


def extract(text, max_new_tokens=256):
    """Run NER/RE extraction on a clinical text. Returns the raw model output."""
    prompt = f"""### Instruction:\n{INSTRUCTION}\n\n### Input:\n{text}\n\n### Output:\n"""
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cuda")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
    full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "### Output:" in full_output:
        return full_output.split("### Output:")[-1].strip()
    return full_output


def parse_output(text):
    """Parse model output into structured entities and relations."""
    entities, relations = [], []
    section = None
    for line in text.strip().split("\n"):
        line = line.strip()
        if not line:
            continue
        if "entities" in line.lower() and line.startswith("#"):
            section = "entities"
            continue
        elif "relations" in line.lower() and line.startswith("#"):
            section = "relations"
            continue
        if section == "entities" and re.match(r"^\d+\.", line):
            m = re.match(r"^\d+\.\s*\[(\w+)\]\s*(.+)", line)
            if m:
                entities.append({"type": m.group(1), "text": m.group(2).strip()})
        elif section == "relations" and re.match(r"^\d+\.", line):
            m = re.match(r"^\d+\.\s*(.+?)\s*--\[(.+?)\]-->\s*(.+)", line)
            if m:
                relations.append({"head": m.group(1).strip(), "relation": m.group(2).strip(), "tail": m.group(3).strip()})
    return {"entities": entities, "relations": relations}


def show(text):
    """Extract, parse, and display results for a clinical text."""
    raw = extract(text)
    parsed = parse_output(raw)
    print(f"Input:  {text}")
    print(f"\nEntities:")
    for e in parsed["entities"]:
        print(f"  [{e['type']}] {e['text']}")
    print(f"\nRelations:")
    for r in parsed["relations"]:
        print(f"  {r['head']} --[{r['relation']}]--> {r['tail']}")
    print("─" * 80)
    return parsed

## 4. Single Example

Run the model on a single clinical sentence.

In [None]:
result = show("The patient developed acute renal failure after treatment with enalapril.")

## 5. Batch Inference

Process multiple clinical texts and collect structured results.

In [None]:
clinical_texts = [
    "Rhabdomyolysis following clarithromycin monotherapy.",
    "Three diabetic cases of acute dizziness due to initial administration of voglibose.",
    "A 39-year-old schizophrenic man treated with olanzapine developed an elevated serum CK concentration.",
    "Metformin is commonly used for the treatment of type 2 diabetes mellitus.",
    "Concurrent use of warfarin and aspirin increases the risk of gastrointestinal bleeding.",
    "The patient was diagnosed with hypertension and prescribed lisinopril, which caused a persistent dry cough.",
]

results = []
for text in clinical_texts:
    r = show(text)
    results.append({"input": text, **r})
    print()

## 6. Results as a Table

View all extracted entities and relations in a tabular format.

In [None]:
# Entities table
print(f"{'Input (truncated)':<55} {'Type':<10} {'Entity'}")
print("═" * 100)
for r in results:
    short_input = r["input"][:52] + "..." if len(r["input"]) > 55 else r["input"]
    for i, e in enumerate(r["entities"]):
        label = short_input if i == 0 else ""
        print(f"{label:<55} {e['type']:<10} {e['text']}")
    print("─" * 100)

In [None]:
# Relations table
print(f"{'Head':<30} {'Relation':<20} {'Tail'}")
print("═" * 80)
for r in results:
    for rel in r["relations"]:
        print(f"{rel['head']:<30} {rel['relation']:<20} {rel['tail']}")

## 7. Try Your Own

Edit the text below and run the cell to extract entities and relations from your own clinical text.

In [None]:
your_text = "Enter your clinical text here."

# ─── Run extraction ───
result = show(your_text)

## 8. Raw Output Inspection

If you want to see the raw model output before parsing (useful for debugging):

In [None]:
raw = extract("The patient developed acute renal failure after treatment with enalapril.")
print(raw)