In [None]:
import os
import json
import torch
import random
import openpyxl
from tqdm import tqdm
from rouge import Rouge
from transformers import BartForConditionalGeneration, AutoTokenizer

In [None]:
device = 'cuda' if(torch.cuda.is_available()) else 'cpu'
device

### 加载模型

In [None]:
model=BartForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-BART-139M-SUMMARY')
tokenizer=AutoTokenizer.from_pretrained('IDEA-CCNL/Randeng-BART-139M-SUMMARY')

### 数据集

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, train_data, valid_data):
        """输入data格式:
            [{'title': '</s>....', 'content': '......'}, {}, ..., {}]
        """
        super().__init__()
        for i in tqdm(range(len(train_data))):
            train_data[i]['title'] = '</s>' + train_data[i]['title'].replace('<s>', '').replace('</s>', '').strip()
            train_data[i]['content'] = train_data[i]['content'].strip()
            
        for i in tqdm(range(len(valid_data))):
            valid_data[i]['title'] = '</s>' + valid_data[i]['title'].replace('<s>', '').replace('</s>', '').strip()
            valid_data[i]['content'] = valid_data[i]['content'].strip()
        
        self.train_data = train_data
        self.valid_data = valid_data
        
        self.train = True
    
    def __len__(self):
        if(self.train):
            return len(self.train_data)
        else:
            return len(self.valid_data)
    
    def __getitem__(self, idx):
        if(self.train):
            return self.train_data[idx]
        else:
            return self.valid_data[idx]

In [None]:
def collect_from_self_summary(root):
    train_data = []; valid_data = []
    for p in os.listdir(root):
        data = []
        workbook = openpyxl.load_workbook(os.path.join(root,p))
        table = workbook.active
        rows = table.max_row
        for row in tqdm(range(2, rows+1)):
            d = {'title': table.cell(row, 2).value,
                'content': table.cell(row, 1).value}
            
            if(d['content'] and d['title'] and d['title'].strip()!= ''):
                data.append(d)
                
        # valid_data += data[:40]
        train_data += data[:]
        workbook.close()
    return train_data, valid_data

In [None]:
weibo = json.load(open("./dataset/text_summary/news_title/weibo_data.json", encoding='utf-8'))
# random.shuffle(weibo)
train_data, valid_data = collect_from_self_summary('./dataset/text_summary/self_summary/labeled/')
valid_data = weibo[:256]
weibo = weibo[256:1024]
dataset = Dataset(weibo + train_data, valid_data)
print('微博: ', len(weibo), '自己打的: ', len(train_data), len(valid_data))

In [None]:
def collate_fn(data):
    titles = [i['title'] for i in data]
    contents = [i['content'] for i in data]
    labels = [i['title'][4:] for i in data]
    
    decoder_input_ids = tokenizer.batch_encode_plus(titles, return_tensors='pt', return_attention_mask=False, padding=True, add_special_tokens=False)
    contents = tokenizer.batch_encode_plus(contents, return_tensors="pt", padding=True, add_special_tokens=True, max_length=512, truncation=True)
    labels = tokenizer.batch_encode_plus(labels, return_tensors="pt", padding=True, return_attention_mask=False, add_special_tokens=True)
    
    data = {}
    data['input_ids'] = contents['input_ids'].to(device)
    data['attention_mask'] = contents['attention_mask'].to(device)
    data['decoder_input_ids'] = decoder_input_ids['input_ids'].to(device)
    data['labels'] = labels['input_ids'].to(device)

    return data

loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=2,
                                     collate_fn=collate_fn,
                                     drop_last=True)

### 测试

In [None]:
def rouge_score(model, dataset, loader):
    dataset.train = False 
    model.eval()
    rouge = Rouge()
    preds = []; labels = []
    # 生成验证集所有摘要
    for d in loader:
        lens = len(d['input_ids'])
        for idx in range(lens):
            inputs = d['input_ids'][idx,:][None]
            pred = tokenizer.decode(model.generate(inputs, max_length=128, do_sample=False)[0]).replace('</s>', '').strip()
            label = tokenizer.decode(d['labels'][idx,:]).replace('</s>', '').replace('<pad>', '').strip()
            preds.append(pred)
            labels.append(label)
    # 计算rouge_score
    rouge_score = rouge.get_scores(preds, labels)
    res = {
        'r-1': 0, 'r-2': 0, 'r-l': 0
    }
    for i in rouge_score:
        res['r-1'] += i['rouge-1']['f']
        res['r-2'] += i['rouge-2']['f']
        res['r-l'] += i['rouge-l']['f']
    res['r-1'] = round(res['r-1'] / len(rouge_score), 2)
    res['r-2'] = round(res['r-2'] / len(rouge_score), 2)
    res['r-l'] = round(res['r-l'] / len(rouge_score), 2)
    
    model.train()
    dataset.train = True
    
    return res

### 训练

In [17]:
def train(model, epoches, lr):
    lens = len(loader)
    model = model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=lr)
    losses = torch.zeros((epoches, lens))
    for i in range(epoches):
        with tqdm(total=lens, ncols=150) as bar:
            bar.set_description('训练进度-epoch: {}/{}'.format(i+1,epoches))
            for n,d in enumerate(loader):
                loss = model(**d)['loss']
                loss.backward()
                optim.step()
                optim.zero_grad()
                 
                losses[i,n] += loss.item()
                bar.update(1)
            # res = rouge_score(model, dataset, loader)
            bar.set_postfix(loss = '{:.4f}'.format(losses[i].mean().item()))#, **res)
    model.cpu()
    return losses

losses = train(model, 2, 2e-5)
# model.cpu()
torch.save(model, './models/text_summary2.model')

### 使用

In [None]:
def generate(mdoel, text, max_new_tokens):
    inputs = tokenizer.encode_plus(text, return_tensors='pt', add_special_tokens=False)
    decoder_input_ids = tokenizer.encode_plus('</s>', return_tensors='pt', add_special_tokens=False)
    for _ in range(max_new_tokens):
        logits = model(**inputs, decoder_input_ids=decoder_input_ids['input_ids'])['logits']
        # focus only the last time step
        logits = logits[:, -1, :]   # becomes (B, C)
        # apply softmax to get probabilities
        probs = torch.nn.functional.softmax(logits, dim=1)
        # sample from the distribution
        # idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
        idx_next = probs.argmax(dim=1, keepdim=True)  # (B, 1)
        # append sampled index to the running squence
        decoder_input_ids['input_ids'] = torch.cat((decoder_input_ids['input_ids'], idx_next), dim=1)
    
    text = tokenizer.decode(decoder_input_ids['input_ids'][0])
    return text

In [None]:
# 导入模型
model = torch.load('./models/text_summary.model')

In [None]:
model.eval()
text = '一加2要来了,搭载了满血版的骁龙8处理器,跑分高达'
inputs = tokenizer.encode_plus(text, return_tensors='pt')
model.cpu()
print('手动生成: ', generate(model, text, max_new_tokens=15))
model = model.to(device)
print('API生成: ', tokenizer.decode(model.generate(inputs['input_ids'].to(device), max_length=128, do_sample=False)[0]).replace('</s>', '').strip())

In [None]:
inputs['input_ids']