In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
INPUT_PATH = "/content/drive/MyDrive/disertatie/error_analysis"

In [None]:
datasets_files_files_first = dict()
datasets_files_files_first["arabia_saudita_guvern_"] = "Saudi-Romania Econ."
datasets_files_files_first["conventia_geneva_"] = "Geneva Convention"
datasets_files_files_first["drepturile_copilului"] = "Children Rights"
datasets_files_files_first["drepturile_economice_culturale"] = "Economic Rights"
datasets_files_files_first["drepturile_omului_"] = "Human Rights"
datasets_files_files_first["flores+"] = "FLORES+"
datasets_files_files_first["from_training_dataset_"] = "In-Domain"
datasets_files_files_first["ksaembassyrou"] = "Saudi Embassy"
datasets_files_files_first["statutul_refugiatilor_"] = "Refugee Status"


In [None]:
models_translate = dict()

models_translate["GEMINI"]        = "gemini"
models_translate["GPT4"]          = "gpt4"
models_translate["JAIS"]          = "JAIS"
models_translate["LLAMA_PRE"]         = "llm"
models_translate["LLAMA_TRAINED"]     = "llm_trained"
models_translate["MISTRAL_PRE"]       = "mistral_pre"
models_translate["MISTRAL_TRAINED"]   = "mistral_trained"
models_translate["NLLB"]              = "nllb"
models_translate["NLLB_TRAINED"]      = "nllb_trained"
#models_translate["OPENNMT_BASELINE"]  = "OPENNMT_BASELINE"
models_translate["OPENNMT_GEMINI"]    = "OPENNMT_GEMINI"
models_translate["OPENNMT_LEALLA"]      = "OPENNMT_LEALLA"
models_translate["OPENNMT_WITHOUT_DIALECTS"]    = "OPENNMT_WITHOUT_DIALECTS"
models_translate["OPENNMT_WITHOUT_DUPLICATES"]    = "OPENNMT_WITHOUT_DUPLICATES"

In [None]:
models_named_in_table = dict()
models_named_in_table["GEMINI"]        = "Gemini-flash-002"
models_named_in_table["GPT4"]          = "GPT-4"
models_named_in_table["JAIS"]          =  "Jais+Transf."
models_named_in_table["LLAMA_PRE"]     = "Pre-Trained RoLLama3.1-8b"
models_named_in_table["LLAMA_TRAINED"] = "Fine-Tuned RoLLama3.1-8b"
models_named_in_table["MISTRAL_PRE"]   = "Pre-Trained RoMistral-7b"
models_named_in_table["MISTRAL_TRAINED"] = "Fine-Tuned RoMistral-7b"
models_named_in_table["NLLB"] = "Pre-Trained NLLB-600M"
models_named_in_table["NLLB_TRAINED"] = "Fine-Tuned NLLB-600M"
#models_named_in_table["OPENNMT_BASELINE"] = "Baseline Transf."
models_named_in_table["OPENNMT_GEMINI"]  = "Transf. (Gemini filter)"
models_named_in_table["OPENNMT_LEALLA"] = "Transf. (Lealla filter)"
models_named_in_table["OPENNMT_WITHOUT_DIALECTS"] = "Transf. (Dialects removed)"
models_named_in_table["OPENNMT_WITHOUT_DUPLICATES"] = "Transf. (Duplicates removed)"

In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def file_read(input_path):
  # Load the JSON file
  file_path = input_path  # Replace with the actual path to your JSON file
  with open(file_path, "r", encoding="utf-8") as file:
    data = json.load(file)
  key = list(data.keys())[0]
  ## Extract all COMET values into a list
  comet_values = [
    item["COMET"]
    for item in data[key]
    if "COMET" in item
  ]

  return comet_values

