In [None]:
# Call library 

In [None]:

import torch
import evaluate 
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration, get_scheduler
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from utils import read_json, collote_fn, MAX_TARGET_LENGTH, merge_qa_dataset
from dataset import MengziT5Dataset
from pathlib import Path
from tqdm import tqdm 
from dotenv import load_dotenv 
load_dotenv()

checkpoint = "Langboat/mengzi-t5-base"

# Preprocess

In [None]:
DATA_TRAIN_PATH = "data/train.json"
DATA_DEV_PATH = "data/dev.json"

DATA_FDEV_PATH = "data/formatted_dev.json"
DATA_DEV_PATH = "data/dev.json"

valid_data = read_json(DATA_DEV_PATH)
merged_valid_data = merge_qa_dataset(valid_data, DATA_FDEV_PATH)
valid_dataset = MengziT5Dataset(merged_valid_data)
generator = torch.Generator().manual_seed(42)
valid_dataset, test_dataset = random_split(valid_dataset, [0.5, 0.5], generator=generator)

test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=valid_batch_size, collate_fn=lambda x: collote_fn(x, model, tokenizer))
test_data = next(iter(test_dataloader))
print("test input_ids: ", test_data['input_ids'])
print("test attention_mask: ", test_data['attention_mask'])
print("test decoder_input_ids: ", test_data['decoder_input_ids'])
print("test labels:", test_data['labels'])

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

best_model_name = "best_t5.pt"
foldername =  '???????????_ckpt'
checkpoint_path = Path(f"./checkpoint/{foldername}")
file_path = checkpoint_path / best_model_name

checkpoint = "Langboat/mengzi-t5-base"
model = T5ForConditionalGeneration.from_pretrained(checkpoint)
tokenizer = T5Tokenizer.from_pretrained(checkpoint)

model.load_state_dict(torch.load(file_path, weight_only=True))

In [None]:
def test_loop(dataloader, model, tokenizer):
    model.eval()
    bleu = evaluate.load("bleu")
    loss = []
    val_loss_sum = 0.0

    #cumulative_batch = (epoch-1) * len(dataloader)
    all_preds = []
    all_labels = []

    with tqdm(total=len(dataloader)) as pbar:
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(dataloader, start=1):
                raw_references = batch_data.pop("answer", None)
                if raw_references is None:
                    print("No raw reference is found. Now create based on labels.")
                    temp_labels = torch.where(batch_data["labels"] != -100, batch_data["labels"], tokenizer.pad_token_id)
                    raw_references = [[ref] for ref in tokenizer.batch_decode(temp_labels, skip_special_tokens=True)]


                batch_data = batch_data.to(device)
                results = model(**batch_data)
                loss = results.loss
                val_loss_sum += loss.item() # Accumulate loss

                outputs = model.generate(
                    batch_data["input_ids"],
                    attention_mask=batch_data["attention_mask"],
                    max_new_tokens=MAX_TARGET_LENGTH,
                    num_beams=4
                    )
                decoded_outputs = tokenizer.batch_decode(
                    outputs,
                    skip_special_tokens=True
                    )

                batch_preds = []
                for pred in decoded_outputs:
                    if len(pred) == 0:
                        pred = " " # Prevent divided by zero during calculation of BLEU
                    batch_preds.append(pred)
                
                batch_labels = []
                for ref_list in raw_references: # ref_list: [ans1, ans2, ...]
                    processed_ref_list = []
                    for ref in ref_list:
                        cleaned_ref = ref.strip()
                        processed_ref_list.append(' '.join(cleaned_ref.strip()))
                    batch_labels.append(processed_ref_list)

                # batch_preds = [' '.join(pred.strip()) for pred in decoded_outputs]
                # batch_labels = [' '.join(label.strip()) for label in decoded_labels]

                all_preds.extend(batch_preds)
                all_labels.extend(batch_labels)

                pbar.update(1)

            bleu_result = bleu.compute(predictions=all_preds, references=all_labels)
            result = {f"bleu-{i}" : value for i, value in enumerate(bleu_result["precisions"], start=1)}
            result['avg'] = bleu_result['bleu']
            avg_val_loss = val_loss_sum / len(dataloader)
            log_dict = {
                "val_loss": avg_val_loss,
                "BLEU_avg": bleu_result['bleu'], # 'bleu' is the avg in huggingface evaluate
                "BLEU_1": bleu_result['precisions'][0],
                "BLEU_2": bleu_result['precisions'][1],
                "BLEU_3": bleu_result['precisions'][2],
                "BLEU_4": bleu_result['precisions'][3]
            }
            print(f"Test result: BLEU={result['avg']}, BLEU1={result['bleu-1']}, BLEU2={result['bleu-2']}, BLEU3={result['bleu-3']}, BLEU4={result['bleu-4']}")
            return result

In [None]:
test_loop(test_dataloader, model, tokenizer)