# Call library

In [1]:
import json 
import torch
import os
import evaluate 
import wandb
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration, get_scheduler
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from utils import save_checkpoint, read_json, get_data_stats, collote_train_fn, collote_valid_fn, MAX_TARGET_LENGTH
from dataset import MengziT5Dataset
from pathlib import Path
from datetime import datetime 
from tqdm import tqdm 
from dotenv import load_dotenv 
load_dotenv()

checkpoint = "Langboat/mengzi-t5-base"

  from .autonotebook import tqdm as notebook_tqdm


# Preprocess data

In [2]:
def merge_qa_dataset(data, output_file_path):
    """
    Merges JSON entries with the same Context and Question into a single entry
    with a list of answers. Re-indexes IDs sequentially.
    """
    # Grouping Logic
    # We use a dictionary to group items.
    # Key: Tuple of (context, question) -> This ensures unique QA pairs
    # Value: List of answers
    grouped_data = {}

    print(f"Processing {len(data)} items...")

    for item in data:
        context = item.get('context', '').strip()
        question = item.get('question', '').strip()
        answer = item.get('answer', '')
        
        # Create a unique key for this specific question context
        key = (context, question)

        if key not in grouped_data:
            grouped_data[key] = []

        # Handle cases where input answer might already be a list or a string
        if isinstance(answer, list):
            grouped_data[key].extend(answer)
        else:
            grouped_data[key].append(answer)

    # 3. Reconstruct the List with new IDs
    new_json_data = []
    new_id_counter = 0

    for (context, question), answers in grouped_data.items():
        # Remove duplicate answers if you want unique references only
        answers = list(set(answers)) 
        
        entry = {
            "id": new_id_counter,
            "context": context,
            "question": question,
            "answer": answers  # This is now a list ["Ans1", "Ans2"]
        }
        new_json_data.append(entry)
        new_id_counter += 1

    # Save to new file
    with open(output_file_path, 'w', encoding='utf-8') as f:
        for obj in tqdm(new_json_data, desc="Writing to JSON file"):
            # ensure_ascii=False is crucial for Chinese characters to be readable
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

    print(f"Success! Merged data saved to {output_file_path}")
    print(f"Original count: {len(data)} -> New count: {len(new_json_data)}")

    return new_json_data

In [3]:
DATA_TRAIN_PATH = "data/train.json"
DATA_DEV_PATH = "data/dev.json"

DATA_FDEV_PATH = "data/formatted_dev.json"
DATA_DEV_PATH = "data/dev.json"

valid_data = read_json(DATA_DEV_PATH)
merged_valid_data = merge_qa_dataset(valid_data, DATA_FDEV_PATH)
# merged_valid_data = read_json(DATA_FDEV_PATH)

tokenizer = T5Tokenizer.from_pretrained(checkpoint) 

print("First valid data: ", merged_valid_data[0])
train_data = read_json(DATA_TRAIN_PATH)
print("First train data: ", train_data[0])


Reading JSON file: 984it [00:00, 136934.15it/s]


Processing 984 items...


Writing to JSON file: 100%|██████████| 700/700 [00:00<00:00, 82685.95it/s]


Success! Merged data saved to data/formatted_dev.json
Original count: 984 -> New count: 700


KeyboardInterrupt: 

In [None]:
get_data_stats(valid_data, tokenizer)

{'question_num': 984,
 'context_num': 984,
 'answer_num': 984,
 'question_mean_length': 6.5426829268292686,
 'context_mean_length': 192.15243902439025,
 'answer_mean_length': 4.774390243902439,
 'question_max_length': 18,
 'context_max_length': 728,
 'answer_max_length': 26}

In [None]:
get_data_stats(train_data, tokenizer)

{'question_num': 14520,
 'context_num': 14520,
 'answer_num': 14520,
 'question_mean_length': 6.488154269972452,
 'context_mean_length': 182.3798209366391,
 'answer_mean_length': 4.257782369146006,
 'question_max_length': 28,
 'context_max_length': 1180,
 'answer_max_length': 95}

In [None]:
valid_dataset = MengziT5Dataset(merged_valid_data)
train_dataset = MengziT5Dataset(train_data)

Total data filtered away: 165
Total data filtered away: 1906


# Retrieve Model 

