# Masked Clause Input Task – Structured Legal Text Completion in Greek
### Introduction
In this notebook, we tackle a specialized Masked Language Modeling (MLM) task within the legal domain, specifically focusing on Masked Clause Input. This involves completing missing words from legal clauses written in Greek, where the ultimate goal is to **evaluate open-source language** models based on their **semantic fidelity** to the Greek Civil Code (Αστικός Κώδικας).

---

### Task Description
We are provided with incomplete legal clauses where one or more words are masked. Our goal is to:
- Predict the missing words in Greek
- Use open-source language models 
- Evaluate and compare their performance
This task is classified under **Masked Clause Input**, a variant of classic masked language modeling with emphasis on clause-level legal text.

---

### Models and Papers Used

- [BERT Multilingual Cased](https://huggingface.co/google-bert/bert-base-multilingual-cased)  
  [Paper: BERT – Devlin et al. (2018)](https://arxiv.org/pdf/1810.04805)

- [DistilBERT Multilingual Cased](https://huggingface.co/distilbert/distilbert-base-multilingual-cased)  
  [Paper: DistilBERT – Sanh et al. (2019)](https://arxiv.org/pdf/1910.01108)

- [XLM-RoBERTa Base](https://huggingface.co/FacebookAI/xlm-roberta-base)  
  [Paper: XLM-R – Conneau et al. (2019)](https://arxiv.org/pdf/1911.02116)

- [InfoXLM Base](https://huggingface.co/microsoft/infoxlm-base)  
  [Paper: InfoXLM – Chi et al. (2020)](https://arxiv.org/pdf/2007.07834)

- [GreekBERT (NLPAUEB)](https://huggingface.co/nlpaueb/bert-base-greek-uncased-v1)  
  [Paper: Koutsikakis et al. (2020)](https://arxiv.org/pdf/2008.12014)

- [GreekSocialBERT](https://huggingface.co/gealexandri/greeksocialbert-base-greek-uncased-v1)  
  [Paper: Alexandri et al. (2021)](https://www.mdpi.com/2078-2489/12/8/331)

- [GreekLegalRoBERTa v3](https://huggingface.co/AI-team-UoA/GreekLegalRoBERTa_v3)  
  [Paper: Chronopoulou et al. (2024)](https://arxiv.org/pdf/2410.12852)

- [Llama-Krikri 8B](https://huggingface.co/ilsp/Llama-Krikri-8B-Base)  
  [Paper: ILSP Krikri LLaMA (2024)](https://arxiv.org/pdf/2502.01534)

- [mT5 Base](https://huggingface.co/google/mt5-base)  
  [Paper: Xue et al. (2020)](https://arxiv.org/pdf/2010.11934)

- [Meltemi-7B v1.5](https://huggingface.co/ilsp/Meltemi-7B-v1.5)  
  [Paper: ILSP Meltemi (2024)](https://arxiv.org/pdf/2407.20743)

---

## Evaluation Metrics for Masked Token Prediction in Greek Legal Clauses

In evaluating masked token predictions on Greek Civil Code clauses, we focus on two aspects:
1. **Exact-match accuracy** (did the model restore the exact original legal term?), and  
2. **Semantic alignment** (does the clause retain the same meaning even if a different token is used?).  

Below we adapt standard NLP metrics to this task, emphasizing precision in legal terminology and overall clause meaning. All metrics are computed only over the **masked tokens**, since the unmasked context remains unchanged. In legal texts, small wording differences can alter interpretation, so our definitions consider the strict requirements of legal language while also allowing measures of semantic similarity for acceptable paraphrases.

---

### Token-Level Precision, Recall & F₁ (Masked-Token Exact Match)

We treat each masked token prediction as an attempt to recover the ground-truth legal word. Define:  
- $TP$ = number of **correctly predicted** tokens (exact match with the ground truth).  
- $FP$ = number of **incorrect predictions** (predicted token ≠ ground truth).  
- $FN$ = number of ground-truth tokens the model failed to predict correctly.  

Then:

- **Precision**  
  $$
    P \;=\; \frac{TP}{TP + FP}
  $$

- **Recall**  
  $$
    R \;=\; \frac{TP}{TP + FN}
  $$

- **F₁ Score**  
  $$
    F_{1} \;=\; 2 \cdot \frac{P \cdot R}{P + R}
  $$

Because each mask corresponds to exactly one ground-truth token, $FP = FN$ whenever a single token is predicted per mask. In that scenario, precision $=$ recall $=$ exact-match accuracy, and thus $F_1 = P = R$. However, if multiple valid synonyms or morphological variants are allowed (rare in legal evaluation), recall will reflect coverage of all acceptable predictions.  

> **Legal Note:** In legal text, **exact matches are typically required**—a predicted token is correct only if it exactly matches the ground truth term (e.g.\ “σύμβαση” vs. “συμφωνία” would be considered different, even if semantically related). For strict legal evaluation, we count only verbatim matches.  
>  
> *See Rajpurkar et al. (2016) for token-level F₁ in QA benchmarks*  
> [Paper: SQuAD (Rajpurkar et al., 2016)](https://arxiv.org/abs/1606.05250)  
> *See Chalkidis et al. (2020) for LegalBERT exact-match evaluation in legal NLP*  
> [Paper: LegalBERT (Chalkidis et al., 2020)](https://arxiv.org/abs/2010.02559)  

---

### Word Error Rate (WER) for Masked-Token Errors

**WER** measures the normalized edit distance between the **predicted clause** (with filled-in tokens) and the **ground truth clause**. Let:  
- $S$ = number of substitutions (wrong token for the correct one),  
- $D$ = number of deletions (ground truth token not predicted),  
- $I$ = number of insertions (extra token added),  
- $N$ = total number of words in the ground truth clause.  

Then:
$$
  \mathrm{WER} \;=\; \frac{S + D + I}{N}
$$

In masked token prediction, the clause length usually remains the same (one predicted token per mask). Thus, the primary errors are substitutions. For example, if 3 tokens were masked and the model got 1 wrong,  
$$
  \mathrm{WER} \;=\; \frac{1}{3} \approx 0.33.
$$
A WER of 0 indicates all masked tokens were correct; WER = 1.0 indicates every token in the clause was in error.  

> *See WER definition and use in speech/MT evaluation*  
> [Paper: Word Error Rate](https://www.researchgate.net/publication/271429169_Word_error_rates)  

---

### BERTScore for Masked Token Semantic Similarity

**BERTScore** uses contextual embeddings to compare the predicted clause and the ground truth clause on a token level. Instead of strict string matches, it finds **embedding-based** matches, giving partial credit if the predicted token is semantically close. Let $\mathbf{e}(t)$ be the contextual embedding of token $t$. Define:  

- $\lvert \mathrm{cand}\rvert$ = number of tokens in the **predicted clause** (usually same length as reference).  
- $\lvert \mathrm{ref}\rvert$ = number of tokens in the **ground truth clause**.  

Compute:

- **Precision₍BERT₎**  
  $$
    P_{\mathrm{BERT}} 
    = \frac{1}{\lvert \mathrm{cand} \rvert}
      \sum_{j=1}^{\lvert \mathrm{cand} \rvert}
      \max_{\,i\,}\;\cos\bigl(\mathbf{e}(\mathrm{ref}_i),\,\mathbf{e}(\mathrm{cand}_j)\bigr)
  $$

- **Recall₍BERT₎**  
  $$
    R_{\mathrm{BERT}} 
    = \frac{1}{\lvert \mathrm{ref} \rvert}
      \sum_{i=1}^{\lvert \mathrm{ref} \rvert}
      \max_{\,j\,}\;\cos\bigl(\mathbf{e}(\mathrm{ref}_i),\,\mathbf{e}(\mathrm{cand}_j)\bigr)
  $$

- **F₁₍BERT₎**  
  $$
    F_{\mathrm{BERT}} 
    = 2 \cdot \frac{P_{\mathrm{BERT}}\;\,R_{\mathrm{BERT}}}
                         {P_{\mathrm{BERT}} + R_{\mathrm{BERT}}}
  $$

Since most tokens in the clause are identical except for the masked ones, BERTScore focuses on how the **predicted tokens** align with the **ground truth tokens** in embedding space. If the model predicts a legal synonym or a morphologically close variant (e.g.\ “δικαίωμα” vs. “δικαιώματος”), BERTScore will yield a high similarity even if the exact surface form differs.  

> *See Zhang et al. (2019) for original BERTScore formulation*  
> [Paper: BERTScore (Zhang et al., 2019)](https://arxiv.org/abs/1904.09675)  

---

### Sentence Embedding Cosine Similarity (Clause-Level Semantic Alignment)

To evaluate the **overall meaning** of the filled clause, we can encode entire clauses into fixed vectors using **Sentence-BERT (SBERT)** or **Universal Sentence Encoder (USE)**, then compute cosine similarity between the **original clause** and the **predicted clause**. Formally, let $E(\cdot)$ map a sentence to a dense vector. For the ground truth clause $r$ and predicted clause $c$:
$$
  \mathrm{CosSim}_{\mathrm{sent}}(c, r)
  = \frac{E(c)\;\cdot\;E(r)}
         {\,\lVert E(c)\rVert \;\lVert E(r)\rVert\,}.
$$
A score near 1.0 means the two clauses are nearly identical in meaning, whereas a low score indicates the substituted token altered the legal proposition.  

> *See Reimers & Gurevych (2019) for Sentence-BERT and STS evaluation*  
> [Paper: SBERT (Reimers & Gurevych, 2019)](https://arxiv.org/abs/1908.10084)  

---

### TF-IDF Cosine Similarity (Lexical Overlap Baseline)

As a simpler, lexicon-based baseline, represent each clause as a **TF-IDF-weighted bag-of-words vector**. Let $\mathbf{v}_c$ and $\mathbf{v}_r$ be the TF-IDF vectors for the **predicted** and **ground truth** clauses, respectively. Compute:
$$
  \cos\bigl(\mathbf{v}_c,\mathbf{v}_r\bigr)
  = \frac{\mathbf{v}_c \;\cdot\;\mathbf{v}_r}
         {\,\lVert \mathbf{v}_c\rVert \;\lVert \mathbf{v}_r\rVert\,}.
$$
If the model predicts the **exact** token, the vectors will match (cosine = 1). Any deviation in the masked token creates a gap in the corresponding TF-IDF dimension, reducing cosine similarity. Because the masked tokens are often high-IDF legal terms, TF-IDF cosine is a conservative measure: only nearly exact or highly overlapping clauses score close to 1.  

> *See Salton et al. (1975) for the Vector Space Model and TF-IDF cosine*  
> [Paper: Vector Space Model (Salton et al., 1975)](https://doi.org/10.1145/361219.361220)  

---

### Text to be masked

```plaintext
Άρθρο 1113. Κοινό πράγμα. — Αν η κυριότητα του [MASK] ανήκει σε περισσότερους  
[MASK] αδιαιρέτου κατ΄ιδανικά [MASK], εφαρμόζονται οι διατάξεις για την κοινωνία.

Άρθρο 1114. Πραγματική δουλεία σε [MASK] η υπέρ του κοινού ακινήτου. — Στο κοινό  
[MASK] μπορεί να συσταθεί πραγματική δουλεία υπέρ του [MASK] κύριου άλλου ακινήτου  
και αν ακόμη αυτός είναι [MASK] του ακινήτου που βαρύνεται με τη δουλεία. Το ίδιο ισχύει  
και για την [MASK] δουλεία πάνω σε ακίνητο υπέρ των εκάστοτε κυρίων κοινού ακινήτου,  
αν [MASK] από αυτούς είναι κύριος του [MASK] που βαρύνεται με τη δουλεία.
```
#### Ground Truth
```plaintext
Άρθρο 1113. Κοινό πράγμα: Αν η κυριότητα του πράγματος ανήκει σε περισσότερους  
εξ αδιαιρέτου κατ' ιδανικά μέρη, εφαρμόζονται οι διατάξεις για την κοινωνία.

Άρθρο 1114. Πραγματική δουλεία σε βάρος ή υπέρ του κοινού ακινήτου: Στο κοινό  
ακίνητο μπορεί να συσταθεί πραγματική δουλεία υπέρ του εκάστοτε κυρίου άλλου ακινήτου  
και αν ακόμη αυτός είναι συγκύριος του ακινήτου που βαρύνεται με τη δουλεία. Το ίδιο ισχύει  
και για πραγματική δουλεία πάνω σε ακίνητο υπέρ των εκάστοτε κυρίων κοινού ακινήτου,  
αν κάποιος από αυτούς είναι κύριος του ακινήτου που βαρύνεται με τη δουλεία.
```

In [None]:
# Import necessary libraries
!pip install bert-score
!pip install --upgrade transformers  accelerate bitsandbytes -q
import os
from difflib import SequenceMatcher
import re
import torch
import bitsandbytes as bnb
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, pipeline
from bert_score import score as bert_score
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import accelerate
import numpy as np
import dotenv
import json
# Load environment variables
dotenv.load_dotenv()
hf_token = os.getenv("HF_TOKEN")

# Define models and their types
models = {
    "ilsp/Llama-Krikri-8B-Base": "clm",
    "google/mt5-base": "enc-dec",
    "google-bert/bert-base-multilingual-cased": "mlm",
    "distilbert/distilbert-base-multilingual-cased": "mlm",
    "FacebookAI/xlm-roberta-base": "mlm",
    "microsoft/infoxlm-base": "mlm",
    "nlpaueb/bert-base-greek-uncased-v1": "mlm",
    "gealexandri/greeksocialbert-base-greek-uncased-v1": "mlm",
    "AI-team-UoA/GreekLegalRoBERTa_v3": "mlm",
    "ilsp/Meltemi-7B-v1.5": "clm"
}

# Prepare masked texts and ground truths
examples = [
    {
        "masked_text": "Άρθρο 1113. Κοινό πράγμα. — Αν η κυριότητα του [MASK] ανήκει σε περισσότερους [MASK] αδιαιρέτου κατ΄ιδανικά [MASK], εφαρμόζονται οι διατάξεις για την κοινωνία.",
        "ground_truth": "Άρθρο 1113. Κοινό πράγμα: Αν η κυριότητα του πράγματος ανήκει σε περισσότερους εξ αδιαιρέτου κατ' ιδανικά μέρη, εφαρμόζονται οι διατάξεις για την κοινωνία."
    },
    {
        "masked_text": "Άρθρο 1114. Πραγματική δουλεία σε [MASK] ή υπέρ του κοινού ακινήτου. — Στο κοινό [MASK] μπορεί να συσταθεί πραγματική δουλεία υπέρ του [MASK] κύριου άλλου ακινήτου και αν ακόμη αυτός είναι [MASK] του ακινήτου που βαρύνεται με τη δουλεία. Το ίδιο ισχύει και για την [MASK] δουλεία πάνω σε ακίνητο υπέρ των εκάστοτε κυρίων κοινού ακινήτου, αν [MASK] από αυτούς είναι κύριος του [MASK] που βαρύνεται με τη δουλεία.",
        "ground_truth": "Άρθρο 1114. Πραγματική δουλεία σε βάρος ή υπέρ του κοινού ακινήτου: Στο κοινό ακίνητο μπορεί να συσταθεί πραγματική δουλεία υπέρ του εκάστοτε κύριου άλλου ακινήτου και αν ακόμη αυτός είναι συγκύριος του ακινήτου που βαρύνεται με τη δουλεία. Το ίδιο ισχύει και για πραγματική δουλεία πάνω σε ακίνητο υπέρ των εκάστοτε κυρίων κοινού ακινήτου, αν κάποιος από αυτούς είναι κύριος του ακινήτου που βαρύνεται με τη δουλεία."
    }
]


In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from accelerate import infer_auto_device_map
from accelerate import init_empty_weights
from transformers import LlamaConfig, AutoConfig

# Function to normalize text (standardize punctuation)
def normalize_text(text):
    # Replace various punctuation with a standard space or colon
    text = re.sub(r'[.:;—]', ' ', text)
    # Collapse multiple spaces
    text = re.sub(r'\s+', ' ', text)
    return text.strip()



# Function to get masked positions and ground truth words
def get_masked_positions(masked_text, ground_truth):
    # Normalize texts to avoid punctuation issues
    masked_text_norm = normalize_text(masked_text)
    ground_truth_norm = normalize_text(ground_truth)
    
    masked_words = masked_text_norm.split()
    gt_words = ground_truth_norm.split()
    
    # Find [MASK] positions and align with ground truth
    masked_positions = []
    mask_count = masked_text.count("[MASK]")
    if mask_count == 0:
        return masked_positions
    
    # Use SequenceMatcher to align sequences, allowing for multi-word matches
    matcher = SequenceMatcher(None, masked_words, gt_words)
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == 'replace' and i2 - i1 == 1 and masked_words[i1] == '[MASK]':
            # [MASK] corresponds to ground truth words from j1 to j2
            masked_positions.append((i1, ' '.join(gt_words[j1:j2])))
    
    return masked_positions

# Prediction functions for different model types
def predict_masks_mlm(model, tokenizer, masked_text):
    inputs = tokenizer(masked_text, return_tensors="pt")
    mask_token_index = torch.where(inputs["input_ids"][0] == tokenizer.mask_token_id)[0]
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits
    predicted_text = masked_text
    for idx in mask_token_index:
        predicted_id = torch.argmax(predictions[0, idx.item()]).item()
        predicted_token = tokenizer.decode([predicted_id]).strip()
        predicted_text = predicted_text.replace("[MASK]", predicted_token, 1)
    return predicted_text

def predict_masks_clm(model, tokenizer, masked_text):
    parts = masked_text.split("[MASK]")
    current_text = parts[0]
    
    for i in range(1, len(parts)):
        # 1. Tokenize without moving to any device
        inputs = tokenizer(current_text, return_tensors="pt")
        
        # 2. Identify the device of the first layer
        first_device = list(model.hf_device_map.values())[0]
        
        # 3. Move each input tensor to that device
        inputs = {k: v.to(first_device) for k, v in inputs.items()}
        
        # 4. Run the forward pass (Accelerate will move weights/activations as needed)
        with torch.no_grad():
            outputs = model(**inputs)
            # 5. Get top prediction for the [MASK] position
            next_token_logits = outputs.logits[0, -1]
            predicted_id = torch.argmax(next_token_logits).item()
            predicted_token = tokenizer.decode([predicted_id]).strip()
        
        current_text += predicted_token + parts[i]
    
    return current_text
#def predict_masks_clm(model, tokenizer, masked_text):
#    parts = masked_text.split("[MASK]")
#    current_text = parts[0]
#    for i in range(1, len(parts)):
#        inputs = tokenizer(current_text, return_tensors="pt")
#        with torch.no_grad():
#            outputs = model(**inputs)
#            # Allow multiple tokens to be generated
#            predicted_ids = torch.topk(outputs.logits[0, -1], k=5).indices
#            predicted_tokens = [tokenizer.decode([pid]).strip() for pid in predicted_ids]
#            predicted_token = predicted_tokens[0]  # Take the most likely
#        current_text += predicted_token + parts[i]
#    return current_text

def predict_masks_encdec(model, tokenizer, masked_text):
    mask_count = masked_text.count("[MASK]")
    input_text = masked_text
    for i in range(mask_count):
        input_text = input_text.replace("[MASK]", f"<extra_id_{i}>", 1)
    inputs = tokenizer(input_text, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=512)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    predicted_text = masked_text
    for i in range(mask_count):
        start_token = f"<extra_id_{i}>"
        end_token = f"<extra_id_{i+1}>" if i + 1 < mask_count else None
        if end_token:
            pred_token = generated_text.split(start_token)[1].split(end_token)[0].strip()
        else:
            pred_token = generated_text.split(start_token)[1].strip()
        predicted_text = predicted_text.replace("[MASK]", pred_token, 1)
    return predicted_text

# Load models and tokenizers
def load_model_and_tokenizer(model_name, model_type):
    cache_dir = "/kaggle/working/models"
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_enable_fp32_cpu_offload=True
    )
    if model_type == "mlm":
        # -- MODEL --
        model = AutoModelForMaskedLM.from_pretrained(
            model_name,
            cache_dir=cache_dir,
            token=hf_token,
        )
        # -- TOKENIZER --
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            cache_dir=cache_dir,
            token=hf_token,    
        )
    
    elif model_type == "clm":
         # 2. Create an empty model skeleton to compute device_map weights
        with init_empty_weights():
            dummy_config = AutoConfig.from_pretrained(
                model_name,
                cache_dir=cache_dir,
                token=hf_token,
                trust_remote_code=True
            )
            dummy_model = AutoModelForCausalLM.from_config(dummy_config)
        
        # 3. Define per-device memory budget
        max_memory = {
            0: "13GB",    # GPU 0
            1: "13GB",    # GPU 1
            "cpu": "29GB" # CPU RAM
        }
        # 4. Compute device_map under these constraints
        no_split = ["LlamaDecoderLayer", "MistralDecoderLayer"]
        device_map = infer_auto_device_map(
            dummy_model,
            max_memory=max_memory,
            no_split_module_classes=no_split  # avoid splitting between GPUs if needed
        )
#        model = AutoModelForCausalLM.from_pretrained(
#        model_name,
#        quantization_config=bnb_config,
#        device_map=device_map,
#        torch_dtype="auto",
#        low_cpu_mem_usage=True
#    )
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            cache_dir=cache_dir,
            token=hf_token,
            trust_remote_code=True,
            quantization_config=bnb_config,
            device_map=device_map,
            torch_dtype="auto",         # float16 on GPU if available
            low_cpu_mem_usage=True       # reduce CPU peak-memory when deserializing
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            cache_dir=cache_dir,
            token=hf_token,
        )
    
    elif model_type == "enc-dec":
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            cache_dir=cache_dir,
            token=hf_token,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            cache_dir=cache_dir,
            token=hf_token,
        )
    
    else:
        raise ValueError(f"Unknown model_type: {model_type}")
    
    return model, tokenizer

# Evaluation metrics
def compute_token_metrics(predicted_text, ground_truth, masked_positions):
    pred_words = normalize_text(predicted_text).split()
    gt_words = normalize_text(ground_truth).split()
    correct = 0
    for pos, gt_phrase in masked_positions:
        # Check if predicted phrase at position matches ground truth phrase
        pred_phrase = pred_words[pos] if pos < len(pred_words) else ''
        if pred_phrase == gt_phrase:
            correct += 1
    total_masks = len(masked_positions)
    accuracy = correct / total_masks if total_masks > 0 else 0
    return {"Precision": accuracy, "Recall": accuracy, "F1": accuracy}

def compute_wer(predicted_text, ground_truth, masked_positions):
    pred_words = normalize_text(predicted_text).split()
    incorrect = 0
    for pos, gt_phrase in masked_positions:
        pred_phrase = pred_words[pos] if pos < len(pred_words) else ''
        if pred_phrase != gt_phrase:
            incorrect += 1
    total_masks = len(masked_positions)
    return incorrect / total_masks if total_masks > 0 else 0

def compute_bertscore(predicted_text, ground_truth):
    P, R, F1 = bert_score([predicted_text], [ground_truth], lang="el", verbose=False)
    return {"BERTScore_P": P.item(), "BERTScore_R": R.item(), "BERTScore_F1": F1.item()}
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

def compute_sentence_similarity(predicted_text, ground_truth):
    model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
    embeddings = model.encode([predicted_text, ground_truth])   # returns a NumPy array
    sim = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0]
    return float(sim)

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

def compute_tfidf_similarity(predicted_text, ground_truth):
    vectorizer = TfidfVectorizer()
    tfidf_matrix = vectorizer.fit_transform([predicted_text, ground_truth])
    sim = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
    return float(sim)

In [None]:
# Main evaluation loop
import shutil
import time

model_dir = "/kaggle/working/models"
results = {}
for model_name, model_type in models.items():
    print(f"Evaluating {model_name} ({model_type})")
    model, tokenizer = load_model_and_tokenizer(model_name, model_type)
    model_results = []
    for example in examples:
        masked_text = example["masked_text"]
        ground_truth = example["ground_truth"]
        masked_positions = get_masked_positions(masked_text, ground_truth)
        
        # Predict based on model type
        start = time.time()
        if model_type == "mlm":
            predicted_text = predict_masks_mlm(model, tokenizer, masked_text)
        elif model_type == "clm":
            predicted_text = predict_masks_clm(model, tokenizer, masked_text)
        else:  # enc-dec
            predicted_text = predict_masks_encdec(model, tokenizer, masked_text)
        inference_time = time.time() - start
        
        # Compute metrics
        token_metrics = compute_token_metrics(predicted_text, ground_truth, masked_positions)
        wer = compute_wer(predicted_text, ground_truth, masked_positions)
        bertscore = compute_bertscore(predicted_text, ground_truth)
        sent_sim = compute_sentence_similarity(predicted_text, ground_truth)
        tfidf_sim = compute_tfidf_similarity(predicted_text, ground_truth)
        
        model_results.append({
            "Example": masked_text,
            "Predicted": predicted_text,
            "Ground Truth": ground_truth,
            "Token Metrics": token_metrics,
            "WER": wer,
            "BERTScore": bertscore,
            "Sentence Similarity": sent_sim,
            "TF-IDF Similarity": tfidf_sim,
            "Inference Time": inference_time
        })
    # Delete the model directory to free up space
    if os.path.exists(model_dir):
        try:
            shutil.rmtree(model_dir)
            print(f"Deleted model directory: {model_dir}")
        except Exception as e:
            print(f"Error deleting model directory {model_dir}: {e}")
    results[model_name] = model_results

# Display results
for model_name, model_results in results.items():
    print(f"\nResults for {model_name}:")
    for i, res in enumerate(model_results):
        print(f"\nExample {i+1}:")
        print(f"Masked Text: {res['Example']}")
        print(f"Predicted: {res['Predicted']}")
        print(f"Ground Truth: {res['Ground Truth']}")
        print(f"Token Metrics: {res['Token Metrics']}")
        print(f"WER: {res['WER']:.2f}")
        print(f"BERTScore: {res['BERTScore']}")
        print(f"Sentence Similarity: {res['Sentence Similarity']:.2f}")
        print(f"TF-IDF Similarity: {res['TF-IDF Similarity']:.2f}")

In [None]:
# Extract data for plotting (with inference times)
import os
import json
import pandas as pd
import seaborn as sns
import time
import matplotlib.pyplot as plt

# Build a flat list of all metrics, including the new Inference Time
data = []
for model_name, model_results in results.items():
    for i, res in enumerate(model_results):
        entry = {
            'model': model_name,
            'example': i+1,
            'Token F1': res['Token Metrics']['F1'],
            'WER': res['WER'],
            'BERTScore F1': res['BERTScore']['BERTScore_F1'],
            'Sentence Similarity': res['Sentence Similarity'],
            'TF-IDF Similarity': res['TF-IDF Similarity'],
            'Inference Time': res['Inference Time']
        }
        data.append(entry)

# Create DataFrame
df = pd.DataFrame(data)
df_melted = pd.melt(df, id_vars=['model', 'example'], var_name='metric', value_name='value')

# Ensure output dir
output_dir = '/kaggle/working/Results'
os.makedirs(output_dir, exist_ok=True)

# Save JSON and CSV
with open(os.path.join(output_dir, 'evaluation_results.json'), 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=4)
df.to_csv(os.path.join(output_dir, 'evaluation_metrics.csv'), index=False)

print(f"Saved JSON, CSV, and plots (including Inference Time) under {output_dir}")


In [None]:
# Define model categories
model_categories = {
    'google/mt5-base': 'General multilingual',
    'google-bert/bert-base-multilingual-cased': 'General multilingual',
    'distilbert/distilbert-base-multilingual-cased': 'General multilingual',
    'FacebookAI/xlm-roberta-base': 'General multilingual',
    'microsoft/infoxlm-base': 'General multilingual',
    'ilsp/Llama-Krikri-8B-Base': 'Greek general',
    'nlpaueb/bert-base-greek-uncased-v1': 'Greek general',
    'gealexandri/greeksocialbert-base-greek-uncased-v1': 'Greek general',
    'ilsp/Meltemi-7B-v1.5': 'Greek general',
    'AI-team-UoA/GreekLegalRoBERTa_v3': 'Greek legal'
}

df['category'] = df['model'].map(model_categories)

# 2) Save Markdown table of the full DataFrame:
with open(os.path.join(output_dir, 'df_metrics.md'), 'w', encoding='utf-8') as f:
    f.write(df.to_markdown(index=False))

# 3) Compute averages per model:
numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
avg_model = df.groupby('model')[numeric_cols].mean().reset_index()
with open(os.path.join(output_dir, 'average_metrics_per_model.md'), 'w', encoding='utf-8') as f:
    f.write(avg_model.to_markdown(index=False))

# 4) Plot average metrics per model:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(18, 12))
axes = axes.flatten()
metrics = ['Token F1','WER','BERTScore F1','Sentence Similarity','TF-IDF Similarity','Inference Time']
for ax, metric in zip(axes, metrics):
    sns.barplot(data=avg_model, x='model', y=metric, ax=ax)     
    ax.set_title(f'Avg {metric} per Model')
    ax.tick_params(axis='x', rotation=45, labelsize=8)
fig.tight_layout()
fig.savefig(os.path.join(output_dir, 'Average_Metrics.png'), dpi=300)

# 5) Plot inference time breakdown:
plt.figure(figsize=(12,6))
sns.barplot(
    data=df, x='model', y='Inference Time', hue='example',
    orient='v'
)                                                                   
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'Inference_Time_per_Model_and_Example.png'))

# 6) Plot all metrics per model and example in a single card
fig, axes = plt.subplots(rows, cols, figsize=(18, 6 * rows))
axes = axes.flatten()

for idx, metric in enumerate(metrics):
    ax = axes[idx]
    sns.barplot(data=df, x='model', y=metric, hue='example', ax=ax)
    ax.set_title(f'{metric} per Model and Example', fontsize=12)
    ax.set_xlabel('Model', fontsize=10)
    ax.set_ylabel(metric, fontsize=10)
    ax.tick_params(axis='x', rotation=45, labelsize=8)
    if idx == 0:
        ax.legend(loc='upper right')
    else:
        ax.get_legend().remove()

# Remove unused subplots
for j in range(n_metrics, len(axes)):
    fig.delaxes(axes[j])

fig.tight_layout()
fig.savefig(os.path.join(output_dir, 'Per_Model_Metrics.png'))
plt.close(fig)