def generate_box_plots(results, model_name):
    plot_data = []

    # Iterate over the results
    for dataset, comet_scores in results.items():
        if isinstance(comet_scores, list):  # Ensure comet_scores is a list
            for score in comet_scores:
                plot_data.append({
                    "Dataset": dataset,
                    "COMET": score
                })
        else:
            print(f"Unexpected data format for dataset {dataset}: {comet_scores}")

    df = pd.DataFrame(plot_data)
    grouped_data = [df[df["Dataset"] == dataset]["COMET"].values for dataset in df["Dataset"].unique()]

    # Plotting the boxplot
    plt.figure(figsize=(10, 6))  # Adjusted figure size
    plt.boxplot(grouped_data, labels=df["Dataset"].unique(), vert=True)

    plt.xlabel("Datasets")
    plt.ylabel("COMET")
    plt.title(f"Comet-xl-da scores for {model_name} model", fontsize=16)
    plt.xlabel("Dataset", fontsize=14)
    plt.ylabel("COMET Score", fontsize=14)
    plt.xticks(rotation=45)
    plt.tight_layout()
    output_file = f"/content/drive/MyDrive/disertatie/error_analysis/box_plots/comet_xl_da_{model_name}"
    plt.savefig(output_file, format="pdf")
    plt.close()  # Close the plot to free up memory

def generate_error_bins_plots_per_dataset(results, model_name):
    """
    Generate error analysis bar charts for each dataset.

    Args:
        results (dict): A dictionary where keys are dataset names and values are lists of COMET scores.
        model_name (str): The name of the model being visualized.

    Returns:
        None
    """
    # Define error bins
    def categorize_error(score):
        if score >= 0.85:
            return "No Errors"
        elif 0.45 <= score < 0.85:
            return "Minor Errors"
        elif 0.0 <= score < 0.45:
            return "Major Errors"
        else:
            return "Critical Errors"

    for dataset_name, comet_scores in results.items():
        if "In-Domain" in dataset_name:
          print(comet_scores)
          # Initialize error counts for the current dataset
          error_counts = {"No Errors": 0, "Minor Errors": 0, "Major Errors": 0, "Critical Errors": 0}
          total_scores = len(comet_scores)

          # Categorize scores into bins
          for score in comet_scores:
              category = categorize_error(score)
              error_counts[category] += 1

          # Calculate percentages
          error_percentages = {k: (v / total_scores) * 100 for k, v in error_counts.items()}

          # Prepare data for plotting
          categories = list(error_percentages.keys())
          percentages = list(error_percentages.values())

          # Plot the data
          plt.figure(figsize=(8, 6))  # Adjusted figure size
          ax = sns.barplot(x=categories, y=percentages)

          # Add percentage labels above the bars
          for i, percentage in enumerate(percentages):
              ax.text(i, percentage + 1, f"{percentage:.1f}%", ha="center", fontsize=12)

          # Customize the plot
          plt.title(f"Error Analysis for {dataset_name} ({model_name})", fontsize=16)
          plt.xlabel("Error Category", fontsize=14)
          plt.ylabel("Percentage (%)", fontsize=14)
          plt.ylim(0, 100)
          plt.tight_layout()

          # Save to PDF
          output_file = f"/content/drive/MyDrive/disertatie/error_analysis/error_analysis/{dataset_name}_{model_name}_error_bins"
          try:
              plt.savefig(output_file, format="pdf")
              print(f"Bar chart saved to {output_file}")
          except Exception as e:
              print(f"Error saving chart for {dataset_name}: {e}")
          finally:
              plt.close()

