In [None]:
import os
import re
import json
import warnings
from evaluate import load
import regex as reg
from tqdm import tqdm
from collections import Counter
from rouge_score import rouge_scorer
from bert_score import score as bert_score

In [None]:
PRED_DIR = ""  # Path for your prediction dir
GT_DIR = "../NLP4Health_Shared_task_benchmark/test_data_internal_with_answers"

In [None]:

TEAM_NAME = os.path.basename(PRED_DIR).split('_')[0]
TASK_NAME = os.path.basename(PRED_DIR).split('_')[1]
TASK_TYPE = os.path.basename(PRED_DIR).split('_')[2]

RESULTS_DIR = "../Evaluation_Results/Submissions"
OUT_FILE = f"{TEAM_NAME}_{TASK_NAME}_{TASK_TYPE}_v1.json"

In [None]:
OUT_FILE

In [None]:
comet_metric = load('comet',config_name="Unbabel/wmt22-comet-da")

In [None]:
warnings.filterwarnings("ignore", category=UserWarning, module="bert_score")

In [None]:
def get_lang_code(language):
    """
    Returns the language code compatible with BERTScore or IndicBERT.
    Fallbacks to 'multilingual' if unsupported.
    """
    lang_map = {
        "Assamese": "as",
        "Telugu": "te",
        "Tamil": "ta",
        "Kannada": "kn",
        "Hindi": "hi",
        "Marathi": "mr",
        "Bangla": "bn",
        "Dogri": "doi",
        "Gujarati": "gu",
        "English":'en'
    }
    return lang_map.get(language, "multilingual")


In [None]:
def read_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def read_text(path):
    with open(path, "r", encoding="utf-8") as f:
        return f.read().strip()

def flatten_kv(d):
    return " ".join(f"{k}: {d[k]}" for k in sorted(d.keys()))

def normalize(text):
    if not text:
        return ""
    if not isinstance(text, str):
        return text
    
    text = reg.sub(r'\s+', ' ', text)
    text = reg.sub(r'\p{P}+', '', text)
    text = text.lower().strip()
    return text


def tokenize(text):
    if not text or text in [None, "", []]:
        return Counter()
    return Counter(str(text))

def read_jsonl(path):
    dialogues = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue 
            try:
                dialogues.append(json.loads(line))
            except json.JSONDecodeError:
                print(f"Skipping invalid line in {path}: {line}")
    return dialogues        

In [None]:
def fill_with_template(data, template):    
    if isinstance(template, dict):
        result = {}
        
        # Ensure data is a dict, otherwise treat as empty
        if not isinstance(data, dict):
            data = {}
        
        # Process each key in template
        for key, subtemplate in template.items():
            if key in data:
                # If template value is primitive/None but data is dict, flatten it
                if (isinstance(subtemplate, (str, int, float, bool, type(None))) and 
                    isinstance(data[key], dict)):
                    result[key] = flatten_dict_values(data[key])
                else:
                    result[key] = fill_with_template(data[key], subtemplate)
            else:
                # Key missing → fill with template structure
                result[key] = fill_with_template(None, subtemplate)
        
        return result
    
    # Handle list templates
    elif isinstance(template, list):
        # Ensure data is a list, otherwise treat as empty
        if not isinstance(data, list):
            data = []
        
        if not template:
            # Empty template = empty list
            return []
        
        filled = []
        template_len = len(template)
        
        # Always return exactly template_len items
        for i in range(template_len):
            template_item = template[i]
            
            # Get data item if available, otherwise None
            if i < len(data):
                item_data = data[i]
            else:
                item_data = None
            
            filled.append(fill_with_template(item_data, template_item))
        
        return filled
    
    # Handle primitives (str, int, bool, None, etc.)
    else:
        # If template is None, flatten data if it's multi-dimensional or dict
        if template is None:
            if data is None:
                return "Missing"
            if isinstance(data, dict):
                return flatten_dict_values(data)
            return flatten_value_v2(data)
        
        # If data is None or not provided, return "Missing"
        # Otherwise, keep the existing value
        return data if data is not None else "Missing"


def flatten_value_v2(value):

    if isinstance(value, list):
        result = []
        for item in value:
            if isinstance(item, list):
                # Recursively flatten nested lists
                result.extend(flatten_value_v2(item))
            elif isinstance(item, dict):
                # Extract all values from dict
                result.extend(flatten_dict_values(item))
            else:
                result.append(item)
        return result
    else:
        return value


