In [None]:
!pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2
!pip install -e .
 
!pip install pydantic-settings
!pip install peft
!pip install trl
!pip install bitsandbytes
# Restart the Python process to use the updated packages
dbutils.library.restartPython()

In [None]:
import torch
import json
import os
import pandas as pd
from pathlib import Path
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, MT5ForConditionalGeneration, T5TokenizerFast, AutoModelForSeq2SeqLM

In [None]:
from src.llm_research.inference import run_inference, run_inference_manual
from src.llm_research.evaluation import calculate_token_level_match_accuracy, calculate_exact_match_accuracy

In [None]:
from Reimer.constants import(
    MODEL_MAPPING,
    EVAL_LANGUAGES
)
STORE_RESULTS = True


In [None]:
from peft import PeftModel, PeftConfig, get_peft_model
from typing import Dict
from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    StoppingCriteria,
    StoppingCriteriaList,
    BitsAndBytesConfig,
    MaxLengthCriteria,
    AutoConfig,
    MT5ForConditionalGeneration,
)
try:
    import flash_attn

    flash_attn_available = True
except ImportError:
    flash_attn_available = False
    print("flash_attn not installed. Using default attention implementation.")


class StopAtLineEndCriterion(StoppingCriteria):
    """A stopping criterion that halts the text generation when a newline character is encountered.

    This stopping criterion is useful when the generation of text should stop
    at the end of a line, such as for tasks where each output should be a single line of text.

    Attributes:
        tokenizer: The tokenizer that is used to decode the IDs back to text.

    Args:
        tokenizer: An instance of a tokenizer that converts token IDs to text.
    """

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores):
        """Checks if the generation should be stopped.

        Args:
            input_ids: Tensor of token IDs representing the current state of generated text.
            scores: The generation scores for the current set of tokens.

        Returns:
            A boolean indicating if the newline character is at the end of the text, thus
            generation should stop.
        """
        # Convert token IDs to text
        text = self.tokenizer.decode(input_ids[0])
        # Check if a newline character is in the text
        return text.endswith("\n")
    
        import flash_attn


In [None]:
def call_run_inference(
    model,
    model_path,
    tokenizer,
    task,
    language,
    is_peft,
    token_level_eval,
    generation_config,
    is_qlora=None,
    subset=None
):
    if subset:
        dataset_test_path = Path(f"/dbfs/Reimer/data/{task}/{task}_{language}_validation-filtered-samples{subset}.json")
    else:
        dataset_test_path = Path(f"/dbfs/Reimer/data/{task}/{task}_{language}_validation.json")

    pred_file_path = Path(f"{Path(model_path).name}-{task}-{language}-predictions.json")

    run_inference_manual(
        dataset_path=dataset_test_path,
        model=model,
        tokenizer=tokenizer,
        model_id=model_path,
        is_peft=is_peft,
        is_qlora=is_qlora,
        batch_size=BATCH_SIZE,
        generation_config=generation_config,
        out_path=pred_file_path,
        dataset_split='train',
    )

    if token_level_eval:
        #calculate token level accuracy
        exact_match_accuracy, mismatch_percentage = calculate_token_level_match_accuracy(
            ground_truth_file=dataset_test_path,
            prediction_file=pred_file_path,
        )
    else:
        # Compute exact match accuracy
        exact_match_accuracy = calculate_exact_match_accuracy(
            ground_truth_file=dataset_test_path,
            prediction_file=pred_file_path,
        )
        mismatch_percentage = None
    return exact_match_accuracy, mismatch_percentage

