In [6]:
import os
import json
import torch
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM
)
from collections import defaultdict
from difflib import SequenceMatcher

# -----------------------------
# CONFIG
# -----------------------------
MODEL_NAME = "facebook/nllb-200-distilled-600M"
DATASET_NAME = "masakhane/AfriDocMT"
CONFIG_NAME = "doc_health"

# Language codes for accessing the dataset
DATASET_SRC_LANG = "en"
DATASET_TGT_LANG = "sw"

# Language codes for NLLB tokenizer and model
NLLB_SRC_LANG = "eng_Latn"
NLLB_TGT_LANG = "swh_Latn"

MAX_DOCS = 50            # keep small for qualitative research
MAX_LENGTH = 512
OUTPUT_DIR = "afridoc_mt_case_studies"
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------------
# ERROR TAXONOMY
# -----------------------------
class MTErrorType:
    OMISSION = "omission"
    REPETITION = "repetition"
    TERMINOLOGY_DRIFT = "terminology_drift"
    INCONSISTENT_ENTITY = "inconsistent_entity"
    LOW_CONFIDENCE = "low_confidence"
    HIGH_CONFIDENCE_ERR = "high_confidence_error"

# -----------------------------
# LOAD MODEL
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    output_attentions=True
).to(device)
model.eval()

# -----------------------------
# LOAD DATA
# -----------------------------
dataset = load_dataset(DATASET_NAME, CONFIG_NAME)["test"]

# -----------------------------
# HELPERS
# -----------------------------
def compute_logprob(output_scores, sequences):
    logprobs = []
    for step_scores, token_id in zip(output_scores, sequences[0][1:]):
        # Correctly access the log probability for the generated token
        logp = torch.log_softmax(step_scores, dim=-1)[0, token_id]
        logprobs.append(logp.item())
    return float(np.mean(logprobs))


def detect_repetition(text):
    tokens = text.split()
    return len(tokens) != len(set(tokens))


def similarity(a, b):
    return SequenceMatcher(None, a, b).ratio()


# -----------------------------
# MAIN ANALYSIS LOOP
# -----------------------------
case_studies = []

for i, example in enumerate(dataset):
    if i >= MAX_DOCS:
        break

    # Use dataset-specific language codes
    src_doc = example[DATASET_SRC_LANG]
    ref_doc = example[DATASET_TGT_LANG]

    # Use NLLB-specific language codes for tokenizer
    tokenizer.src_lang = NLLB_SRC_LANG
    inputs = tokenizer(
        src_doc,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            # Use NLLB-specific language code for forced_bos_token_id
            forced_bos_token_id=tokenizer.convert_tokens_to_ids(f"__{NLLB_TGT_LANG}__"),
            return_dict_in_generate=True,
            output_scores=True
        )

    hyp_doc = tokenizer.decode(
        outputs.sequences[0],
        skip_special_tokens=True
    )

    avg_logprob = compute_logprob(outputs.scores, outputs.sequences)

    error_tags = []

    # Omission
    if len(hyp_doc.split()) < 0.7 * len(ref_doc.split()):
        error_tags.append(MTErrorType.OMISSION)

    # Repetition
    if detect_repetition(hyp_doc):
        error_tags.append(MTErrorType.REPETITION)

    # Low / high confidence
    if avg_logprob < -6:
        error_tags.append(MTErrorType.LOW_CONFIDENCE)
    elif similarity(hyp_doc, ref_doc) < 0.5:
        error_tags.append(MTErrorType.HIGH_CONFIDENCE_ERR)

    if error_tags:
        case_studies.append({
            "doc_id": i,
            "source": src_doc,
            "reference": ref_doc,
            "hypothesis": hyp_doc,
            "avg_logprob": avg_logprob,
            "error_types": error_tags
        })

# -----------------------------
# SAVE RESULTS
# -----------------------------
output_path = os.path.join(
    OUTPUT_DIR,
    "health_swahili_document_errors.json"
)