def flatten_dict_values(d):
    result = []
    for value in d.values():
        if isinstance(value, dict):
            result.extend(flatten_dict_values(value))
        elif isinstance(value, list):
            result.extend(flatten_value_v2(value))
        else:
            result.append(value)
    return result



In [None]:
# flatten_knv(List[Dict])
def flatten_knv(knv_list):
    values = []

    if isinstance(knv_list, dict):
        for v in knv_list.values():
            values.extend(flatten_knv(v))

    elif isinstance(knv_list, list):
        if len(knv_list) == 0:
            values.append("")  # empty list → empty string for scoring
        else:
            # If it's a list of simple types (not dicts/lists), join as string
            if all(not isinstance(i, (dict, list)) for i in knv_list):
                joined = " ".join(str(i) for i in knv_list if i not in [None, ""])
                values.append(joined)
            else:
                # Nested structure — recurse
                for item in knv_list:
                    values.extend(flatten_knv(item))
    else:
        values.append(knv_list)

    return values


In [None]:
def compute_f1_score(preds, refs):
    total_tp, total_fp, total_fn = 0, 0, 0
    assert len(preds) == len(refs), "Preds and refs must be same length."

    pred_tokens = []
    ref_tokens = []

    for pred, ref in zip(preds, refs):

        pred_tokens = tokenize(normalize(pred))
        ref_tokens = tokenize(normalize(ref))

        common = sum((pred_tokens & ref_tokens).values())

        tp = common
        fp = sum(pred_tokens.values()) - common
        fn = sum(ref_tokens.values()) - common

        total_tp += tp
        total_fp += fp
        total_fn += fn


    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 0.0
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0

    return round(f1, 4)


In [None]:

def compute_knv_f1_score(preds_knv_list, refs_knv_list):

    assert len(preds_knv_list) == len(refs_knv_list), "Preds knv and refs knv must be same length."
    if all(p in["", None] or p == {}  for p in preds_knv_list):
        return 0.0

    all_f1s = []

    for pred_knv, ref_knv in zip(preds_knv_list, refs_knv_list):
        pred_filled = fill_with_template(pred_knv, ref_knv)
        pred_values = flatten_knv(pred_filled)
        ref_values  = flatten_knv(ref_knv)

        if len(pred_values) != len(ref_values):
            print(pred_filled)
            print(ref_knv)
            print(pred_values)
            print(ref_values)

        assert len(pred_values) == len(ref_values), "Preds and refs must be same length."

        pred_values = [normalize(str(p)) if p is not None else "None" for p in pred_values]
        ref_values = [normalize(str(r)) if r is not None else "None" for r in ref_values]

        for pred_val, ref_val in zip(pred_values, ref_values):

            # Handle both empty or None
            if pred_val in [None, "", []] and ref_val in [None, "", []]:
                f1 = 1.0
            else:
                pred_tokens = tokenize(pred_val)
                ref_tokens  = tokenize(ref_val)

                common = sum((pred_tokens & ref_tokens).values())
                tp = common
                fp = sum(pred_tokens.values()) - common
                fn = sum(ref_tokens.values()) - common

                precision = tp / (tp + fp) if (tp + fp) else 0.0
                recall = tp / (tp + fn) if (tp + fn) else 0.0
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0

            all_f1s.append(f1)

    return round(sum(all_f1s) / len(all_f1s), 4) if all_f1s else 0.0


In [None]:
# Exact match score 
def compute_exact_match(preds, refs):

    assert len(preds) == len(refs), "Preds and refs must be same length."

    total_matches = 0

    for pred, ref in zip(preds, refs):
        if normalize(pred) == normalize(ref):
            total_matches += 1

    exact_match_score = total_matches / len(refs) if len(refs) else 0.0
    return round(exact_match_score, 4)