In [None]:
train_batch_size = 8
valid_batch_size = 8
#test_batch_size = 8

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = T5ForConditionalGeneration.from_pretrained(checkpoint)
model = model.to(device)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, collate_fn=lambda x: collote_train_fn(x, model, tokenizer))
train_data = next(iter(train_dataloader))
print("train input_ids: ", train_data['input_ids'])
print("train attention_mask: ", train_data['attention_mask'])
print("train decoder_input_ids", train_data['decoder_input_ids'])
print("train labels", train_data['labels'])
print("----------")

generator = torch.Generator().manual_seed(42)
valid_dataset, test_dataset = random_split(valid_dataset, [0.5, 0.5], generator=generator)

valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=valid_batch_size, collate_fn=lambda x: collote_valid_fn(x, model, tokenizer))
valid_data = next(iter(valid_dataloader))
print("valid input_ids: ", valid_data['input_ids'])
print("valid attention_mask: ", valid_data['attention_mask'])
print("valid decoder_input_ids: ", valid_data['decoder_input_ids'])
print("valid labels:", valid_data['labels'])

# test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=valid_batch_size, collate_fn=lambda x: collote_fn(x, model, tokenizer))
# test_data = next(iter(test_dataloader))
# print("test input_ids: ", test_data['input_ids'])
# print("test attention_mask: ", test_data['attention_mask'])
# print("test decoder_input_ids: ", test_data['decoder_input_ids'])
# print("test labels:", test_data['labels'])


Loading weights: 100%|██████████| 282/282 [00:00<00:00, 500.08it/s, Materializing param=shared.weight]                                                       


train input_ids:  tensor([[  7, 143,  13,  ...,   0,   0,   0],
        [  7, 143,  13,  ...,   0,   0,   0],
        [  7, 143,  13,  ...,   0,   0,   0],
        ...,
        [  7, 143,  13,  ...,   0,   0,   0],
        [  7, 143,  13,  ...,   0,   0,   0],
        [  7, 143,  13,  ...,   0,   0,   0]])
train attention_mask:  tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
train decoder_input_ids tensor([[    0,     7,   644,  ...,     0,     0,     0],
        [    0,     7,   390,  ...,     0,     0,     0],
        [    0, 12598,   419,  ...,     0,     0,     0],
        ...,
        [    0,     7,  2687,  ...,     0,     0,     0],
        [    0,     7,  8316,  ...,     0,     0,     0],
        [    0,     7,  4265,  ...,     0,     0,     0]])
