问答模型

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from torch import functional
from torch.utils.data import DataLoader, Dataset
# Load the tokenizer and model
from sacrebleu.metrics import BLEU

tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
model = T5ForConditionalGeneration.from_pretrained("Langboat/mengzi-t5-base")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
batch_size = 8

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo):

In [2]:
print(device)
model

cuda


T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo):

数据集处理

In [3]:
import json
class QADataSet(Dataset):
    def __init__(self, file_path):
        super().__init__()
        self.data_path = file_path
        self.data = self.load_data()
    
    def load_data(self):
        data = []
        with open(self.data_path, 'rt', encoding='utf-8') as f:
            for line in f:
                json_data = json.loads(line)
                question = json_data["question"]
                context = json_data["context"]
                data.append({
                    "input": f"问题是:{question},文章:{context}",
                    "answer": json_data["answer"]
                })
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    

In [4]:
from torch.utils.data import random_split
dataset = QADataSet('DuReaderQG/train.json')
train_dataset, valid_dataset = random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])
test_dataset = QADataSet('DuReaderQG/dev.json')
print(len(train_dataset), len(test_dataset))

11616 984


In [5]:
maxlength = 0
for data in train_dataset:
    inputs = tokenizer(data["input"],truncation=True, padding=True, max_length=10240, return_tensors="pt")
    maxlength = max(maxlength, inputs.input_ids.shape[1])
    inputs = tokenizer(data["answer"],truncation=True, padding=True, max_length=10240, return_tensors="pt")
    maxlength = max(maxlength, inputs.input_ids.shape[1])

for data in test_dataset:
    inputs = tokenizer(data["input"],truncation=True, padding=True, max_length=10240, return_tensors="pt")
    maxlength = max(maxlength, inputs.input_ids.shape[1])
    inputs = tokenizer(data["answer"],truncation=True, padding=True, max_length=10240, return_tensors="pt")
    maxlength = max(maxlength, inputs.input_ids.shape[1])
maxlength

1139

In [5]:
inputs = tokenizer(train_dataset[0]["input"], truncation=True, padding=True, max_length=1280, return_tensors="pt")
print(inputs)
print(tokenizer.convert_ids_to_tokens(inputs.input_ids[0]))

