# Tasks 2, 3, 5

Extract parties, agreement types, and international organizations from Alabama agreements.


In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import glob
import os
import warnings
import json
import re
from pathlib import Path
from datetime import datetime
import time
import tracemalloc
import resource
import statistics
import atexit

warnings.filterwarnings('ignore')


KeyboardInterrupt: 

## Initialize models

In [None]:
mistral_model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
mistral_file = "mistral-7b-instruct-v0.2.Q4_K_M.gguf"
mistral_path = hf_hub_download(repo_id=mistral_model_name, filename=mistral_file)
llm = Llama(model_path=mistral_path, n_gpu_layers=-1, n_ctx=2048, verbose=False)


In [None]:
classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli",
    device=0 if torch.cuda.is_available() else -1
)


In [None]:
model_name = "dandoune/legal-NER"
tokenizer = AutoTokenizer.from_pretrained(model_name)
ner_model = AutoModelForTokenClassification.from_pretrained(model_name)


## Helper functions

- fetching all .txt files within one directory
- saving results 
- time/memory profiling

In [None]:
def find_txt_files(directory):
    path_pattern = os.path.join(directory, "*.txt")
    return sorted(glob.glob(path_pattern))

# Resolve paths robustly (works when running from repo root or from NLP_2025W)
PROJECT_DIR = Path.cwd()
if not (PROJECT_DIR / "data").exists():
    PROJECT_DIR = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
    if not (PROJECT_DIR / "data").exists():
        PROJECT_DIR = Path.cwd().parent

DATA_DIR = PROJECT_DIR / "data" / "Alabama"
OUTPUT_DIR = PROJECT_DIR / "outputs" / "tasks_2_3_5"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")


def _write_json(path: Path, obj):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)


def _write_jsonl(path: Path, rows):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


def _try_parse_json_from_llm(text: str):
    if text is None:
        return None
    s = text.strip()
    m = re.search(r"\{.*\}", s, flags=re.DOTALL)
    if not m:
        return None
    candidate = m.group(0)
    try:
        return json.loads(candidate)
    except Exception:
        return None


def _max_rss_mb() -> float:
    return float(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) / 1024.0


def _profile_start() -> dict:
    tracemalloc.start()
    return {
        "t0": time.perf_counter(),
        "rss0_mb": _max_rss_mb(),
    }


def _profile_end(start: dict) -> dict:
    elapsed = time.perf_counter() - start["t0"]
    _, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    rss1 = _max_rss_mb()
    return {
        "seconds": float(elapsed),
        "py_peak_mb": float(peak) / 1024.0 / 1024.0,
        "max_rss_mb": float(rss1),
        "max_rss_delta_mb": float(rss1 - start["rss0_mb"]),
    }


def _summarize_times(times_s: list[float]) -> dict:
    if not times_s:
        return {"count": 0}
    return {
        "count": int(len(times_s)),
        "mean_s": float(statistics.mean(times_s)),
        "median_s": float(statistics.median(times_s)),
        "min_s": float(min(times_s)),
        "max_s": float(max(times_s)),
    }


PROFILE = {
    "run_id": RUN_ID,
    "data_dir": str(DATA_DIR),
    "output_dir": str(OUTPUT_DIR),
    "methods": {},
}

PROFILE_PATH = OUTPUT_DIR / f"profiling_tasks_2_3_5_{RUN_ID}.json"


def _save_profile():
    try:
        _write_json(PROFILE_PATH, PROFILE)
    except Exception:
        pass


atexit.register(_save_profile)


## Task 2 - extract agreements parties

### 2.1 Legal NER for parties extraction

In [None]:
def extract_entities(text, model, tokenizer, max_length=512):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_length)

    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=2)[0].cpu().tolist()

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    label_map = model.config.id2label
    labels = [label_map.get(str(pred), label_map.get(pred, f"LABEL_{pred}")) for pred in predictions]

    entities = []
    current_tokens = []
    current_label = None
    print({token: label for token, label in zip(tokens, labels)})
    for token, label in zip(tokens, labels):
        if token in ['[CLS]', '[SEP]', '[PAD]']:
            if current_tokens and current_label and current_label != 'LABEL_8' and current_label != 'O':
                entity_text = tokenizer.convert_tokens_to_string(current_tokens)
                if entity_text.strip() and not entity_text.strip().startswith('##'):
                    entities.append((entity_text.strip(), current_label))
            current_tokens = []
            current_label = None
            continue

        if label == 'LABEL_8' or label == 'O':
            if current_tokens and current_label:
                entity_text = tokenizer.convert_tokens_to_string(current_tokens)
                if entity_text.strip() and not entity_text.strip().startswith('##'):
                    entities.append((entity_text.strip(), current_label))
            current_tokens = []
            current_label = None
        elif label == current_label:
            current_tokens.append(token)
        else:
            if current_tokens and current_label:
                entity_text = tokenizer.convert_tokens_to_string(current_tokens)
                if entity_text.strip() and not entity_text.strip().startswith('##'):
                    entities.append((entity_text.strip(), current_label))
            current_tokens = [token]
            current_label = label

    if current_tokens and current_label and current_label != 'LABEL_8' and current_label != 'O':
        entity_text = tokenizer.convert_tokens_to_string(current_tokens)
        if entity_text.strip() and not entity_text.strip().startswith('##'):
            entities.append((entity_text.strip(), current_label))

    return entities

