In [1]:
from torch.utils.data import Dataset

max_dataset_size = 2000

class LCSTS(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)
    
    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                if idx >= max_dataset_size:
                    break
                items = line.strip().split("!=!")
                assert len(items) == 2
                Data[idx] = {
                    'title': items[0],
                    'content': items[1]
                }
        return Data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
train_data = LCSTS('data/lcsts_tsv/data1.tsv')
valid_data = LCSTS('data/lcsts_tsv/data2.tsv')
test_data = LCSTS('data/lcsts_tsv/data3.tsv')

In [2]:
print(f'train set size: {len(train_data)}')
print(f'valid set size: {len(valid_data)}')
print(f'test set size: {len(test_data)}')
print(next(iter(train_data)))

train set size: 2000
valid set size: 2000
test set size: 1106
{'title': '修改后的立法法全文公布', 'content': '新华社受权于18日全文播发修改后的《中华人民共和国立法法》，修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、自治条例和单行条例、规章”“适用与备案审查”“附则”等6章，共计105条。'}


In [3]:
from transformers import AutoTokenizer

checkpoint = "csebuetnlp/mT5_multilingual_XLSum"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
inputs = tokenizer("我在苏州大学学计算机")
print(inputs)
print(tokenizer.convert_ids_to_tokens(inputs.input_ids))

{'input_ids': [259, 165389, 117707, 9792, 6696, 123553, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
['▁', '我在', '苏州', '大学', '学', '计算机', '</s>']


In [5]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM

max_input_length = 512
max_target_length = 64

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device} device")

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model = model.to(device)

def collote_fn(batch_samples):
    batch_inputs, batch_targets = [], []
    for sample in batch_samples:
        batch_inputs.append(sample['content'])
        batch_targets.append(sample['title'])
    batch_data = tokenizer(
        batch_inputs,
        padding=True,
        max_length=max_input_length,
        truncation=True,
        return_tensors='pt'
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch_targets,
            padding=True,
            max_length=max_target_length,
            truncation=True,
            return_tensors='pt'
        )['input_ids']
        batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
        end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
        for idx, end_index in enumerate(end_token_index):
            labels[idx][end_index + 1:] = -100
        batch_data['labels'] = labels
    return batch_data

train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=4, shuffle=False, collate_fn=collote_fn)

Using cpu device


In [6]:
batch = next(iter(train_dataloader))
print(batch.keys())
print('batch shape:', {k: v.shape for k, v in batch.items()})
print(batch)