with open(output_path, "w", encoding="utf-8") as f:
    json.dump(case_studies, f, indent=2, ensure_ascii=False)

print(f"Saved {len(case_studies)} document-level error cases to:")
print(output_path)




Saved 50 document-level error cases to:
afridoc_mt_case_studies/health_swahili_document_errors.json


In [None]:
import os
import json
import torch
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from difflib import SequenceMatcher
from collections import defaultdict
from tqdm import tqdm

# ======================================================
# 1. GLOBAL CONFIGURATION
# ======================================================
MODEL_NAME = "facebook/nllb-200-distilled-600M"
DATASET_NAME = "masakhane/AfriDocMT"

CONFIGS = [
    "tech", "health",
    "doc_tech", "doc_health",
    "doc_tech_5", "doc_health_5",
    "doc_tech_10", "doc_health_10",
    "doc_tech_25", "doc_health_25"
]

# Maps dataset language codes (keys) to NLLB language codes (values)
LANG_MAP = {
    "am": "amh_Ethi",
    "ha": "hau_Latn",
    "sw": "swh_Latn",
    "yo": "yor_Latn",
    "zu": "zul_Latn"
}

DATASET_SRC_LANG = "en" # Dataset source language code
NLLB_SRC_LANG = "eng_Latn" # NLLB source language code

MAX_SAMPLES = 500         # keep small → qualitative research
MAX_LENGTH = 512
OUTPUT_ROOT = "results_complete"

device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUTPUT_ROOT, exist_ok=True)

# ======================================================
# 2. DATA QUALITY ERROR TAXONOMY (MT-SPECIFIC)
# ======================================================
ERROR_TYPES = [
    "OMISSION",
    "REPETITION",
    "TERMINOLOGY_DRIFT",
    "LOW_CONFIDENCE",
    "HIGH_CONFIDENCE_ERROR"
]

# ======================================================
# 3. LOAD MODEL
# ======================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    output_attentions=True,
    output_hidden_states=True
).to(device)
model.eval()

# ======================================================
# 4. HELPER FUNCTIONS
# ======================================================
def similarity(a, b):
    return SequenceMatcher(None, a, b).ratio()

def detect_repetition(text):
    tokens = text.split()
    return len(tokens) > len(set(tokens))

def avg_logprob(scores, seq):
    logps = []
    for s, t in zip(scores, seq[0][1:]):
        # Correctly access the log probability for the generated token
        logp = torch.log_softmax(s, dim=-1)[0, t]
        logps.append(logp.item())
    return float(np.mean(logps))

def extract_terms(text, min_len=4):
    return {w.lower() for w in text.split() if len(w) >= min_len}

# ======================================================
# 5. GRADIENT × INPUT (SOURCE TOKENS)
# ======================================================
def gradient_x_input(model, inputs, target_ids):
    embeddings = model.get_input_embeddings()(inputs["input_ids"])
    embeddings.requires_grad_(True)
    embeddings.retain_grad() # Added to ensure gradients are retained

    outputs = model(
        inputs_embeds=embeddings,
        attention_mask=inputs["attention_mask"],
        decoder_input_ids=target_ids[:, :-1]
    )

    logits = outputs.logits
    loss = logits.max(dim=-1).values.mean()
    loss.backward()

    grads = embeddings.grad
    attributions = (grads * embeddings).sum(dim=-1)
    attributions = attributions.detach().cpu().numpy()[0]

    return attributions