In [None]:
def compute_knv_exact_match(preds_lists,refs_lists):
    assert len(preds_lists) == len(refs_lists), "Preds knv and refs knv must be same length."
    if all(p in["", None] or p == {}  for p in preds_lists):
        return 0.0
    
    total_matches = 0
    total_pairs = 0
    for pred_knv, ref_knv in zip(preds_lists, refs_lists):
        pred_filled = fill_with_template(pred_knv, ref_knv)
        pred_values = flatten_knv(pred_filled)
        ref_values  = flatten_knv(ref_knv)
        
        # max_len = max(len(pred_values), len(ref_values))
        # pred_values += [""] * (max_len - len(pred_values))
        # ref_values  += [""] * (max_len - len(ref_values))

        # print(f"Pred {pred_values}")
        # print(f"Pred {pred_filled}")
        # print(f"Ref {ref_values}")
        # print(f"Ref {ref_knv}")

        pred_values = [normalize(str(p)) if p is not None else "None" for p in pred_values]
        ref_values = [normalize(str(r)) if r is not None else "None" for r in ref_values]

        assert len(pred_values) == len(ref_values), "Preds and refs must be same length."
        for pred_val, ref_val in zip(pred_values, ref_values):
            if pred_val == ref_val:
                total_matches += 1
                
        total_pairs += len(ref_values)

    exact_match_score = total_matches/ total_pairs if total_pairs else 0.0
    return round(exact_match_score,4)

In [None]:
def compute_bertscore(preds,refs,lang='en'):

    print(lang)

    assert len(preds) == len(refs), "preds and refs should be of same size"
    
    # Handle empty refs
    if not preds or not refs or all(not p for p in preds):
        return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
    
    indic_langs = {"as","bn","gu","hi","kn","ml","mr","or","pa","ta","te","doi"}

    bert_lang = None
    model_type = None

    if lang in indic_langs:
        model_type = "xlm-roberta-large"
        bert_lang = 'xx'
    if lang == "en":
        model_type = "roberta-large"  
        bert_lang = lang

    # Flatten list of lists
    pred_texts = [normalize(p) for p in preds]
    ref_texts  = [normalize(r) for r in refs]

    # Batch compute BERTScore
    P, R, F1 = bert_score(
        cands=pred_texts,
        refs=ref_texts,
        lang=bert_lang,
        model_type=model_type,
        verbose=False,
    )

    return {
        "precision": float(P.mean()),
        "recall": float(R.mean()),
        "f1": float(F1.mean())
    }

In [None]:
def compute_knv_bertscore(preds_knv_list, refs_knv_list, lang="en"):
    assert len(preds_knv_list) == len(refs_knv_list), "preds and refs must be of same length"

     # Handle empty refs
    if all(p in["", None] or p == {}  for p in preds_knv_list):
        return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
    
    indic_langs = {"as","bn","gu","hi","kn","ml","mr","or","pa","ta","te","doi"}

    bert_lang = None
    model_type = None

    if lang in indic_langs:
        model_type = "xlm-roberta-large"
        bert_lang = 'xx'
    if lang == "en":
        model_type = "roberta-large"  
        bert_lang = lang

    all_f1 = []
    all_precision = []
    all_recall = []
    
    # loop on summary_knvs
    for pred_knv, ref_knv in zip(preds_knv_list, refs_knv_list):

        pred_filled = fill_with_template(pred_knv, ref_knv)
        pred_values = flatten_knv(pred_filled)
        ref_values  = flatten_knv(ref_knv)

        # max_len = max(len(pred_values), len(ref_values))
        # pred_values += [""] * (max_len - len(pred_values))
        # ref_values  += [""] * (max_len - len(ref_values))

        assert len(pred_values) == len(ref_values), "Preds and refs must be same length."

        if not pred_values:
            all_f1.append(0.0)
            all_precision.append(0.0)
            all_recall.append(0.0)
            continue
        
        pred_values = [normalize(str(p)) if p is not None else "None" for p in pred_values]
        ref_values = [normalize(str(r)) if r is not None else "None" for r in ref_values]

        P, R, F1 = bert_score(
            refs=ref_values,
            cands=pred_values,
            lang=bert_lang,
            model_type=model_type,
            verbose=False,
        )
        all_f1.append(float(F1.mean()))
        all_precision.append(float(P.mean()))
        all_recall.append(float(R.mean()))

    return {
        "precision":sum(all_precision)/len(all_precision) if all_precision else 0,
        "recall":sum(all_recall)/len(all_recall) if all_recall else 0,
        "f1": sum(all_f1) / len(all_f1) if all_f1 else 0
    }
    

