# Call library

In [1]:
import json 
import torch
import os
import evaluate 
import wandb
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration, get_scheduler
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from utils import save_checkpoint, read_json, get_data_stats, collote_train_fn, collote_valid_fn, MAX_TARGET_LENGTH
from dataset import MengziT5Dataset
from pathlib import Path
from datetime import datetime 
from tqdm import tqdm 
from dotenv import load_dotenv 
load_dotenv()

checkpoint = "Langboat/mengzi-t5-base"

  from .autonotebook import tqdm as notebook_tqdm


# Preprocess data

In [2]:
def merge_qa_dataset(data, output_file_path):
    """
    Merges JSON entries with the same Context and Question into a single entry
    with a list of answers. Re-indexes IDs sequentially.
    """
    # Grouping Logic
    # We use a dictionary to group items.
    # Key: Tuple of (context, question) -> This ensures unique QA pairs
    # Value: List of answers
    grouped_data = {}

    print(f"Processing {len(data)} items...")

    for item in data:
        context = item.get('context', '').strip()
        question = item.get('question', '').strip()
        answer = item.get('answer', '')
        
        # Create a unique key for this specific question context
        key = (context, question)

        if key not in grouped_data:
            grouped_data[key] = []

        # Handle cases where input answer might already be a list or a string
        if isinstance(answer, list):
            grouped_data[key].extend(answer)
        else:
            grouped_data[key].append(answer)

    # 3. Reconstruct the List with new IDs
    new_json_data = []
    new_id_counter = 0

    for (context, question), answers in grouped_data.items():
        # Remove duplicate answers if you want unique references only
        answers = list(set(answers)) 
        
        entry = {
            "id": new_id_counter,
            "context": context,
            "question": question,
            "answer": answers  # This is now a list ["Ans1", "Ans2"]
        }
        new_json_data.append(entry)
        new_id_counter += 1

    # Save to new file
    with open(output_file_path, 'w', encoding='utf-8') as f:
        for obj in tqdm(new_json_data, desc="Writing to JSON file"):
            # ensure_ascii=False is crucial for Chinese characters to be readable
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

    print(f"Success! Merged data saved to {output_file_path}")
    print(f"Original count: {len(data)} -> New count: {len(new_json_data)}")

    return new_json_data

In [3]:
DATA_TRAIN_PATH = "data/train.json"
DATA_DEV_PATH = "data/dev.json"

DATA_FDEV_PATH = "data/formatted_dev.json"
DATA_DEV_PATH = "data/dev.json"

valid_data = read_json(DATA_DEV_PATH)
merged_valid_data = merge_qa_dataset(valid_data, DATA_FDEV_PATH)
# merged_valid_data = read_json(DATA_FDEV_PATH)

tokenizer = T5Tokenizer.from_pretrained(checkpoint) 

print("First valid data: ", merged_valid_data[0])
train_data = read_json(DATA_TRAIN_PATH)
print("First train data: ", train_data[0])


Reading JSON file: 984it [00:00, 144580.51it/s]


Processing 984 items...


Writing to JSON file: 100%|██████████| 700/700 [00:00<00:00, 93206.76it/s]


Success! Merged data saved to data/formatted_dev.json
Original count: 984 -> New count: 700
First valid data:  {'id': 0, 'context': '年基准利率4.35%。 从实际看,贷款的基本条件是: 一是中国大陆居民,年龄在60岁以下; 二是有稳定的住址和工作或经营地点; 三是有稳定的收入来源; 四是无不良信用记录,贷款用途不能作为炒股,赌博等行为; 五是具有完全民事行为能力。', 'question': '2017年银行贷款基准利率', 'answer': ['年基准利率4.35%', '4.35%']}


Reading JSON file: 14520it [00:00, 158519.51it/s]