_profile = _profile_start()
with open("/home/ubuntu/nlp/data/Alabama/Alabama_3.txt", "r", encoding="utf-8") as f:
    alabama_text = f.read()

ner_entities = extract_entities(alabama_text, ner_model, tokenizer)
profile_stats = _profile_end(_profile)

PROFILE["methods"]["task2_ner_single"] = {
    "input": {"filename": "Alabama_3.txt", "chars": len(alabama_text)},
    "performance": profile_stats,
    "output": {"n_entities": len(ner_entities)},
}
_save_profile()

_write_json(
    OUTPUT_DIR / f"task2_parties_ner_single_{RUN_ID}.json",
    {
        "filename": "Alabama_3.txt",
        "entities": [{"text": t, "label": l} for t, l in ner_entities],
    },
)


In [None]:
from collections import defaultdict

entities_by_label = defaultdict(list)
for entity_text, label in ner_entities:
    if len(entity_text.strip()) > 2:
        entities_by_label[label].append(entity_text.strip())

print(f"Total entities found: {len(ner_entities)}\n")
for label in sorted(entities_by_label.keys()):
    unique_entities = sorted(set(entities_by_label[label]))
    print(f"\n{label}:")
    for entity in unique_entities[:10]:
        print(f"{entity}")



#### Parties extraction results for legal NER model

Several NER models have been tested but the task seems to be too complex for this approach. Agreement parties have often very long names and contain other country/institution/person name in it. This leads to not recognizing the full name of the agreement party. 

NER models finetuned for legal purposes were hard to find though - often publicly available sources were broken. More trusted models could could turn out better.

### 2.2 Prompting Mistral model for agreement parties

In [None]:
def extract_parties(text_chunk):
    prompt = f"""[INST] You are a legal AI. Extract the contracting parties from the text below and only the direct agreement parties.
    Return ONLY a JSON object with keys: "parties" (list of objects with "name", "role").
    
    TEXT:
    {text_chunk}
    [/INST]"""
    output = llm(prompt, max_tokens=300, temperature=0.1, stop=["[/INST]"])
    return output['choices'][0]['text']


In [None]:
txt_files = find_txt_files(str(DATA_DIR))
parties_results = []
per_doc_times = []

_profile = _profile_start()
for txt_file in txt_files:
    with open(txt_file, "r", encoding="utf-8") as f:
        text = f.read(1500)
    t0 = time.perf_counter()
    result = extract_parties(text)
    per_doc_times.append(time.perf_counter() - t0)
    parties_results.append({
        'filename': os.path.basename(txt_file),
        'parties': result
    })
profile_stats = _profile_end(_profile)

parties_results_serialized = []
for r in parties_results:
    parties_results_serialized.append({
        **r,
        "parsed": _try_parse_json_from_llm(r.get("parties")),
    })

PROFILE["methods"]["task2_parties_mistral"] = {
    "input": {"n_files": len(txt_files), "chars_per_file": 1500},
    "performance": {
        **profile_stats,
        "per_doc": _summarize_times(per_doc_times),
    },
    "output": {
        "n_rows": len(parties_results_serialized),
        "parse_ok": int(sum(1 for r in parties_results_serialized if isinstance(r.get("parsed"), dict))),
    },
}
_save_profile()

_write_json(OUTPUT_DIR / f"task2_parties_mistral_{RUN_ID}.json", parties_results_serialized)


#### Parties extraction results for Mistral model

The following results show all contracting parties extracted from each agreement using the Mistral 7B model specifically fine-tuned to follow user instructions, and generate structured text. Approprietly prompted model returned correct parties of the agreement. In more complex scenarios errors occured - e.g. canadian provinces/USA states listed as members of one party were treated as separate agreement party. Additional "role" field explains the entity role in the agreement.


In [None]:
for result in parties_results:
    print(f"\n{'='*80}")
    print(f"File: {result['filename']}")
    print(f"{'='*80}")
    print(result['parties'])