In [None]:
def run_task_lang_combos(
    model_id,
    tokenizer_path,
    generation_config,
    tasks,
    langs,
    is_qlora=False,
    device="cuda",
    file_count = 1,
    skip_existing = False,
    subset_wiki=None,
    use_existing=True,
    save=True
):  
    print(f"Saving results: {save}")

    #creating results dict
    tokenizer = T5TokenizerFast.from_pretrained("Bigscience/mt0-large")
    # model = MT5ForConditionalGeneration.from_pretrained(model_path)
    results_dict = {
        'model' : str(model_path),
        'tokenizer' : str(tokenizer_path)
    }
    results_dict['subset wikiann'] = subset_wiki
    generation_config["eos_token_id"] = tokenizer.eos_token_id,
    results_dict['config'] = generation_config

    # saving results dict
    results_dict_path = Path(f"accuracies/{Path(model_path).name}_accuracies.json")
    results_dict_stem = Path(f"accuracies/{Path(model_path).name}_accuracies")

    if results_dict_stem.exists():      #ugly but quickest fix
        results_dict_path = results_dict_stem

    if results_dict_path.exists() and use_existing:
        print(f"Found existing results dict at {results_dict_path}, using existing.")
        with open(results_dict_path, 'r', encoding='utf-8') as f:
            results_dict = json.load(f)
            new_tasks = [task for task in tasks if task not in results_dict.keys()]
            print(f"New tasks: {new_tasks}")
            tasks = new_tasks
    else:
        while(results_dict_path.exists() or results_dict_stem.exists()):
            print(f"Found existing file: {results_dict_path}")
            if skip_existing:
                print(f"Skipping {results_dict_path}")
                return 0
            else:
                results_dict_path = Path(f"accuracies/{Path(model_path).name}_accuracies{str(file_count)}.json")
                file_count += 1

    if len(tasks) == 0:
        print(f"No tasks")
        return 0

    if results_dict_path is not None and save:
        with open(results_dict_path, 'w', encoding='utf-8') as f:
            json.dump(results_dict, f, indent=4)

    if "lora" in str(model_path):
        is_peft = True
    else:
        is_peft = False

    #load and configure model / tokenizer
    # configure the model
    if is_peft:
        print("Converting to PEFT model...")
        config = PeftConfig.from_pretrained(model_id)
        inference_model_id = config.base_model_name_or_path
        is_encoder_decoder = config.task_type == "SEQ_2_SEQ_LM"
    else:
        inference_model_id = model_id
        config = AutoConfig.from_pretrained(model_id)
        is_encoder_decoder = config.is_encoder_decoder

    bnb_config = None
    if is_qlora:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )

    if device == "cuda" and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    print(f"Using device: {device}")
    
    # Initialize model and tokenizer
    model = MT5ForConditionalGeneration.from_pretrained(
            inference_model_id,
            torch_dtype=torch.bfloat16,
            quantization_config=bnb_config,
        )


    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(
        inference_model_id,
        padding_side="left",
        trust_remote_code=True,
    )

    if is_peft:
        # add the adapters to the base model
        model = PeftModel.from_pretrained(model, model_id)
        if not is_qlora:
            model = model.merge_and_unload()

    # Initialize stopping criteria - stop at new line char
    stop_criterium = StoppingCriteriaList(
        [
            StopAtLineEndCriterion(tokenizer=tokenizer),
        ]
    )

    for task in tasks:
        token_level_eval = False

        if task == 'wikiann':
            token_level_eval = True
            subset = subset_wiki
        else:
            subset = None

        print(f"subset for {task} is {subset}")
        
        if langs is None:
            langs = EVAL_LANGUAGES[task]
            
        for language in langs:
            exact_match_accuracy, mismatch_count = call_run_inference(
                model=model, 
                tokenizer=tokenizer, 
                task=task, 
                language=language,           
                token_level_eval=token_level_eval, 
                model_path=model_path, 
                is_peft=is_peft, 
                generation_config=generation_config,
                subset=subset
            )

            print(f"Mismatch count for {Path(model_path).name} - {task} - {language} : {mismatch_count}")

            if results_dict_path is not None:
                with open(results_dict_path, 'r', encoding='utf-8') as f:
                    results_dict = json.load(f)

                if not task in results_dict.keys():
                    results_dict[task] = {}
                if not language in results_dict.keys():
                    results_dict[task][language] = {}

                results_dict[task][language]['token_level_eval'] = token_level_eval
                results_dict[task][language]['accuracy'] = exact_match_accuracy

                if save:
                    with open(results_dict_path, 'w', encoding='utf-8') as f:
                        json.dump(results_dict, f, indent=4)

            print(f"Accuracy for {Path(model_path).name} - {task} - {language} : {exact_match_accuracy} (token_level = {token_level_eval})")

In [None]:
model_dir = dbutils.fs.ls("dbfs:/Reimer/merged_models/mono")
model_list = [f'/{str(p.path).replace(":", "")}' for p in model_dir]

In [None]:
BATCH_SIZE = 1
MAX_SEQ_LEN = 512
TEMPERATURE = 0.9
TOP_K = 50
TOP_P = 0.95
NUM_RETURN_SEQ = 1