In [None]:
def compute_rouge(preds, refs, lang='en'):
    use_stemmer = True
    indic_langs = {"as","bn","gu","hi","kn","ml","mr","or","pa","ta","te","doi"}
    if lang in indic_langs:
        use_stemmer = False

    scorer = rouge_scorer.RougeScorer(['rouge1','rouge2','rougeL'], use_stemmer = use_stemmer)

    metrics = {
        "rouge1": {"precision": [], "recall": [], "f1": []},
        "rouge2": {"precision": [], "recall": [], "f1": []},
        "rougeL": {"precision": [], "recall": [], "f1": []},
    }

    assert len(preds) == len(refs), "Preds and Refs should be same size"
    preds = [normalize(str(p)) if p else "" for p in preds]
    refs = [normalize(str(r)) if r else "" for r in refs]

    for p, r in zip(preds, refs):
        score = scorer.score(r, p)
        for metric in ["rouge1", "rouge2", "rougeL"]:
            metrics[metric]["precision"].append(score[metric].precision)
            metrics[metric]["recall"].append(score[metric].recall)
            metrics[metric]["f1"].append(score[metric].fmeasure)

    avg_metrics = {
        metric: {
            "precision": sum(values["precision"]) / len(values["precision"]) if values['precision'] else 0,
            "recall": sum(values["recall"]) / len(values["recall"])if values['recall'] else 0,
            "f1": sum(values["f1"]) / len(values["f1"]) if values['f1'] else 0,
        }
        for metric, values in metrics.items()
    }

    return avg_metrics


In [None]:
def compute_knv_rouge(preds_knv_list, refs_knv_list, lang = "en",):
    assert len(preds_knv_list) == len(refs_knv_list), "Preds knv list and Refs knv list should be same size"

    if all(p in["", None] or p == {}  for p in preds_knv_list):
        return {
            "rouge1": {"precision": 0, "recall": 0, "f1": 0},
            "rouge2": {"precision": 0, "recall": 0, "f1": 0},
            "rougeL": {"precision": 0, "recall": 0, "f1": 0},
        }
    
    use_stemmer = True
    indic_langs = {"as","bn","gu","hi","kn","ml","mr","or","pa","ta","te","doi"}
    if lang in indic_langs:
        use_stemmer = False

    scorer = rouge_scorer.RougeScorer(['rouge1','rouge2','rougeL'], use_stemmer = use_stemmer)
    metrics = {
        "rouge1": {"precision": [], "recall": [], "f1": []},
        "rouge2": {"precision": [], "recall": [], "f1": []},
        "rougeL": {"precision": [], "recall": [], "f1": []},
    }

    # loop on summary_knvs
    for pred_knv, ref_knv in zip(preds_knv_list, refs_knv_list):
        pred_filled = fill_with_template(pred_knv, ref_knv)
        pred_values = flatten_knv(pred_filled)
        ref_values  = flatten_knv(ref_knv)

        # max_len = max(len(pred_values), len(ref_values))
        # pred_values += [""] * (max_len - len(pred_values))
        # ref_values  += [""] * (max_len - len(ref_values))
        
        pred_values = [normalize(str(p)) if p is not None else "None" for p in pred_values]
        ref_values = [normalize(str(r)) if r is not None else "None" for r in ref_values]

        assert len(pred_values) == len(ref_values), "Preds and refs must be same length."
        for p, r in zip(pred_values, ref_values):

            if not p and not r:
                for metric in ["rouge1", "rouge2", "rougeL"]:
                    metrics[metric]["precision"].append(1.0)
                    metrics[metric]["recall"].append(1.0)
                    metrics[metric]["f1"].append(1.0)
                continue

            if not p or not r:
                for metric in ["rouge1", "rouge2", "rougeL"]:
                    metrics[metric]["precision"].append(0.0)
                    metrics[metric]["recall"].append(0.0)
                    metrics[metric]["f1"].append(0.0)
                continue
        
            score = scorer.score(r, p)
            for metric in ["rouge1", "rouge2", "rougeL"]:
                metrics[metric]["precision"].append(score[metric].precision)
                metrics[metric]["recall"].append(score[metric].recall)
                metrics[metric]["f1"].append(score[metric].fmeasure)

    avg_metrics = {
        metric: {
            "precision": sum(values["precision"]) / len(values["precision"]) if values['precision'] else 0,
            "recall": sum(values["recall"]) / len(values["recall"])if values['recall'] else 0,
            "f1": sum(values["f1"]) / len(values["f1"]) if values['f1'] else 0,
        }
            for metric, values in metrics.items()
    } 
    return avg_metrics