# ======================================================
# 6. MAIN PIPELINE
# ======================================================
for config in CONFIGS:
    print(f"\n┘ CONFIG: {config}")
    dataset = load_dataset(DATASET_NAME, config)["test"]

    # lang: dataset language code (e.g., 'am')
    # nllb_tgt_lang: NLLB language code (e.g., 'amh_Ethi')
    for lang, nllb_tgt_lang in LANG_MAP.items():
        print(f"   → Language: {lang}")

        out_dir = os.path.join(OUTPUT_ROOT, config, lang)
        os.makedirs(out_dir, exist_ok=True)

        case_studies = []
        summary = defaultdict(int)
        doc_term_memory = defaultdict(set)

        for idx, example in enumerate(tqdm(dataset)):
            if idx >= MAX_SAMPLES:
                break

            source = example[DATASET_SRC_LANG] # Use dataset-specific source language code
            reference = example[lang] # Use dataset-specific target language code

            tokenizer.src_lang = NLLB_SRC_LANG # Use NLLB-specific source language code
            inputs = tokenizer(
                source,
                return_tensors="pt",
                truncation=True,
                max_length=MAX_LENGTH
            ).to(device)

            with torch.no_grad():
                generation = model.generate(
                    **inputs,
                    forced_bos_token_id=tokenizer.convert_tokens_to_ids(f"__{nllb_tgt_lang}__"), # Use NLLB-specific target language code
                    return_dict_in_generate=True,
                    output_scores=True,
                    output_attentions=True
                )

            hypothesis = tokenizer.decode(
                generation.sequences[0],
                skip_special_tokens=True
            )

            score = avg_logprob(
                generation.scores,
                generation.sequences
            )

            errors = []

            # -----------------------------
            # ERROR DETECTION HEURISTICS
            # -----------------------------
            if len(hypothesis.split()) < 0.7 * len(reference.split()):
                errors.append("OMISSION")

            if detect_repetition(hypothesis):
                errors.append("REPETITION")

            if score < -6:
                errors.append("LOW_CONFIDENCE")
            elif similarity(hypothesis, reference) < 0.5:
                errors.append("HIGH_CONFIDENCE_ERROR")

            # Terminology drift (document-level memory)
            src_terms = extract_terms(source)
            tgt_terms = extract_terms(hypothesis)

            for term in src_terms:
                if term in doc_term_memory and tgt_terms:
                    if doc_term_memory[term] != tgt_terms:
                        errors.append("TERMINOLOGY_DRIFT")
                doc_term_memory[term] |= tgt_terms

            if not errors:
                continue

            for e in errors:
                summary[e] += 1
            summary["TOTAL_ERRORS"] += 1

            # -----------------------------
            # ATTENTION (decoder → encoder)
            # -----------------------------
            # Stack attention tensors from each head, then average
            # generation.cross_attentions[-1] is a tuple of tensors (one per head)
            cross_attn_per_head = torch.stack(generation.cross_attentions[-1], dim=1)
            cross_attn = cross_attn_per_head.mean(dim=1).mean(dim=1)
            cross_attn = cross_attn.detach().cpu().numpy()[0]

            # -----------------------------
            # GRADIENT ATTRIBUTION
            # -----------------------------
            grad_attr = gradient_x_input(
                model,
                inputs,
                generation.sequences
            )

            tokens = tokenizer.convert_ids_to_tokens(
                inputs["input_ids"][0]
            )

            case_studies.append({
                "id": idx,
                "source": source,
                "reference": reference,
                "hypothesis": hypothesis,
                "avg_logprob": score,
                "error_types": errors,
                "tokens": tokens,
                "gradient_attribution": grad_attr.tolist(),
                "cross_attention": cross_attn.tolist()
            })

        # ==================================================
        # SAVE OUTPUTS
        # ==================================================
        with open(os.path.join(out_dir, "case_studies.json"), "w", encoding="utf-8") as f:
            json.dump(case_studies, f, indent=2, ensure_ascii=False)

        with open(os.path.join(out_dir, "summary.json"), "w") as f:
            json.dump(summary, f, indent=2)

        print(f"     ✔ Saved {len(case_studies)} explainability cases")

print("\n✅ COMPLETE AFRIDOC-MT QUALITY ANALYSIS FINISHED")


