In [1]:
from torch.utils.data import Dataset
import json

max_dataset_size = 20000

class DRQG(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)
    
    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                if idx >= max_dataset_size:
                    break
                data = json.loads(line)
                Data[idx] = {
                    'question': data['question'],
                    'context': data['context'],
                    'answer': data['answer'],
                    'id': data['id']
                }
        return Data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

train_data = DRQG('data/DuReaderQG/train.json')
valid_data = DRQG('data/DuReaderQG/dev.json')

In [2]:
print(f'train set size: {len(train_data)}')
print(f'valid set size: {len(valid_data)}')
print(next(iter(train_data)))

train set size: 14520
valid set size: 984
{'question': '仙剑奇侠传3第几集上天界', 'context': '第35集雪见缓缓张开眼睛，景天又惊又喜之际，长卿和紫萱的仙船驶至，见众人无恙，也十分高兴。众人登船，用尽合力把自身的真气和水分输给她。雪见终于醒过来了，但却一脸木然，全无反应。众人向常胤求助，却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世，清微语带双关说一切上了天界便有答案。长卿驾驶仙船，众人决定立马动身，往天界而去。众人来到一荒山，长卿指出，魔界和天界相连。由魔界进入通过神魔之井，便可登天。众人至魔界入口，仿若一黑色的蝙蝠洞，但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦，模仿重楼的翅膀，制作数对翅膀状巨物。刚佩戴在身，便被吸入洞口。众人摔落在地，抬头发现魔界守卫。景天和众魔套交情，自称和魔尊重楼相熟，众魔不理，打了起来。', 'answer': '第35集', 'id': 0}


In [3]:
from transformers import AutoTokenizer

checkpoint = 'langboat/mengzi-t5-base'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
inputs = tokenizer("我叫张三，在苏州大学学习计算机。")
print(inputs)
print(tokenizer.convert_ids_to_tokens(inputs.input_ids))