In [None]:
def compute_knv_cometscore(preds_knv_list, refs_knv_list):
    assert len(preds_knv_list) == len(refs_knv_list), "preds knv and refs knv should be same len"
    if all(p in["", None] or p == {}  for p in preds_knv_list):
        return 0.0
    hypothesis = []
    reference = []
    for pred_knv, ref_knv in zip(preds_knv_list, refs_knv_list):
        pred_filled = fill_with_template(pred_knv, ref_knv)
        pred_values = flatten_knv(pred_filled)
        ref_values  = flatten_knv(ref_knv)

        pred_values = [normalize(str(p)) if p is not None else "None" for p in pred_values]
        ref_values = [normalize(str(r)) if r is not None else "None" for r in ref_values]

        assert len(pred_values) == len(ref_values), "Preds and refs must be same length."
        for pred, ref in zip(pred_values, ref_values):
            hypothesis.append(pred)
            reference.append(ref)

    source = [""]*len(hypothesis)

    print(source)
    print(hypothesis)
    print(reference)

    results = comet_metric.compute(
        predictions=hypothesis,
        references=reference,
        sources = source,
    )
    return results['mean_score']
    

In [None]:
def compute_cometscore(preds_list, refs_list):
    assert len(preds_list) == len(refs_list), "Preds and refs should be of same size"
    if all(p == "" or p is None for p in preds_list):
        return 0.0
        
    source = [""]*len(preds_list)
    hypothesis = [normalize(str(p)) if p is not None else "None" for p in preds_list]
    reference = [normalize(str(r)) if r is not None else "None" for r in refs_list]

    results = comet_metric.compute(
        predictions=hypothesis,
        references=reference,
        sources = source,
    )
    return results['mean_score']

In [None]:
# Find the file starting with the dialogue ID
def find_matching_file(folder, prefix, suffix=None):
    for f in os.listdir(folder):
        if f.startswith(prefix) and (suffix is None or f.endswith(suffix)):
            return os.path.join(folder, f)
    return None

In [None]:
results = {}