def generate_error_bins_plot(results, model_name):
    """
    Categorize COMET scores into error bins, calculate percentages, and plot as a bar chart.

    Args:
        results (dict): A dictionary where keys are dataset names and values are lists of COMET scores.
        model_name (str): The name of the model being visualized.

    Returns:
        None
    """
    # Define error bins
    def categorize_error(score):
        if score >= 0.85:
            return "No Errors"
        elif 0.45 <= score < 0.85:
            return "Minor Errors"
        elif 0.0 <= score < 0.45:
            return "Major Errors"
        else:
            return "Critical Errors"

    # Aggregate scores into categories
    error_counts = {"No Errors": 0, "Minor Errors": 0, "Major Errors": 0, "Critical Errors": 0}
    total_scores = 0

    for dataset_name, comet_scores in results.items():
        if not("FLORES+" in dataset_name or "In-Domain" in dataset_name):
            for score in comet_scores:
                category = categorize_error(score)
                error_counts[category] += 1
                total_scores += 1
        else:
            print(f"Unexpected data format: {comet_scores}")

    # Calculate percentages
    error_percentages = {k: (v / total_scores) * 100 for k, v in error_counts.items()}

    # Prepare data for plotting
    categories = list(error_percentages.keys())
    percentages = list(error_percentages.values())

    # Plot the data
    plt.figure(figsize=(8, 6))  # Adjusted figure size
    ax = sns.barplot(x=categories, y=percentages)

    for i, percentage in enumerate(percentages):
      ax.text(i, percentage + 1, f"{percentage:.1f}%", ha="center", fontsize=12)

    plt.title(f"Aggregated Error Analysis for OOD ({model_name})", fontsize=16)
    plt.xlabel("Error Category", fontsize=14)
    plt.ylabel("Percentage (%)", fontsize=14)
    plt.ylim(0, 100)
    plt.tight_layout()

    # Save to PDF
    output_file = f"/content/drive/MyDrive/disertatie/error_analysis/error_analysis/comet_xl_da_{model_name}_error_bins"
    try:
        plt.savefig(output_file, format="pdf")
        print(f"Bar chart saved to {output_file}")
    except Exception as e:
        print(f"Error saving chart: {e}")
    finally:
        plt.close()

def aggregate_data():
  for models_translate_key in models_translate.keys():
    input_path_folder = f"{INPUT_PATH}/{models_translate_key}/"
    model_results = dict()
    for datasets_files_first_key in datasets_files_files_first.keys():
      input_path = f"{input_path_folder}{datasets_files_first_key}{models_translate[models_translate_key]}_error_analysis_output.json"
      result = file_read(input_path)
      model_results[datasets_files_files_first[datasets_files_first_key]] = result
    generate_box_plots(model_results, models_named_in_table[models_translate_key])
    generate_error_bins_plots_per_dataset(model_results, models_named_in_table[models_translate_key])
    generate_error_bins_plot(model_results, models_named_in_table[models_translate_key])

In [None]:
aggregate_data()

  plt.boxplot(grouped_data, labels=df["Dataset"].unique(), vert=True)


[0.73681640625, 0.1976318359375, 0.41845703125, 0.426025390625, 0.1636962890625, 0.2337646484375, 0.4677734375, 0.463623046875, 0.1812744140625, 0.40283203125, 0.310546875, 0.556640625, 0.32861328125, 0.8564453125, 0.358154296875, 0.7900390625, 0.76171875, 0.59912109375, 0.39697265625, 0.27392578125, 0.6689453125, 0.92626953125, 0.11761474609375, 0.8642578125, 0.81640625, 0.438720703125, 0.441162109375, 0.76318359375, 0.32568359375, 0.41455078125, 0.7998046875, 0.34228515625, 0.60595703125, 0.6201171875, 0.277587890625, 0.91357421875, 0.9599609375, 0.7392578125, 0.049957275390625, 0.2225341796875, 0.268798828125, 0.60400390625, 0.587890625, 0.712890625, 0.0282135009765625, 0.2435302734375, 0.64013671875, 0.7724609375, 0.7529296875, 0.9306640625, 0.63623046875, 0.908203125, 0.06982421875, 0.17578125, 0.765625, 0.20751953125, 0.771484375, 0.626953125, 0.09381103515625, 0.7939453125, 0.70703125, 0.3642578125, 0.85498046875, 0.76416015625, 0.462890625, 0.9609375, 0.908203125, 0.63720703125