{'input_ids': [2900, 538, 232, 105, 3, 8, 4574, 278, 191, 2074, 4, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
['▁我', '叫', '张', '三', ',', '在', '苏州', '大学', '学习', '计算机', '。', '</s>']


In [5]:
context = [train_data[idx]['context'] for idx in range(4)]
question = [train_data[idx]['question'] for idx in range(4)]

inputs = tokenizer(
    question,
    context,
    max_length=512,
    truncation='only_second',
    padding=True,
    return_tensors='pt'
)
print(inputs)

{'input_ids': tensor([[    7,  1707,  1467,  ...,  6453,     4,     1],
        [    7,  9147, 14702,  ...,     0,     0,     0],
        [    7, 19918,   176,  ...,     0,     0,     0],
        [    7,  2747,  4403,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}


In [6]:
print(tokenizer.convert_ids_to_tokens(inputs.input_ids[0]))

['▁', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上天', '界', '</s>', '▁第', '35', '集', '雪', '见', '缓缓', '张开', '眼睛', ',', '景', '天', '又', '惊', '又', '喜', '之际', ',', '长', '卿', '和', '紫', '萱', '的', '仙', '船', '驶', '至', ',', '见', '众人', '无', '恙', ',', '也十分', '高兴', '。', '众人', '登', '船', ',', '用', '尽', '合力', '把', '自身的', '真', '气', '和', '水分', '输给', '她', '。', '雪', '见', '终于', '醒', '过来', '了', ',', '但却', '一脸', '木', '然', ',', '全', '无', '反应', '。', '众人', '向', '常', '胤', '求助', ',', '却发现', '人', '世界', '竟', '没有', '雪', '见', '的', '身世', '纪录', '。', '长', '卿', '询问', '清', '微', '的', '身世', ',', '清', '微', '语', '带', '双', '关', '说', '一切', '上了', '天', '界', '便', '有', '答案', '。', '长', '卿', '驾驶', '仙', '船', ',', '众人', '决定', '立马', '动', '身', ',', '往', '天', '界', '而去', '。', '众人', '来到', '一', '荒', '山', ',', '长', '卿', '指出', ',', '魔', '界', '和', '天', '界', '相连', '。', '由', '魔', '界', '进入', '通过', '神', '魔', '之', '井', ',', '便可', '登', '天', '。', '众人', '至', '魔', '界', '入口', ',', '仿', '若', '一', '黑色的', '蝙蝠', '洞', ',', '但', '始终', '无法', '进入', '。', '后来', 

In [7]:
from transformers import AutoModelForSeq2SeqLM
import torch
max_input_length = 512
max_target_length = 256

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model = model.to(device)

Using cpu device


In [8]:
from torch.utils.data import DataLoader

def collate_fn(batch_samples):
    batch_question, batch_context = [], []
    batch_target = []
    for sample in batch_samples:
        batch_question.append(sample['question'])
        batch_context.append(sample['context'])
        batch_target.append(sample['answer'])
    batch_data = tokenizer(
        batch_question,
        batch_context,
        padding=True,
        max_length=max_input_length,
        truncation='only_second',
        return_tensors='pt'
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch_target,
            max_length=max_target_length,
            padding=True,
            truncation=True,
            return_tensors='pt'
        )['input_ids']
        batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
        end_token_idx = torch.where(labels == tokenizer.eos_token_id)[1]
        for idx, end_idx in enumerate(end_token_idx):
            labels[idx][end_idx + 1:] = -100
        batch_data['labels'] = labels
    return batch_data

train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_data, batch_size=4, shuffle=False, collate_fn=collate_fn)

In [9]:
batch = next(iter(train_dataloader))
print(batch.keys())
print('batch shape:', {k: v.shape for k, v in batch.items()})
print(batch)

dict_keys(['input_ids', 'attention_mask', 'decoder_input_ids', 'labels'])
batch shape: {'input_ids': torch.Size([4, 207]), 'attention_mask': torch.Size([4, 207]), 'decoder_input_ids': torch.Size([4, 13]), 'labels': torch.Size([4, 13])}
{'input_ids': tensor([[    7,  8613, 18784, 10912,  5349,     1,     7,   838,    68,  2946,
            90,     3,   354, 10526,   125,    43,  4057,  7698, 25018,     3,
          6318,  3162,  8210, 18784, 12611, 10047,     3,   354,   883,     5,
             3,   112,   851,    10,  2324,  7999,    10,  4734,  8613, 18784,
         10526,  5349,    66,  7698, 25018,  1449,     4, 11307,   600,  1619,
           366,   993,   495,   177,   364,  4932,  7372,    74,   364,  4932,
          7372,  8395,     1,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,   



In [11]:
from tqdm.auto import tqdm

def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch - 1) * len(dataloader)

    model.train()
    for batch, batch_data in enumerate(dataloader, start=1):
        batch_data = batch_data.to(device)
        outputs = model(**batch_data)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss / (finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

In [12]:
from rouge import Rouge

rouge = Rouge()

generated_summary = "我在苏州大学学习计算机，苏州大学很美丽。"
reference_summary = "我在环境优美的苏州大学学习计算机。"

TOKENIZE_CHINESE = lambda x: ' '.join(x)

scores = rouge.get_scores(
    hyps=[TOKENIZE_CHINESE(generated_summary)],
    refs=[TOKENIZE_CHINESE(reference_summary)]
)
print('Rouge', scores)

Rouge [{'rouge-1': {'r': 0.75, 'p': 0.8, 'f': 0.7741935433922998}, 'rouge-2': {'r': 0.5625, 'p': 0.5625, 'f': 0.562499995}, 'rouge-l': {'r': 0.6875, 'p': 0.7333333333333333, 'f': 0.7096774143600416}}]


In [13]:
import numpy as np
from rouge import Rouge

rouge = Rouge()

def test_loop(dataloader, model):
    preds, labels = [], []

    model.eval()
    for batch_data in tqdm(dataloader):
        batch_data = batch_data.to(device)

        with torch.no_grad():
            generated_tokens = model.generate(
                batch_data['input_ids'],
                attention_mask=batch_data['attention_mask'],
                max_length=max_target_length,
                num_beams=4,
                no_repeat_ngram_size=2
            ).cpu().numpy()
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]
        label_tokens = batch_data['labels'].cpu().numpy()

        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        preds += [' '.join(pred.strip()) for pred in decoded_preds]
        labels += [' '.join(label.strip()) for label in decoded_labels]
    
    scores = rouge.get_scores(hyps=preds, refs=labels, avg=True)
    result = {key: value['f'] * 100 for key, value in scores.items()}
    print(f"Rouge1: {result['rouge-1']:>0.2f} Rouge2: {result['rouge-2']:>0.2f} RougeL: {result['rouge-l']:>0.2f}\n")
    return  result
    # return preds, labels

In [None]:
# result = test_loop(valid_dataloader, model)
preds, labels = test_loop(valid_dataloader, model)

In [None]:
from modelscope import AutoModelForSeq2SeqLM
model1 = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model1

In [None]:
valid_iter = iter(train_dataloader)
for _ in range(4):
    batch_data = next(valid_iter)
    print(batch_data)
    generated = model1.generate(
        batch_data['input_ids'],
        attention_mask=batch_data['attention_mask'],
        max_length=max_target_length,
        num_beams=4,
        no_repeat_ngram_size=2
    )
    print('generated', generated)
    decoded = tokenizer.batch_decode(generated, skip_special_tokens=True)
    print(decoded)

In [None]:
preds

In [None]:
next(iter(valid_dataloader))

In [None]:
scores = rouge.get_scores(hyps=preds, refs=labels, avg=True)

In [None]:
for idx, (pred, label) in enumerate(zip(preds, labels)):
    if not label or not pred:
        print('pred: ', pred)
        print('label:', label)
        print('idx:', idx)

In [14]:
from transformers import get_scheduler
from torch.optim import AdamW

learning_rate = 2e-5
epoch_num = 1

optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
best_avg_rouge = 0.
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)
    valid_rouge = test_loop(valid_dataloader, model)
    print(valid_rouge)
    rouge_avg = valid_rouge['avg']
    if rouge_avg > best_avg_rouge:
        best_avg_rouge = rouge_avg
        print('saving new weights...\n')
        torch.save(model.state_dict(), f'epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_weights.bin')
print("Done!")


Epoch 1/1
-------------------------------


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
loss: 0.000000:   0%|          | 0/3630 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
loss: 7.795273:   0%|          | 6/3630 [00:14<1:52:23,  1.86s/it] 

KeyboardInterrupt: 