In [None]:
def create_model_names(
    task2,
    lang2,
    merge_method,
    base = None,
    task1 = None,
    lang1 = None,
    ft1 = "ft",
    ft2 = "ft",
    bm1 = "mt0-large",
    bm2 = "mt0-large",
    base_path = "/dbfs/Reimer/merged_models",
    connecter = "--"
):
    assert base or (task1 and lang1), "Need base or finetuned model for model1"

    if base:
        model1 = bm1
    else:
        model1 = f"{bm1}_{task1}_{ft1}_{lang1}"

    model2 = f"{bm2}_{task2}_{ft2}_{lang2}"

    return f"{base_path}/{merge_method}/{model1}{connecter}{model2}-{merge_method}"

In [None]:
languages = ['ar', 'de', 'el', 'es']
base = 'mt0-large'
tasks = ['wikiann']
merging_methods = ['mono', 'pooling']

base_list = [
    create_model_names(
        task2=task2,
        lang2=lang2,
        merge_method=merge_method,
        base=True
    )
    for task2 in tasks
    for merge_method in merging_methods
    for lang2 in languages
]

print("".join([f"{model}\n" for model in base_list]))


In [None]:
task = "xnli"
output_dir = f"/dbfs/Reimer/output/{task}"

xnli_list = [f"{output_dir}/experiment_mt0-large_{task}_ft_{lang}/" for lang in languages]
xnli_list = xnli_list[2:]

In [None]:
from itertools import combinations_with_replacement, combinations

languages = ['ar', 'de', 'el', 'es']
task1 = 'sib200'
task2 = 'sib200'
# merging_methods = ['mono', 'pooling']
merging_methods = ['mono']

wt_model_list_sib = [
    create_model_names(
        task1=task1,
        lang1=lang1,
        task2=task2,
        lang2=lang2,
        merge_method=merge_method
    )
    for merge_method in merging_methods
    for lang1, lang2 in combinations(languages, 2)
]

print("".join([f"{model}\n" for model in wt_model_list_sib]))

In [None]:
from itertools import combinations_with_replacement, combinations

languages = ['ar', 'de', 'el', 'es']
task1 = 'sib200'
task2 = 'xnli'
# merging_methods = ['mono', 'pooling']
merging_methods = ['mono']

ct_model_list_sib_xnli = [
    create_model_names(
        task1=task1,
        lang1=lang1,
        task2=task2,
        lang2=lang2,
        merge_method=merge_method
    )
    for merge_method in merging_methods
    for lang1, lang2 in combinations_with_replacement(languages, 2)
]

print("".join([f"{model}\n" for model in ct_model_list_sib_xnli]))

In [None]:
output_dir = f"/dbfs/Reimer/output/sib200"
task = "sib200"

sib_list = [f"{output_dir}/experiment_mt0-large_{task}_ft_{lang}" for lang in languages]

## Running inference

In [None]:
model_path = ""
tokenizer_path = f"{model_path}/tokenizer_config.json"
generation_config = {
        "max_length": MAX_SEQ_LEN,
        "temperature": TEMPERATURE,
        "top_k": TOP_K,
        "top_p": TOP_P,
        "num_return_sequences": NUM_RETURN_SEQ,
    }

tasks = ['xnli', 'sib200', 'wikiann']

subset = 500          #only used in case of wikiann

model_list = ct_model_list_sib_xnli[:5]

for model_path in model_list: 
    print(f"Testing for {model_path}")
    run_task_lang_combos(
        model_id=model_path, 
        tokenizer_path = tokenizer_path, 
        tasks=tasks, 
        langs=None, 
        generation_config=generation_config, 
        skip_existing=False,
        subset_wiki=subset,
        use_existing=True,
        save=True
    )

## Subsampling data

In [None]:
import random
def subsample_data(
    data_path,
    N,
    seed = 34
):
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    random.seed(seed)
    data = random.choices(data, k=500)
    
    new_path = data_path.split(".json")[0] + f"-samples{N}" + ".json"
    print(f"Saving subsampled data under {new_path}")

    with open(new_path, 'wt', encoding='utf-8') as f:
        json.dump(data, f, indent=4)


In [None]:
# data_list = [p.path for p in dbutils.fs.ls("dbfs:/Reimer/data/wikiann") if 'filtered' in p.name and 'samples' not in p.name]

# for p in data_list:
#     subsample_data(f"/{p.replace(':', '')}", N=500, seed=32)