## Task 3 - agreement type classification

### 3.1 Zero shot classification with BART model

Some of the most popular international/interstate agreements type were given as classes.

In [None]:
agreement_types = [
    "Economic Cooperation Agreement",
    "Trade Agreement",
    "Investment Agreement",
    "Memorandum of Understanding (MOU)",
    "Partnership Agreement",
    "Educational Exchange Agreement",
    "Cultural Exchange Agreement",
    "Technology Transfer Agreement",
    "Waterway Development Agreement",
    "Interstate Compact",
    "Sister State Agreement",
    "Sister City Agreement",
    "Development Cooperation Agreement"
]

def classify_agreement_type(text, max_length=512):
    if len(text) > max_length * 4:
        header = text[:max_length * 2]
        footer = text[-max_length * 2:]
        text_to_classify = header + "\n\n" + footer
    else:
        text_to_classify = text
    return classifier(text_to_classify, agreement_types, multi_label=False)


In [None]:
classification_results = []
txt_files = find_txt_files(str(DATA_DIR))
per_doc_times = []

_profile = _profile_start()
for txt_file in txt_files:
    with open(txt_file, "r", encoding="utf-8") as f:
        text = f.read()
    t0 = time.perf_counter()
    result = classify_agreement_type(text)
    per_doc_times.append(time.perf_counter() - t0)
    classification_results.append({
        'filename': os.path.basename(txt_file),
        'predicted_type': result['labels'][0],
        'confidence': result['scores'][0],
        'top_3': list(zip(result['labels'][:3], result['scores'][:3]))
    })
profile_stats = _profile_end(_profile)

PROFILE["methods"]["task3_type_bart"] = {
    "input": {"n_files": len(txt_files), "chars_per_file": None},
    "performance": {
        **profile_stats,
        "per_doc": _summarize_times(per_doc_times),
    },
    "output": {"n_rows": len(classification_results)},
}
_save_profile()

_write_json(OUTPUT_DIR / f"task3_type_bart_{RUN_ID}.json", classification_results)


#### Agreement Type Classification Results

The following results show the agreement type classification for each document using zero-shot classification. An expert knowledge is needed to verify the classification because some agreements do not have the goal clearly listed. There are several Memorandas of Understanding though and they were recognized.

In [None]:
from collections import Counter

agreement_type_counts = Counter([r['predicted_type'] for r in classification_results])
print(f"Total files classified: {len(classification_results)}\n")
print("Agreement type distribution:")
print("-" * 80)
for agreement_type, count in agreement_type_counts.most_common():
    print(f"  {agreement_type:50} : {count:2} files")

print("\n" + "=" * 80)
print("Detailed Results:")
print("=" * 80)
for result in classification_results:
    print(f"\n{result['filename']}")
    print(f"  Type: {result['predicted_type']}")
    print(f"  Confidence: {result['confidence']:.3f}")
    print(f"  Top 3 predictions:")
    for label, score in result['top_3']:
        print(f"    - {label}: {score:.3f}")


### 3.2 Prompting Mistral model for agreement type

In [None]:
def extract_agreement_type_mistral(text_chunk):
    prompt = f"""[INST] You are a legal AI. Analyze the legal agreement text below and identify type of agreement. Return only one type of agreement.
    Return ONLY a JSON object with keys: "agreement_type" (the specific type of agreement, e.g., "Economic Cooperation Agreement", "Trade Agreement", "Memorandum of Understanding", etc.) and "description" (a brief description of what the agreement is about).
    
    TEXT:
    {text_chunk}
    [/INST]"""
    output = llm(prompt, max_tokens=200, temperature=0.1, stop=["[/INST]"])
    return output['choices'][0]['text']


In [None]:
agreement_type_results = []
txt_files = find_txt_files(str(DATA_DIR))
per_doc_times = []

_profile = _profile_start()
for txt_file in txt_files:
    with open(txt_file, "r", encoding="utf-8") as f:
        text = f.read(2000)
    t0 = time.perf_counter()
    result = extract_agreement_type_mistral(text)
    per_doc_times.append(time.perf_counter() - t0)
    agreement_type_results.append({
        'filename': os.path.basename(txt_file),
        'extraction': result
    })
profile_stats = _profile_end(_profile)

agreement_type_results_serialized = []
for r in agreement_type_results:
    agreement_type_results_serialized.append({
        **r,
        "parsed": _try_parse_json_from_llm(r.get("extraction")),
    })