train labels tensor([[    7,   644,  6891,  ...,  -100,  -100,  -

# Train Model  

In [None]:
def train_loop(dataloader, model, optimizer, scheduler, epoch, global_step, use_wandb=False):
    model.train()
    # Reset loss counter at the start of the epoch
    epoch_loss_sum = 0.0 
    current_avg_loss = 0.0
    #cumulative_batch = len(dataloader) * (epoch - 1)
    
    with tqdm(total=len(dataloader)) as pbar:
        for batch_idx, batch_data in enumerate(dataloader, start=1):
            batch_data = batch_data.to(device)
            results = model(**batch_data)
            loss = results.loss

            # backward popagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            global_step += 1
            if use_wandb:
                wandb.log(
                    {"train_loss": loss.item()},
                    step=global_step
                )

            epoch_loss_sum += loss.item()
            current_avg_loss = epoch_loss_sum / batch_idx

            pbar.set_description(f"Epoch {epoch} | Avg Loss: {current_avg_loss:.4f}")
            pbar.update(1)


    return current_avg_loss, global_step 

def valid_loop(dataloader, model, tokenizer, epoch, global_step, use_wandb=False):
    model.eval()
    bleu = evaluate.load("bleu")
    loss = []
    val_loss_sum = 0.0

    #cumulative_batch = (epoch-1) * len(dataloader)
    all_preds = []
    all_labels = []

    with tqdm(total=len(dataloader)) as pbar:
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(dataloader, start=1):
                raw_references = batch_data.pop("answer", None)
                if raw_references is None:
                    print("No raw reference is found. Now create based on labels.")
                    temp_labels = torch.where(batch_data["labels"] != -100, batch_data["labels"], tokenizer.pad_token_id)
                    raw_references = [[ref] for ref in tokenizer.batch_decode(temp_labels, skip_special_tokens=True)]


                batch_data = batch_data.to(device)
                results = model(**batch_data)
                loss = results.loss
                val_loss_sum += loss.item() # Accumulate loss

                outputs = model.generate(
                    batch_data["input_ids"],
                    attention_mask=batch_data["attention_mask"],
                    max_new_tokens=MAX_TARGET_LENGTH,
                    num_beams=4
                    )
                decoded_outputs = tokenizer.batch_decode(
                    outputs,
                    skip_special_tokens=True
                    )
                # labels = batch_data['labels']
                # labels = torch.where(labels != -100, labels, tokenizer.pad_token_id)
                # decoded_labels = tokenizer.batch_decode(
                #     labels,
                #     skip_special_tokens=True
                # )

                batch_preds = []
                for pred in decoded_outputs:
                    if len(pred) == 0:
                        pred = " " # Prevent divided by zero during calculation of BLEU
                    batch_preds.append(pred)
                
                batch_labels = []
                for ref_list in raw_references: # ref_list: [ans1, ans2, ...]
                    processed_ref_list = []
                    for ref in ref_list:
                        cleaned_ref = ref.strip()
                        processed_ref_list.append(' '.join(cleaned_ref.strip()))
                    batch_labels.append(processed_ref_list)

                # batch_preds = [' '.join(pred.strip()) for pred in decoded_outputs]
                # batch_labels = [' '.join(label.strip()) for label in decoded_labels]

                all_preds.extend(batch_preds)
                all_labels.extend(batch_labels)

                pbar.update(1)

            bleu_result = bleu.compute(predictions=all_preds, references=all_labels)
            result = {f"bleu-{i}" : value for i, value in enumerate(bleu_result["precisions"], start=1)}
            result['avg'] = bleu_result['bleu']
            avg_val_loss = val_loss_sum / len(dataloader)
            log_dict = {
                "val_loss": avg_val_loss,
                "BLEU_avg": bleu_result['bleu'], # 'bleu' is the avg in huggingface evaluate
                "BLEU_1": bleu_result['precisions'][0],
                "BLEU_2": bleu_result['precisions'][1],
                "BLEU_3": bleu_result['precisions'][2],
                "BLEU_4": bleu_result['precisions'][3],
                "epoch": epoch
            }
            if use_wandb:
                wandb.log(
                    log_dict,
                    step=global_step
                )
            print(f"Test result: BLEU1={result['bleu-1']}, BLEU2={result['bleu-2']}, BLEU3={result['bleu-3']}, BLEU4={result['bleu-4']}")
            return result

In [None]:
learning_rate = 2e-5
epoch_num = 3
best_model_name = "best_t5.pt"
current_t = datetime.now().strftime('%d-%m-%y-%H_%M')
foldername =  current_t + '_ckpt'
checkpoint_path = Path(f"./checkpoint/{foldername}")
checkpoint_path.mkdir(parents=True, exist_ok=True)
file_path = checkpoint_path / best_model_name
recent_checkpoints = []
use_wandb = True

if use_wandb:
    wandb.init(
        project="mengzi-t5-qa",   # The name of project on the website
        name=f"{current_t}",  # Name of this specific training run
        config={        
            "learning_rate": learning_rate,
            "batch_size": train_batch_size,
            "epochs": epoch_num,
            "model": "mengzi-t5-base"
        }
    )

num_training_steps = epoch_num * len(train_dataloader)
optimizer = AdamW(model.parameters(), lr=learning_rate)
scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

global_step = 0
best_bleu = 0
for epoch in range(epoch_num):
    avg_loss, global_step = train_loop(train_dataloader, model, optimizer, scheduler, epoch, global_step, use_wandb=use_wandb)
    valid_bleu = valid_loop(valid_dataloader, model, tokenizer, epoch, global_step, use_wandb=use_wandb)
    bleu_avg = valid_bleu['avg']
    save_checkpoint(model, epoch, checkpoint_path, recent_checkpoints)
    if bleu_avg > best_bleu:
        best_bleu = bleu_avg 
        print("Saving new best weights ...")
        torch.save(model.static_dict() , file_path)
        print("Finish saving.")
    

print("Finish training")

AuthenticationError: WANDB_API_KEY invalid: API key must have 40+ characters, has 20.