{'input_ids': tensor([[ 8080,    13, 16981,  5019,  3832,  1327,  4325,     3,  1385,    13,
          5019,  3832,  1562,   105,    83,  5555,     3, 10112, 14047,     3,
          5019,  7605,   202,   990,     3,  4405,     3,  3085,  3260,     5,
             3,   176, 26346,  5807,   176,  4398,    37,     3,    33,    32,
          4758,   371, 16202,  2424,    32,    26,     3,  1562,  1143,    35,
         22008,    10,     3, 15999,  4346,     3,   287,   208,   252,     3,
           287,    54,   119,   452,     3,  1191,   760,    55,  1359,  3985,
          1562,     3,  1562,    38,  6041,    37,   371,  2756,    35,  1258,
             3,  3690,    32,   162,  6734,   162,     3,  6308,   142,  4734,
         16981,  5019,  3832,  1084,   232,  1435,     3,  1937,   138,   762,
           108,   394,  7506,     3, 10994,    35,  2637,  1370,   285,    10,
             4,  2149,  1143,  3260,   489,   291,  3719,  1441,     3,  2269,
             5,   403,  1262,     4,  

In [6]:
def collate_fn(batch):
    inputs = []
    answers = []
    for b in batch:
        inputs.append(b["input"])
        answers.append(b["answer"])

    batch_data = tokenizer(inputs, truncation=True, padding=True, max_length=1280, return_tensors="pt")

    with tokenizer.as_target_tokenizer():
        answer_token = tokenizer(answers, truncation=True, padding=True, max_length=1280, return_tensors="pt").input_ids
        
        batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(answer_token)
        eos_token_id = torch.where(answer_token == tokenizer.eos_token_id)[1]
        for idx, eos_id in enumerate(eos_token_id):
            answer_token[idx][eos_id + 1:] = -100  # Mask out the tokens after the EOS token
        batch_data['labels'] = answer_token
    
    return batch_data

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [8]:
batch = next(iter(train_loader))
print(batch.keys())
print('batch shape:', {k: v.shape for k, v in batch.items()})
print(batch['decoder_input_ids'][0])
print(batch['labels'][0])
print(batch['input_ids'][0])

dict_keys(['input_ids', 'attention_mask', 'decoder_input_ids', 'labels'])
batch shape: {'input_ids': torch.Size([8, 337]), 'attention_mask': torch.Size([8, 337]), 'decoder_input_ids': torch.Size([8, 9]), 'labels': torch.Size([8, 9])}
tensor([    0,  4989,   280,  1084,  9181,   165, 18767,  1152,   443])
tensor([ 4989,   280,  1084,  9181,   165, 18767,  1152,   443,     1])
tensor([ 8080,    13,  1095,   224,    24,  6196,  1257,   443,     3,  1385,
           13,    39,   224,    24,    18,    13,  6596,   138,  1359,  8952,
         2294,    42,   165,     5,   672,   394,   316,     7, 27500,  1733,
         1073,   753,  1095,   224,    24,    18,    13,  6596,   138,  1359,
         8952,  3757,  2835,  9914,  9914,  9914,  9914,  3757,  1095,     7,
        27500,  1733,  1073,   753,  3479, 20769,  1084,  7830,  5159,  2294,
           42,   165,     5,  1238,   394,   316,     3,  2007,  4734,   100,
         9181,   165, 18767,   218,    22,   443,   742,  2424,    25,  9181



训练

In [9]:
from tqdm.auto import tqdm
def train_loop(dataloader,model,optimizer,epoch, lr_scheduler,total_loss,device):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)
    model.train()
    
    for batch,data in enumerate(dataloader,start=1):
        data = data.to(device)
        output = model(**data)
        loss = output.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss / (finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    
    return total_loss


In [10]:
# bleu1
blue_1 = BLEU(max_ngram_order=1)
blue_2 = BLEU(max_ngram_order=2)
blue_3 = BLEU(max_ngram_order=3)
blue_4 = BLEU()

In [11]:
blue_2

<sacrebleu.metrics.bleu.BLEU at 0x1be7915bc90>

In [None]:
import numpy as np
def test_loop(dataloader,tokenizer,model,device):
    labels = []
    predictions = []
    model.eval()
    for batch, data in enumerate(dataloader):
        data = data.to(device)
        with torch.no_grad():
            output = model.generate(data["input_ids"],
                attention_mask=data["attention_mask"],
                max_length=1280,
                num_beams=4,
                no_repeat_ngram_size=2,
            )
        if isinstance(output, tuple):
            output = output[0]
        
        decoded_preds = tokenizer.batch_decode(output, skip_special_tokens=True)
        predictions += [' '.join(pred.strip()) for pred in decoded_preds]

        label_token = data["labels"].cpu().numpy()
        label_token = np.where(label_token == -100, tokenizer.pad_token_id, label_token)
        decoded_label = tokenizer.batch_decode(label_token, skip_special_tokens=True)

        labels += [' '.join(label.strip()) for label in decoded_label]
        print(f"batch: {batch}")
    
    bleu1 = blue_1.corpus_score(predictions, [labels]).score
    bleu2 = blue_2.corpus_score(predictions, [labels]).score
    bleu3 = blue_3.corpus_score(predictions, [labels]).score
    bleu4 = blue_4.corpus_score(predictions, [labels]).score
    print(f"BLEU-1: {bleu1:.2f}, BLEU-2: {bleu2:.2f}, BLEU-3: {bleu3:.2f}, BLEU-4: {bleu4:.2f}")
    return bleu1, bleu2, bleu3, bleu4



In [None]:
from matplotlib import pyplot as plt

lr = 1e-4
epochs = 10
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
from transformers import get_scheduler
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=len(train_loader) * epochs,
)
loss_history = []
maxbleusum = 0.0
for epoch in range(epochs):
    print(f"Epoch {epoch}/{epochs}")
    total_loss = 0.0
    # total_loss = train_loop(train_loader, model, optimizer, epoch + 1, lr_scheduler, total_loss, device)
    loss_history.append(total_loss)
    bleu1, bleu2, bleu3, bleu4 = test_loop(valid_loader, tokenizer, model, device)
    bleusum = (bleu1 + bleu2 + bleu3 + bleu4) / 4
    if bleusum > maxbleusum or epoch == 0:
        maxbleusum = bleusum
        model.save_pretrained(f"mengzi-t5-base-finetuned-epoch-{epoch}")
        print(f"Model saved at epoch {epoch} with BLEU sum: {bleusum:.2f}")

# 绘制 loss 曲线
plt.figure(figsize=(10, 5))
plt.plot(loss_history, marker='o')
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.title("Training Loss Curve")
plt.show()

Epoch 0/10




batch: 0
batch: 1
batch: 2
batch: 3
batch: 4
batch: 5
batch: 6
batch: 7
batch: 8
batch: 9
batch: 10
batch: 11
batch: 12
batch: 13
batch: 14
batch: 15
batch: 16
batch: 17
batch: 18
batch: 19
batch: 20
batch: 21
batch: 22
batch: 23
batch: 24
batch: 25
batch: 26
batch: 27
batch: 28
batch: 29
batch: 30
batch: 31
batch: 32
batch: 33
batch: 34
batch: 35
batch: 36
batch: 37
batch: 38
batch: 39
batch: 40
batch: 41
batch: 42
batch: 43
batch: 44
batch: 45
batch: 46
batch: 47
batch: 48
batch: 49
batch: 50
batch: 51
batch: 52
batch: 53
batch: 54
batch: 55
batch: 56
batch: 57
batch: 58
batch: 59
batch: 60
batch: 61
batch: 62
batch: 63
batch: 64
batch: 65
batch: 66
batch: 67
batch: 68
batch: 69
batch: 70
batch: 71
batch: 72
batch: 73
batch: 74
batch: 75
batch: 76
batch: 77
batch: 78
batch: 79
batch: 80
batch: 81
batch: 82
batch: 83
batch: 84
batch: 85
batch: 86
batch: 87
batch: 88
batch: 89
batch: 90
batch: 91
batch: 92
batch: 93
batch: 94
batch: 95
batch: 96
batch: 97
batch: 98
batch: 99
batch: 100



batch: 0
batch: 1
batch: 2
batch: 3
batch: 4
batch: 5
batch: 6
batch: 7
batch: 8
batch: 9
batch: 10
batch: 11
batch: 12
batch: 13
batch: 14
batch: 15
batch: 16
batch: 17
batch: 18
batch: 19
batch: 20
batch: 21
batch: 22
batch: 23
batch: 24
batch: 25
batch: 26
batch: 27
batch: 28
batch: 29
batch: 30
batch: 31
batch: 32
batch: 33
batch: 34
batch: 35
batch: 36
batch: 37
batch: 38
batch: 39
batch: 40
batch: 41
batch: 42
batch: 43
batch: 44
batch: 45
batch: 46
batch: 47
batch: 48
batch: 49
batch: 50
batch: 51
batch: 52
batch: 53
batch: 54
batch: 55
batch: 56
batch: 57
batch: 58
batch: 59
batch: 60
batch: 61
batch: 62
batch: 63
batch: 64
batch: 65
batch: 66
batch: 67
batch: 68
batch: 69
batch: 70
batch: 71
batch: 72
batch: 73
batch: 74
batch: 75
batch: 76
batch: 77
batch: 78
batch: 79
batch: 80
batch: 81
batch: 82
batch: 83
batch: 84
batch: 85
batch: 86
batch: 87
batch: 88
batch: 89
batch: 90
batch: 91
batch: 92
batch: 93
batch: 94
batch: 95
batch: 96
batch: 97
batch: 98
batch: 99
batch: 100

KeyboardInterrupt: 

测试

In [None]:
def test_loop_final(dataloader,tokenizer,model,device,save_path):
    labels = []
    predictions = []
    sources = []
    model.eval()
    for batch, data in enumerate(dataloader):
        data = data.to(device)
        with torch.no_grad():
            output = model.generate(data["input_ids"],
                attention_mask=data["attention_mask"],
                max_length=1280,
                num_beams=4,
                no_repeat_ngram_size=2,
            )
        if isinstance(output, tuple):
            output = output[0]
        
        decoded_sources = tokenizer.batch_decode(
            data["input_ids"].cpu().numpy(), 
            skip_special_tokens=True, 
            use_source_tokenizer=True
        )
        
        sources += [source.strip() for source in decoded_sources]

        decoded_preds = tokenizer.batch_decode(output, skip_special_tokens=True)
        predictions += [' '.join(pred.strip()) for pred in decoded_preds]

        label_token = data["labels"].cpu().numpy()
        label_token = np.where(label_token == -100, tokenizer.pad_token_id, label_token)
        decoded_label = tokenizer.batch_decode(label_token, skip_special_tokens=True)

        labels += [' '.join(label.strip()) for label in decoded_label]
        print(f"batch: {batch}")
    
    bleu1 = blue_1.corpus_score(predictions, [labels]).score
    bleu2 = blue_2.corpus_score(predictions, [labels]).score
    bleu3 = blue_3.corpus_score(predictions, [labels]).score
    bleu4 = blue_4.corpus_score(predictions, [labels]).score
    print(f"BLEU-1: {bleu1:.2f}, BLEU-2: {bleu2:.2f}, BLEU-3: {bleu3:.2f}, BLEU-4: {bleu4:.2f}")
    results = []
    for source, pred, label in zip(sources, predictions, labels):
        results.append({
            "context": source, 
            "prediction": pred, 
            "labels": label
        })
    with open(save_path, 'w', encoding='utf-8') as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')
    return bleu1, bleu2, bleu3, bleu4


In [17]:

bleu1, bleu2, bleu3, bleu4 = test_loop_final(test_loader, tokenizer, model, device, "test_predictions.json")
print(f"Test BLEU-1: {bleu1:.2f}, BLEU-2: {bleu2:.2f}, BLEU-3: {bleu3:.2f}, BLEU-4: {bleu4:.2f}")




batch: 0
batch: 1
batch: 2
batch: 3
batch: 4
batch: 5
batch: 6
batch: 7
batch: 8
batch: 9
batch: 10
batch: 11
batch: 12
batch: 13
batch: 14
batch: 15
batch: 16
batch: 17
batch: 18
batch: 19
batch: 20
batch: 21
batch: 22
batch: 23
batch: 24
batch: 25
batch: 26
batch: 27
batch: 28
batch: 29
batch: 30
batch: 31
batch: 32
batch: 33
batch: 34
batch: 35
batch: 36
batch: 37
batch: 38
batch: 39
batch: 40
batch: 41
batch: 42
batch: 43
batch: 44
batch: 45
batch: 46
batch: 47
batch: 48
batch: 49
batch: 50
batch: 51
batch: 52
batch: 53
batch: 54
batch: 55
batch: 56
batch: 57
batch: 58
batch: 59
batch: 60
batch: 61
batch: 62
batch: 63
batch: 64
batch: 65
batch: 66
batch: 67
batch: 68
batch: 69
batch: 70
batch: 71
batch: 72
batch: 73
batch: 74
batch: 75
batch: 76
batch: 77
batch: 78
batch: 79
batch: 80
batch: 81
batch: 82
batch: 83
batch: 84
batch: 85
batch: 86
batch: 87
batch: 88
batch: 89
batch: 90
batch: 91
batch: 92
batch: 93
batch: 94
batch: 95
batch: 96
batch: 97
batch: 98
batch: 99
batch: 100

NameError: name 'preds' is not defined