tokenizer_config.json:   0%|          | 0.00/564 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/4.85M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/846 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.46G [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['output_attentions', 'output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


model.safetensors:   0%|          | 0.00/2.46G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]


┘ CONFIG: tech


README.md: 0.00B [00:00, ?B/s]

train.csv: 0.00B [00:00, ?B/s]

test.csv: 0.00B [00:00, ?B/s]

dev.csv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/7048 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1982 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/970 [00:00<?, ? examples/s]

   → Language: am


 25%|██▌       | 500/1982 [20:52<1:01:52,  2.51s/it]


     ✔ Saved 500 explainability cases
   → Language: ha


 25%|██▌       | 500/1982 [20:11<59:51,  2.42s/it]


     ✔ Saved 499 explainability cases
   → Language: sw


 25%|██▌       | 500/1982 [20:10<59:49,  2.42s/it]


     ✔ Saved 498 explainability cases
   → Language: yo


 25%|██▌       | 500/1982 [20:09<59:44,  2.42s/it]


     ✔ Saved 498 explainability cases
   → Language: zu


 25%|██▌       | 500/1982 [20:08<59:41,  2.42s/it]


     ✔ Saved 498 explainability cases

┘ CONFIG: health


train.csv: 0.00B [00:00, ?B/s]

test.csv: 0.00B [00:00, ?B/s]

dev.csv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/7041 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1982 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/977 [00:00<?, ? examples/s]

   → Language: am


 25%|██▌       | 500/1982 [22:27<1:06:35,  2.70s/it]


     ✔ Saved 500 explainability cases
   → Language: ha


 25%|██▌       | 500/1982 [20:58<1:02:11,  2.52s/it]


     ✔ Saved 500 explainability cases
   → Language: sw


 25%|██▌       | 500/1982 [20:51<1:01:49,  2.50s/it]


     ✔ Saved 497 explainability cases
   → Language: yo


 25%|██▌       | 500/1982 [21:00<1:02:15,  2.52s/it]


     ✔ Saved 500 explainability cases
   → Language: zu


 25%|██▌       | 500/1982 [21:03<1:02:24,  2.53s/it]


     ✔ Saved 496 explainability cases

┘ CONFIG: doc_tech


train.csv: 0.00B [00:00, ?B/s]

test.csv: 0.00B [00:00, ?B/s]

dev.csv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/187 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/59 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/25 [00:00<?, ? examples/s]

   → Language: am


100%|██████████| 59/59 [03:59<00:00,  4.06s/it]


     ✔ Saved 59 explainability cases
   → Language: ha


100%|██████████| 59/59 [04:00<00:00,  4.07s/it]


     ✔ Saved 59 explainability cases
   → Language: sw


100%|██████████| 59/59 [04:00<00:00,  4.08s/it]


     ✔ Saved 59 explainability cases
   → Language: yo


100%|██████████| 59/59 [04:02<00:00,  4.11s/it]


     ✔ Saved 59 explainability cases
   → Language: zu


100%|██████████| 59/59 [04:00<00:00,  4.08s/it]


     ✔ Saved 59 explainability cases

┘ CONFIG: doc_health


train.csv: 0.00B [00:00, ?B/s]

test.csv: 0.00B [00:00, ?B/s]

dev.csv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/240 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/61 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/33 [00:00<?, ? examples/s]

   → Language: am


100%|██████████| 61/61 [04:10<00:00,  4.10s/it]


     ✔ Saved 61 explainability cases
   → Language: ha


100%|██████████| 61/61 [04:13<00:00,  4.15s/it]


     ✔ Saved 61 explainability cases
   → Language: sw


100%|██████████| 61/61 [04:09<00:00,  4.10s/it]


     ✔ Saved 61 explainability cases
   → Language: yo


100%|██████████| 61/61 [04:10<00:00,  4.10s/it]


     ✔ Saved 61 explainability cases
   → Language: zu


100%|██████████| 61/61 [04:10<00:00,  4.11s/it]


     ✔ Saved 61 explainability cases

┘ CONFIG: doc_tech_5


train.csv: 0.00B [00:00, ?B/s]

test.csv: 0.00B [00:00, ?B/s]

dev.csv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/1483 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/418 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/204 [00:00<?, ? examples/s]

   → Language: am


 61%|██████    | 255/418 [14:02<05:06,  1.88s/it]