In [None]:
for lang in os.listdir(GT_DIR):


    ## Ground Summary text for Dogri is missing, skipping the lang for now
    if lang == "Dogri":
        print(f"Skipping {lang}")
        continue

    print(f"\nEvaluating {PRED_DIR} {lang}")

    pred_lang_dir = os.path.join(PRED_DIR, lang)
    gt_lang_dir = os.path.join(GT_DIR, lang)

    dialogue_dir = os.path.join(gt_lang_dir,'Dialogues')

    qna_pred_dir = os.path.join(pred_lang_dir, "QnA")
    qna_gt_dir = os.path.join(gt_lang_dir, "QnA")

    summary_text_pred_dir = os.path.join(pred_lang_dir, "Summary_Text")
    summary_text_gt_dir = os.path.join(gt_lang_dir, "Summary_Text")

    summary_knv_pred_dir = os.path.join(pred_lang_dir, "Summary_KnV")
    summary_knv_gt_dir = os.path.join(gt_lang_dir, "Summary_KnV")

    dialogue_files = [f for f in os.listdir(qna_gt_dir) if f.endswith("_questions.json")]
    
    dialogue_files.sort()

    # Collect all data for batch evaluation
    all_qna_preds, all_qna_refs = [], []
    all_text_preds, all_text_refs = [], []
    all_knv_preds, all_knv_refs = [], []
    all_dialogues = []
    batch_qna_preds = []
    batch_qna_refs = []

    for qna_file in tqdm(dialogue_files, desc=f"{lang}"):
        dialogue_id = qna_file.replace("_questions.json", "")
        dialogue_file = read_jsonl(os.path.join(dialogue_dir,f"{dialogue_id}.jsonl"))

        """
        QnA
        """
        try:
            pred_qna = read_json(os.path.join(qna_pred_dir, f"{TEAM_NAME}_{TASK_NAME}_{TASK_TYPE}_{dialogue_id}_QnA.json"))["questions"]
        except Exception:
            pred_qna = None
            print(f"File not Found: {os.path.join(qna_pred_dir, f"{TEAM_NAME}_{TASK_NAME}_{TASK_TYPE}_{dialogue_id}_QnA.json")}")
        gt_qna = read_json(find_matching_file(qna_gt_dir, f"{dialogue_id}_questions.json"))["questions"]
     

        """
        Summary Text
        """
        pred_text_path = os.path.join(summary_text_pred_dir, f"{TEAM_NAME}_{TASK_NAME}_{TASK_TYPE}_{dialogue_id}_SummaryText.txt")
        try:
            pred_text = read_text(pred_text_path)
        except Exception:
            pred_text = ""
            print(f"Flile not found: {pred_text_path}")
        gt_text_path = find_matching_file(summary_text_gt_dir,dialogue_id)
        gt_text = read_text(gt_text_path)


        """
        KnV pairs
        """
        try:
            pred_knv = read_json(os.path.join(summary_knv_pred_dir, f"{TEAM_NAME}_{TASK_NAME}_{TASK_TYPE}_{dialogue_id}_SummaryKnV.json"))
        except Exception:
            pred_knv = {}
            print(f"File not found: {os.path.join(summary_knv_pred_dir, f"{TEAM_NAME}_{TASK_NAME}_{TASK_TYPE}_{dialogue_id}_SummaryKnV.json")}")
        gt_knv = read_json(find_matching_file(summary_knv_gt_dir,dialogue_id))


        # Handle case where pred_qna is None
        if pred_qna is None:
            pred_qna = [{} for _ in range(len(gt_qna))]

        # All Answers in a lang added to the list
        assert len(pred_qna) == len(gt_qna)
        for p, g in zip(pred_qna, gt_qna):
            all_qna_preds.append(p.get("answer", "").strip())
            all_qna_refs.append(g.get("answer", "").strip())

        # Batched Answers per dialogue
        batch_pred =[] 
        batch_ref = []
        for p, g in zip(pred_qna, gt_qna):
            batch_pred.append(p.get("answer", "").strip())
            batch_ref.append(g.get("answer", "").strip())

        batch_qna_preds.append(batch_pred)
        batch_qna_refs.append(batch_ref)


        # Text summaries
        all_text_preds.append(pred_text)
        all_text_refs.append(gt_text)

        #KnV pairs
        all_knv_preds.append(pred_knv)
        all_knv_refs.append(gt_knv)

        all_dialogues.append(dialogue_file)


    print(f"QnA Eval for {lang}")
    qna_scores = {
        "f1": compute_f1_score(all_qna_preds,all_qna_refs),
        "exact_match":compute_exact_match(all_qna_preds,all_qna_refs),
        "rouge":compute_rouge(all_qna_preds,all_qna_refs,lang=get_lang_code(lang)),
        "bertscore":compute_bertscore(all_qna_preds,all_qna_refs, lang=get_lang_code(lang)),
        "cometscore":compute_cometscore(all_qna_preds,all_qna_refs)
        }

    print(f"Summary Eval for {lang}")

    text_scores = {
        "f1":compute_f1_score(all_text_preds,all_text_refs),
        'exact_match':compute_exact_match(all_text_preds,all_text_refs),
        "rouge": compute_rouge(all_text_preds, all_text_refs,lang = "en"),
        "bertscore": compute_bertscore(all_text_preds, all_text_refs, lang = "en"),
        "cometscore":compute_cometscore(all_text_preds,all_text_refs)
        }
    

    print(f"KnV Eval for {lang}")
    knv_scores = {
        "f1":compute_knv_f1_score(all_knv_preds,all_knv_refs),
        'exact_match':compute_knv_exact_match(all_knv_preds,all_knv_refs),
        "rouge": compute_knv_rouge(all_knv_preds, all_knv_refs),
        "bertscore": compute_knv_bertscore(all_knv_preds, all_knv_refs, lang = "en"),
        "cometscore":compute_knv_cometscore(all_knv_preds,all_knv_refs)
        }

    results[lang] = {
        'qna_scores':qna_scores,
        'summary_text_scores':text_scores,
        'summary_knv_scores':knv_scores,
    }
    print(len(all_qna_preds), len(all_text_preds), len(all_knv_preds))

In [None]:
results

In [None]:
out_dir = os.path.join(RESULTS_DIR)
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

out_file_path = os.path.join(out_dir,OUT_FILE)

with open(out_file_path,'w') as f:
    json.dump(results,f, indent=4)