PROFILE["methods"]["task3_type_mistral"] = {
    "input": {"n_files": len(txt_files), "chars_per_file": 2000},
    "performance": {
        **profile_stats,
        "per_doc": _summarize_times(per_doc_times),
    },
    "output": {
        "n_rows": len(agreement_type_results_serialized),
        "parse_ok": int(sum(1 for r in agreement_type_results_serialized if isinstance(r.get("parsed"), dict))),
    },
}
_save_profile()

_write_json(OUTPUT_DIR / f"task3_type_mistral_{RUN_ID}.json", agreement_type_results_serialized)


#### Agreement type extraction results (Mistral)

The following results show the agreement types extracted directly from the text. Unlike classification, this approach allows  to identify agreement types without predefined categories. Surprisingly - almost all of the agreements have been recognized as Memoranda of Understanding that was actually present in the class subset for classification. 

In [None]:
print(f"Total files processed: {len(agreement_type_results)}\n")
print("=" * 80)
for result in agreement_type_results:
    print(f"\n{result['filename']}")
    print("-" * 80)
    print(result['extraction'])


## Task 5 - Extracting international organizations

This task was supposed to be solved based on NER results but they turned out to be poor just like in Task 2. Mistral model has been utilized again to find all such entities. Model has been instructed to ignore domestic/local entities and explain context in which the organization appears in the agreement in the "type" field.

In [None]:
def extract_international_orgs(text_chunk):
    prompt = f"""[INST] You are a legal AI. From the legal agreement text below, extract all INTERNATIONAL ORGANIZATIONS mentioned.
    
    - Focus on organizations that operate across national borders (e.g., international organizations, intergovernmental bodies, multinational commissions).
    - Ignore purely national ministries, state agencies, or local entities unless they are clearly part of an international body.
    - Return ONLY a JSON object with the key "international_organizations" which is a list of objects with fields:
        - "name": the full official name as it appears or can be reasonably completed
        - "type": short description (e.g., "international organization", "intergovernmental commission", "binational authority").
    
    TEXT:
    {text_chunk}
    [/INST]"""
    output = llm(prompt, max_tokens=300, temperature=0.1, stop=["[/INST]"])
    return output['choices'][0]['text']


In [None]:
intl_orgs_results = []
txt_files = find_txt_files(str(DATA_DIR))
per_doc_times = []

_profile = _profile_start()
for txt_file in txt_files:
    with open(txt_file, "r", encoding="utf-8") as f:
        text = f.read()
    t0 = time.perf_counter()
    result = extract_international_orgs(text[:2000])
    per_doc_times.append(time.perf_counter() - t0)
    intl_orgs_results.append({
        'filename': os.path.basename(txt_file),
        'extraction': result
    })
profile_stats = _profile_end(_profile)

intl_orgs_results_serialized = []
for r in intl_orgs_results:
    intl_orgs_results_serialized.append({
        **r,
        "parsed": _try_parse_json_from_llm(r.get("extraction")),
    })

PROFILE["methods"]["task5_orgs_mistral"] = {
    "input": {"n_files": len(txt_files), "chars_per_file": 2000},
    "performance": {
        **profile_stats,
        "per_doc": _summarize_times(per_doc_times),
    },
    "output": {
        "n_rows": len(intl_orgs_results_serialized),
        "parse_ok": int(sum(1 for r in intl_orgs_results_serialized if isinstance(r.get("parsed"), dict))),
    },
}
_save_profile()

_write_json(OUTPUT_DIR / f"task5_international_orgs_mistral_{RUN_ID}.json", intl_orgs_results_serialized)


In [None]:
intl_orgs_results_serialized

#### International organizations extraction results

The following results show all international or supranational organizations mentioned in each agreement. These are organizations that operate across national borders, such as intergovernmental bodies, multinational commissions, or international organizations.

In [None]:
# Save profiling summary
_save_profile()

# Print a compact profiling table
rows = []
for name, rec in PROFILE.get("methods", {}).items():
    perf = rec.get("performance", {})
    per_doc = perf.get("per_doc", {}) if isinstance(perf, dict) else {}
    rows.append({
        "method": name,
        "total_s": perf.get("seconds"),
        "per_doc_mean_s": per_doc.get("mean_s"),
        "per_doc_median_s": per_doc.get("median_s"),
        "py_peak_mb": perf.get("py_peak_mb"),
        "max_rss_mb": perf.get("max_rss_mb"),
    })

try:
    import pandas as pd

    df_profile = pd.DataFrame(rows).sort_values("total_s", ascending=False)
    display(df_profile)
    print(f"Profiling saved to: {PROFILE_PATH}")
except Exception:
    print(rows)
    print(f"Profiling saved to: {PROFILE_PATH}")


In [None]:
for result in intl_orgs_results:
    print(f"\n{result['filename']}")
    print("-" * 80)
    print(result['extraction'])