First train data:  {'context': '第35集雪见缓缓张开眼睛，景天又惊又喜之际，长卿和紫萱的仙船驶至，见众人无恙，也十分高兴。众人登船，用尽合力把自身的真气和水分输给她。雪见终于醒过来了，但却一脸木然，全无反应。众人向常胤求助，却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世，清微语带双关说一切上了天界便有答案。长卿驾驶仙船，众人决定立马动身，往天界而去。众人来到一荒山，长卿指出，魔界和天界相连。由魔界进入通过神魔之井，便可登天。众人至魔界入口，仿若一黑色的蝙蝠洞，但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦，模仿重楼的翅膀，制作数对翅膀状巨物。刚佩戴在身，便被吸入洞口。众人摔落在地，抬头发现魔界守卫。景天和众魔套交情，自称和魔尊重楼相熟，众魔不理，打了起来。', 'answer': '第35集', 'question': '仙剑奇侠传3第几集上天界', 'id': 0}





In [4]:
get_data_stats(valid_data, tokenizer)

{'question_num': 984,
 'context_num': 984,
 'answer_num': 984,
 'question_mean_length': 6.5426829268292686,
 'context_mean_length': 192.15243902439025,
 'answer_mean_length': 4.774390243902439,
 'question_max_length': 18,
 'context_max_length': 728,
 'answer_max_length': 26}

In [5]:
get_data_stats(train_data, tokenizer)

{'question_num': 14520,
 'context_num': 14520,
 'answer_num': 14520,
 'question_mean_length': 6.488154269972452,
 'context_mean_length': 182.3798209366391,
 'answer_mean_length': 4.257782369146006,
 'question_max_length': 28,
 'context_max_length': 1180,
 'answer_max_length': 95}

In [6]:
valid_dataset = MengziT5Dataset(merged_valid_data, tokenizer)
train_dataset = MengziT5Dataset(train_data, tokenizer)

Total data filtered away: 19
Total data filtered away: 538


# Retrieve Model 

In [7]:
train_batch_size = 8
valid_batch_size = 8
#test_batch_size = 8

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = T5ForConditionalGeneration.from_pretrained(checkpoint)
model = model.to(device)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, collate_fn=lambda x: collote_train_fn(x, model, tokenizer))
train_data = next(iter(train_dataloader))
print("train input_ids: ", train_data['input_ids'])
print("train attention_mask: ", train_data['attention_mask'])
print("train decoder_input_ids", train_data['decoder_input_ids'])
print("train labels", train_data['labels'])
print("----------")

generator = torch.Generator().manual_seed(42)
valid_dataset, _ = random_split(valid_dataset, [0.5, 0.5], generator=generator)

valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=valid_batch_size, collate_fn=lambda x: collote_valid_fn(x, model, tokenizer))
valid_data = next(iter(valid_dataloader))
print("valid input_ids: ", valid_data['input_ids'])
print("valid attention_mask: ", valid_data['attention_mask'])
print("valid decoder_input_ids: ", valid_data['decoder_input_ids'])
print("valid labels:", valid_data['labels'])

# test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=valid_batch_size, collate_fn=lambda x: collote_fn(x, model, tokenizer))
# test_data = next(iter(test_dataloader))
# print("test input_ids: ", test_data['input_ids'])
# print("test attention_mask: ", test_data['attention_mask'])
# print("test decoder_input_ids: ", test_data['decoder_input_ids'])
# print("test labels:", test_data['labels'])


Loading weights: 100%|██████████| 282/282 [00:00<00:00, 567.66it/s, Materializing param=shared.weight]                                                       


train input_ids:  tensor([[  7, 143,  13,  ...,   0,   0,   0],
        [  7, 143,  13,  ...,   0,   0,   0],
        [  7, 143,  13,  ...,   0,   0,   0],
        ...,
        [  7, 143,  13,  ...,   0,   0,   0],
        [  7, 143,  13,  ...,   0,   0,   0],
        [  7, 143,  13,  ...,   0,   0,   0]])