dict_keys(['input_ids', 'attention_mask', 'decoder_input_ids', 'labels'])
batch shape: {'input_ids': torch.Size([4, 72]), 'attention_mask': torch.Size([4, 72]), 'decoder_input_ids': torch.Size([4, 17]), 'labels': torch.Size([4, 17])}
{'input_ids': tensor([[   259, 109505, 120788,   4002,    591,  99441, 200719,   3541,    261,
          10389,  33420,  19115,  24705,    365,   2037,  74090,    261,  15268,
         188757,  71650, 126455, 133373,  16436,  31761,    292,  14428,   5718,
         122670,   7973, 195050,    292,  97540,  36385,    292, 124916,   5811,
          20760,  14735,    292, 135926, 139877,    292,  31651,  37488,    292,
          11764,  62369,   1107, 125876,  56203,   2395,    449,   1146,  17723,
            306,  21239, 122153,  13061,  28169,   8773,   2811,  15268,   4775,
           6642,    306,      1,      0,      0,      0,      0,      0,      0],
        [   809,    848,  12267, 195305,  32063,  46601, 172901,  71813,  36633,
          46601,   708



In [7]:
from tqdm.auto import tqdm

def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f"loss: {0:>7f}")
    finish_batch_num = (epoch - 1) * len(dataloader)

    model.train()
    for batch, batch_data in enumerate(dataloader, start=1):
        batch_data = batch_data.to(device)
        outputs = model(**batch_data)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss
        progress_bar.set_description(f"loss: {total_loss / (finish_batch_num + batch):>7f}")
        progress_bar.update(1)
    return total_loss


In [8]:
from rouge import Rouge

generated_summary = "I absolutely loved reading the Hunger Games"
reference_summary = "I loved reading the Hunger Games"

rouge = Rouge()

scores = rouge.get_scores(
    hyps=[generated_summary], refs=[reference_summary]
)[0]
print(scores)

{'rouge-1': {'r': 1.0, 'p': 0.8571428571428571, 'f': 0.9230769181065088}, 'rouge-2': {'r': 0.8, 'p': 0.6666666666666666, 'f': 0.7272727223140496}, 'rouge-l': {'r': 1.0, 'p': 0.8571428571428571, 'f': 0.9230769181065088}}


In [9]:
from rouge import Rouge

generated_summary = "我在苏州大学学计算机，苏州大学很美丽"
reference_summary = "我在环境优美的苏州大学学计算机"

rouge = Rouge()

TOKENIZE_CHINESE = lambda x: ' '.join(x)

scores = rouge.get_scores(
    hyps=[TOKENIZE_CHINESE(generated_summary)],
    refs=[TOKENIZE_CHINESE(reference_summary)]
)[0]
print("ROUGE", scores)
scores = rouge.get_scores(
    hyps=[generated_summary],
    refs=[reference_summary]
)[0]
print("wrong ROUGE:", scores)

ROUGE {'rouge-1': {'r': 0.7142857142857143, 'p': 0.7692307692307693, 'f': 0.7407407357475996}, 'rouge-2': {'r': 0.5714285714285714, 'p': 0.5714285714285714, 'f': 0.5714285664285715}, 'rouge-l': {'r': 0.6428571428571429, 'p': 0.6923076923076923, 'f': 0.6666666616735254}}
wrong ROUGE: {'rouge-1': {'r': 0.0, 'p': 0.0, 'f': 0.0}, 'rouge-2': {'r': 0.0, 'p': 0.0, 'f': 0.0}, 'rouge-l': {'r': 0.0, 'p': 0.0, 'f': 0.0}}


In [10]:
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device} device")

checkpoint = "csebuetnlp/mT5_multilingual_XLSum"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model = model.to(device)

article_text = """
受众在哪里，媒体就应该在哪里，媒体的体制、内容、技术就应该向哪里转变。
媒体融合关键是以人为本，即满足大众的信息需求，为受众提供更优质的服务。
这就要求媒体在融合发展的过程中，既注重技术创新，又注重用户体验。
"""

input_ids = tokenizer(
    article_text,
    return_tensors='pt',
    truncation=True,
    max_length=512
)
generated_tokens = model.generate(
    input_ids['input_ids'],
    attention_mask=input_ids['attention_mask'],
    max_length=512,
    no_repeat_ngram_size=2,
    num_beams=4
)
summary = tokenizer.decode(
    generated_tokens[0],
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)
print(summary)

Using cpu device




媒体融合发展是当下中国面临的一大难题。


In [13]:
import numpy as np
from rouge import Rouge

rouge = Rouge()

def test_loop(dataloader, model):
    preds, labels = [], []

    model.eval()
    for batch_data in tqdm(dataloader):
        batch_data = batch_data.to(device)
        with torch.no_grad():
            generated_tokens = model.generate(
                batch_data['input_ids'],
                attention_mask=batch_data['attention_mask'],
                max_length=max_target_length,
                num_beams=4,
                no_repeat_ngram_size=2
            ).cpu().numpy()
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]
        label_tokens = batch_data['labels'].cpu().numpy()

        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        preds += [' '.join(pred.strip()) for pred in decoded_preds]
        labels += [' '.join(label.strip()) for label in decoded_labels]
    
    scores = rouge.get_scores(hyps=preds, refs=labels, avg=True)
    result = {key: value['f'] * 100 for key, value in scores.items()}
    result['avg'] = np.mean(list(result.values()))
    print(f"Rouge1: {result['rouge-1']:>0.2f} Rouge2: {result['rouge-2']:>0.2f} RougeL: {result['rouge-l']:>0.2f}\n")
    return result

In [None]:
from transformers import get_scheduler
from torch.optim import AdamW

lr = 2e-5
epoch_num=1

optimizer = AdamW(model.parameters(), lr=lr)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num * len(train_dataloader)
)

total_loss = 0.
best_avg_rouge = 0.
for t in range(epoch_num):
    print(f"Epoch {t + 1} / {epoch_num}\n-------------------------------")
    total_loss += train_loop(train_dataloader, model, optimizer, lr_scheduler, t + 1, total_loss)
    valid_rouge = test_loop(valid_dataloader, model)
    print(valid_rouge)
    rouge_avg = valid_rouge['avg']
    if rouge_avg > best_avg_rouge:
        best_avg_rouge = rouge_avg
        print('saving new weights...\n')
        torch.save(model.state_dict(), f"epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_weights.bin")
print('Done!')

Epoch 1 / 1
-------------------------------


loss: 3.590162: 100%|██████████| 500/500 [14:17<00:00,  1.72s/it]
100%|██████████| 500/500 [11:46<00:00,  1.41s/it]


Rouge1: 30.82 Rouge2: 18.84 RougeL: 28.09

{'rouge-1': 30.822807969169, 'rouge-2': 18.841270642392395, 'rouge-l': 28.08649294438796, 'avg': np.float64(25.91685718531645)}
saving new weights...

Done!
