In [None]:
import pandas as pd
from question_answering.utils import core_qa_utils, generative_qa_utils
from question_answering.paths import generative_qa_paths
from transformers import (
    BartTokenizerFast,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    TFAutoModelForSeq2SeqLM,
)
import os
from tqdm import tqdm
import numpy as np

In [None]:
java_samples_path = generative_qa_paths.code_qa_dataset_dir / 'java' / "manual_test.csv"
python_samples_path = generative_qa_paths.code_qa_dataset_dir / 'java' / "manual_test.csv"

java_samples = pd.read_csv(java_samples_path)
python_samples = pd.read_csv(python_samples_path)

In [None]:
java_dataset, python_dataset = core_qa_utils.convert_dataframes_to_datasets(
    [java_samples, python_samples]
)

In [None]:
model_names = os.listdir(generative_qa_paths.saved_models_dir)
batch_size = 8
prediction_dataframe_list = []
prediction_dataframe_dict = {}

for model_name in tqdm(model_names):
    if "bart" in model_name:
        model_checkpoint = "facebook/bart-base"
        tokenizer = BartTokenizerFast.from_pretrained(model_checkpoint)
        model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_pt=True)
        data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")
        max_length = 256
    elif "t5" in model_name:
        model_checkpoint = "t5-base"
        tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_pt=True)
        data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")
        max_length = 512

    if "original" in model_name:
        code_type = 'original_code'
    elif "normalised" in model_name:
        code_type = 'code'

    if "java" in model_name:
        dataset = java_dataset
    elif "python" in model_name:
        dataset = python_dataset

    def preprocess_dataset(dataset):
        questions = [q.strip() for q in dataset["questions"]]
        contexts = [c.strip() for c in dataset[code_type]]
        answers = [c.strip() for c in dataset["answers"]]

        inputs = tokenizer(
            questions,
            contexts,
            text_target=answers,
            max_length=max_length,
            truncation=True
        )

        return inputs

    tokenized_dataset = dataset.map(
        preprocess_dataset,
        batched=True,
        remove_columns=dataset.column_names,
    )

    tf_dataset = core_qa_utils.prepare_tf_dataset(
        model=model,
        hf_dataset=tokenized_dataset,
        collator=data_collator,
        batch_size=batch_size,
    )
    
    loaded_weights_model = generative_qa_utils.load_weights_into_model(
        model=model, 
        model_name=model_name
    )

    questions_and_answers_df = pd.DataFrame()
    predictions_list = []
    labels_list = []
    question_contexts_list = []

    for batch, labels in tf_dataset:
        predictions = generative_qa_utils.generate_predictions(model, batch, max_length)
        decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        labels = labels
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        decoded_predictions = [prediction.strip() for prediction in decoded_predictions]
        decoded_labels = [label.strip() for label in decoded_labels]
        predictions_list.extend(decoded_predictions)
        labels_list.extend(decoded_labels)
        question_contexts_list.extend(tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True))

        data = {
            'question_contexts': tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True),
            'labels': decoded_labels,
            'predictions': decoded_predictions
        }
        questions_and_answers_df = pd.concat([questions_and_answers_df, pd.DataFrame(data)], ignore_index=True)

    prediction_dataframe_list.append(questions_and_answers_df)
    model_evaluation_dir = generative_qa_paths.model_evaluation_dir / model_name
    questions_and_answers_df.to_csv(
        model_evaluation_dir / "predictions_for_manual_check.csv", index=False
    )