train attention_mask:  tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
train decoder_input_ids tensor([[    0,  7973,  5946,  ...,     0,     0,     0],
        [    0,  7973,  2056,  ...,     0,     0,     0],
        [    0, 12598,    50,  ...,     0,     0,     0],
        ...,
        [    0,     7,   170,  ...,     0,     0,     0],
        [    0,  8921, 21976,  ...,     0,     0,     0],
        [    0,     7,  6390,  ...,     0,     0,     0]])
train labels tensor([[ 7973,  5946,   212,  ...,  -100,  -100,  -

# Train Model  

In [None]:
def train_loop(dataloader, model, optimizer, scheduler, epoch, global_step, use_wandb=False):
    model.train()
    # Reset loss counter at the start of the epoch
    epoch_loss_sum = 0.0 
    current_avg_loss = 0.0
    #cumulative_batch = len(dataloader) * (epoch - 1)
    
    with tqdm(total=len(dataloader)) as pbar:
        for batch_idx, batch_data in enumerate(dataloader, start=1):
            batch_data = batch_data.to(device)
            results = model(**batch_data)
            loss = results.loss

            # backward popagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            global_step += 1
            if use_wandb:
                wandb.log(
                    {"train_loss": loss.item()},
                    step=global_step
                )

            epoch_loss_sum += loss.item()
            current_avg_loss = epoch_loss_sum / batch_idx

            pbar.set_description(f"Epoch {epoch} | Avg Loss: {current_avg_loss:.4f}")
            pbar.update(1)


    return current_avg_loss, global_step 

def valid_loop(dataloader, model, tokenizer, epoch, global_step, use_wandb=False):
    model.eval()
    bleu = evaluate.load("bleu")
    loss = []
    val_loss_sum = 0.0

    #cumulative_batch = (epoch-1) * len(dataloader)
    all_preds = []
    all_labels = []

    with tqdm(total=len(dataloader)) as pbar:
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(dataloader, start=1):
                raw_references = batch_data.pop("answer", None)
                if raw_references is None:
                    print("No raw reference is found. Now create based on labels.")
                    temp_labels = torch.where(batch_data["labels"] != -100, batch_data["labels"], tokenizer.pad_token_id)
                    raw_references = [[ref] for ref in tokenizer.batch_decode(temp_labels, skip_special_tokens=True)]


                batch_data = batch_data.to(device)
                results = model(**batch_data)
                loss = results.loss
                val_loss_sum += loss.item() # Accumulate loss

                outputs = model.generate(
                    batch_data["input_ids"],
                    attention_mask=batch_data["attention_mask"],
                    max_new_tokens=MAX_TARGET_LENGTH,
                    num_beams=4
                    )
                decoded_outputs = tokenizer.batch_decode(
                    outputs,
                    skip_special_tokens=True
                    )
                # labels = batch_data['labels']
                # labels = torch.where(labels != -100, labels, tokenizer.pad_token_id)
                # decoded_labels = tokenizer.batch_decode(
                #     labels,
                #     skip_special_tokens=True
                # )

                batch_preds = []
                for pred in decoded_outputs:
                    if len(pred) == 0:
                        pred = " " # Prevent divided by zero during calculation of BLEU
                    batch_preds.append(pred)
                
                batch_labels = []
                for ref_list in raw_references: # ref_list: [ans1, ans2, ...]
                    processed_ref_list = []
                    for ref in ref_list:
                        cleaned_ref = ref.strip()
                        processed_ref_list.append(' '.join(cleaned_ref.strip()))
                    batch_labels.append(processed_ref_list)

                # batch_preds = [' '.join(pred.strip()) for pred in decoded_outputs]
                # batch_labels = [' '.join(label.strip()) for label in decoded_labels]
                print(f"First data: decoded output: {decoded_outputs[0]}, ref: {raw_references[0]}")
                all_preds.extend(batch_preds)
                all_labels.extend(batch_labels)

                pbar.update(1)

            bleu_result = bleu.compute(predictions=all_preds, references=all_labels)
            result = {f"bleu-{i}" : value for i, value in enumerate(bleu_result["precisions"], start=1)}
            result['avg'] = bleu_result['bleu']
            avg_val_loss = val_loss_sum / len(dataloader)
            log_dict = {
                "val_loss": avg_val_loss,
                "BLEU_avg": bleu_result['bleu'], # 'bleu' is the avg in huggingface evaluate
                "BLEU_1": bleu_result['precisions'][0],
                "BLEU_2": bleu_result['precisions'][1],
                "BLEU_3": bleu_result['precisions'][2],
                "BLEU_4": bleu_result['precisions'][3],
                "epoch": epoch
            }
            if use_wandb:
                wandb.log(
                    log_dict,
                    step=global_step
                )
            print(f"Test result: BLEU_avg={result['avg']}, BLEU1={result['bleu-1']}, BLEU2={result['bleu-2']}, BLEU3={result['bleu-3']}, BLEU4={result['bleu-4']}")
            return result

In [10]:
learning_rate = 2e-5
epoch_num = 5
best_model_name = "best_t5.pt"
current_t = datetime.now().strftime('%d-%m-%y-%H_%M')
foldername =  current_t + '_ckpt'
checkpoint_path = Path(f"./checkpoint/{foldername}")
checkpoint_path.mkdir(parents=True, exist_ok=True)
file_path = checkpoint_path / best_model_name
recent_checkpoints = []
use_wandb = True

if use_wandb:
    wandb.init(
        project="mengzi-t5-qa",   # The name of project on the website
        name=f"{current_t}",  # Name of this specific training run
        config={        
            "learning_rate": learning_rate,
            "batch_size": train_batch_size,
            "epochs": epoch_num,
            "model": "mengzi-t5-base"
        }
    )

num_training_steps = epoch_num * len(train_dataloader)
optimizer = AdamW(model.parameters(), lr=learning_rate)
scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

global_step = 0
best_bleu = 0
for epoch in range(epoch_num):
    avg_loss, global_step = train_loop(train_dataloader, model, optimizer, scheduler, epoch, global_step, use_wandb=use_wandb)
    valid_bleu = valid_loop(valid_dataloader, model, tokenizer, epoch, global_step, use_wandb=use_wandb)
    bleu_avg = valid_bleu['avg']
    save_checkpoint(model, epoch, checkpoint_path, recent_checkpoints)
    if bleu_avg > best_bleu:
        best_bleu = bleu_avg 
        print("Saving new best weights ...")
        torch.save(model.static_dict() , file_path)
        print("Finish saving.")
    

print("Finish training")

[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.
[34m[1mwandb[0m: Currently logged in as: [33mlamyeungkong0108[0m ([33mlamyeungkong0108-the-hong-kong-university-of-science-and[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0 | Avg Loss: 8.0900: 100%|██████████| 1748/1748 [05:53<00:00,  4.95it/s]
100%|██████████| 43/43 [01:16<00:00,  1.77s/it]


Test result: BLEU_avg=0.0, BLEU1=0.023484434735117424, BLEU2=0.003209242618741977, BLEU3=0.0, BLEU4=0.0
Saving checkpoint to checkpoint/02-02-26-14_04_ckpt/ckpt-epoch0.pt


Epoch 1 | Avg Loss: 6.8229: 100%|██████████| 1748/1748 [05:52<00:00,  4.96it/s]
100%|██████████| 43/43 [01:15<00:00,  1.76s/it]


Test result: BLEU_avg=0.0, BLEU1=0.01790710688304421, BLEU2=0.0031308703819661866, BLEU3=0.0006548788474132286, BLEU4=0.0
Saving checkpoint to checkpoint/02-02-26-14_04_ckpt/ckpt-epoch1.pt


Epoch 2 | Avg Loss: 6.0911:  83%|████████▎ | 1451/1748 [04:53<01:00,  4.95it/s]
socket.send() raised exception.


KeyboardInterrupt: 

socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7f986c3f96c0>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f9673112590, execution_count=10 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7f9673111960, raw_cell="learning_rate = 2e-5
epoch_num = 5
best_model_name.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B7b22686f73744e616d65223a22737368202d7020333436363320726f6f7440636f6e6e6563742e637161312e7365657461636c6f75642e636f6d227d/home/mengzi-t5-QuestionAnswering-/mengzi-t5-base/train.ipynb#X16sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


BrokenPipeError: [Errno 32] Broken pipe

In [None]:
valid_data = next(iter(valid_dataloader))
batch_data = valid_data.to(device)
outputs = model.generate(
    batch_data["input_ids"],
    attention_mask=batch_data["attention_mask"],
    max_new_tokens=MAX_TARGET_LENGTH,
    num_beams=4
    )
decoded_outputs = tokenizer.batch_decode(
    outputs,
    skip_special_tokens=True
    )
print(tokenizer.batch_decode(
    batch_data["input_ids"],
    skip_special_tokens=True
    ))
print(decoded_outputs)

['问题:防晒伞什么牌子好 上下文:喜途汽车安全伞就不错啊伞用起来不错做工挺精致的布料也比较厚实外形特别的漂亮。|防晒品的作用是阻止阳光中紫外线的照射牌子有相宜本草兰芝欧莱雅等。防晒的方法是多吃蔬菜和水果戴帽子、打遮阳伞、穿长袖外衣。', '问题:苹果6换苹果7多少钱呀 上下文:如果你要用苹果6补差换购全新正品苹果7可以到下面苹果实体店补差换购哈,他们哪里差不只需要补4000左右就可以换全新正品苹果7手机了哈。 苹果专卖店(赛格店) 地址:太升南路222号赛格广场4楼1032号', '问题:屠呦呦获诺贝尔奖奖金是多少 上下文:约46万美元。今年生理学或医学奖奖金共800万瑞典克朗约合92万美元屠呦呦将获得奖金的一半另外两名科学家将共享奖金的另一半。诺贝尔奖的奖金总是以瑞典的货币瑞典克朗颁发在同一年里各项奖金的数额是相同的不同的年份奖金数额有所变动其幅度主要取决于市场行情。每年的奖金金额视诺贝尔基金的投资收益而定1901年第一次颁奖的时候每单项的奖金为15万瑞典克朗当时相当于瑞典一个教授工作20年的薪金。1980年诺贝尔奖的单项奖金达到100万瑞典克朗1991年为600万瑞典克朗1992年为650万瑞典克朗1993年为670万瑞典克朗2000年单项奖金达到了900万瑞典克朗当时约折合100万美元。从2001年到2011年单项奖金均为1000万瑞典克朗在2011年折合约145万美元。', '问题:盾构机一台多少钱 上下文:一般直径的地铁盾构(也就是直径6米的)大约4000万左右(国内盾构生产厂商),国外的(如德国、美国)会稍微贵一些(日本可能会便宜一些)。大直径的盾构机会更贵,具体价格视盾构直径而定,当然土压平衡盾构和泥水平衡盾构的价钱也不一样,简单来说泥水平衡盾构会比相同直径的土压平衡盾构贵一些。', '问题:宁夏旅游几月份去好 上下文:宁夏的温差很大春季沙尘冬季寒冷干燥。所以建议你夏天去宁夏吧。最好是五一至十一之间。1、五月和九月白天的温度可以穿裙子但早晚凉。不适合下水游泳。冬泳选手例外。六月时可以摘樱桃等七月八月瓜果成熟尤其是西瓜包沙包甜哈密瓜、香瓜等都甜的很九月的葡萄绝对甜而且皮薄宁夏的葡萄可是酿酒的原材料啊地理条件和气候都类似法国所以葡萄质量绝对可靠。还有丰收的枸杞是特产是良药可入菜和粥。沙枣子呵呵和普通的大枣绝对不同包你印象深刻。2、代表性的玩沙湖、镇北堡西