In [1]:
import json
import torch
import random
from tqdm import tqdm
from transformers import BartForConditionalGeneration, AutoTokenizer

In [2]:
device = 'cuda' if(torch.cuda.is_available()) else 'cpu'
device

'cuda'

### 加载模型

In [3]:
model=BartForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-BART-139M-SUMMARY')
tokenizer=AutoTokenizer.from_pretrained('IDEA-CCNL/Randeng-BART-139M-SUMMARY')

### 数据集

In [4]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        """输入data格式:
            [{'title': '</s>....', 'content': '......'}, {}, ..., {}]
        """
        random.shuffle(data)
        for i in tqdm(range(len(data))):
            data[i]['title'] = '</s>' + data[i]['title'].replace('<s>', '').replace('</s>', '').strip()
            data[i]['content'] = data[i]['content'].strip()
        
        data = data[:2000]
        train_len = int(0.7 * len(data))
        self.train_data = data[:train_len]
        self.test_data = data[train_len:]
        
        self.train = True
    
    def __len__(self):
        if(self.train):
            return len(self.train_data)
        else:
            return len(self.test_data)
    
    def __getitem__(self, idx):
        if(self.train):
            return self.train_data[idx]
        else:
            return self.test_data[idx]
        
weibo = json.load(open("./dataset/text_summary/news_title/weibo_data.json"))
dataset = Dataset(weibo)

100%|██████████| 450295/450295 [00:00<00:00, 667048.85it/s]


In [5]:
def collate_fn(data):
    titles = [i['title'] for i in data]
    contents = [i['content'] for i in data]
    labels = [i['title'][4:] for i in data]
    
    decoder_input_ids = tokenizer.batch_encode_plus(titles, return_tensors='pt', return_attention_mask=False, padding=True, add_special_tokens=False)
    contents = tokenizer.batch_encode_plus(contents, return_tensors="pt", padding=True, add_special_tokens=True)
    labels = tokenizer.batch_encode_plus(labels, return_tensors="pt", padding=True, return_attention_mask=False, add_special_tokens=True)
    
    data = {}
    data['input_ids'] = contents['input_ids'].to(device)
    data['attention_mask'] = contents['attention_mask'].to(device)
    data['decoder_input_ids'] = decoder_input_ids['input_ids'].to(device)
    data['labels'] = labels['input_ids'].to(device)

    return data

loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=4,
                                     collate_fn=collate_fn,
                                     drop_last=True)

### 训练

In [6]:
def train(model, epoches, lr):
    lens = len(loader)
    model = model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=lr)
    losses = torch.zeros((epoches, lens))
    for i in range(epoches):
        with tqdm(total=lens, ncols=80) as bar:
            bar.set_description('训练进度-epoch: {}/{}'.format(i+1,epoches))
            for n,d in enumerate(loader):
                loss = model(**d)['loss']
                loss.backward()
                optim.step()
                optim.zero_grad()
                
                losses[i,n] += loss.item()
                bar.update(1)
                
            bar.set_postfix(loss = '{:.4f}'.format(losses[i].mean().item()))
    
    model.cpu()
    return losses

losses = train(model, 5, 2e-5)
torch.save(model, './models/text_summary.model')

训练进度-epoch: 1/5: 100%|███████| 350/350 [01:28<00:00,  3.94it/s, loss=1.7687]t/s]
训练进度-epoch: 2/5: 100%|███████| 350/350 [01:18<00:00,  4.47it/s, loss=1.0384]t/s]
训练进度-epoch: 3/5: 100%|███████| 350/350 [01:17<00:00,  4.50it/s, loss=0.5588]t/s]
训练进度-epoch: 4/5: 100%|███████| 350/350 [01:17<00:00,  4.52it/s, loss=0.2603]t/s]
训练进度-epoch: 5/5: 100%|███████| 350/350 [01:16<00:00,  4.57it/s, loss=0.1080]t/s]


### 使用

In [7]:
def generate(mdoel, text, max_new_tokens):
    inputs = tokenizer.encode_plus(text, return_tensors='pt', add_special_tokens=False)
    decoder_input_ids = tokenizer.encode_plus('</s>', return_tensors='pt', add_special_tokens=False)
    for _ in range(max_new_tokens):
        logits = model(**inputs, decoder_input_ids=decoder_input_ids['input_ids'])['logits']
        # focus only the last time step
        logits = logits[:, -1, :]   # becomes (B, C)
        # apply softmax to get probabilities
        probs = torch.nn.functional.softmax(logits, dim=1)
        # sample from the distribution
        # idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
        idx_next = probs.argmax(dim=1, keepdim=True)  # (B, 1)
        # append sampled index to the running squence
        decoder_input_ids['input_ids'] = torch.cat((decoder_input_ids['input_ids'], idx_next), dim=1)
    
    text = tokenizer.decode(decoder_input_ids['input_ids'][0])
    return text

In [8]:
# 导入模型
model = torch.load('./models/text_summary.model')

In [10]:
model.eval()
text = '几百架歼7加速淘汰，今年可能全部退役随着航空工业实力的增强，中国空军也进入了快速换装期，有消息称，在年内，所有歼7战斗机就将退出现役。日前，有国内军事专家在央视节目中对外透露，在解放军所剩不多的歼7战斗机将在年内全部退出空军作战序列。如果真是如此的话，就代表着中国空军向着全面三代化，迈出了十分重要的一步。我们必须清楚地是，解放军加速退役这几百架歼7战机，并不是因为这些战机都已经到了服役年限，机体老化到无法继续执行任务，而是在中国航空工业实力提升，新锐战机产能达到一定水平后，为了强化部队战斗力作出的决定。要知道，虽然歼7在中国空军中的服役时间已经超过了50年，但这款飞机真正停产的时间还不到10年。中国空军现役的歼7基本也都是后期换装的型号，如果继续使用，还 能服役一段时间。此外，虽然是一款二代机，但歼7战机并不是一无是处，这位老将也有着自己的优点。' 
inputs = tokenizer.encode_plus(text, return_tensors='pt')
model.cpu()
print('手动生成: ', generate(model, text, max_new_tokens=15))
model = model.to(device)
print('API生成: ', tokenizer.decode(model.generate(inputs['input_ids'].to(device), max_length=128, do_sample=False)[0]).replace('</s>', '').strip())

手动生成:  </s> 歼7战机年内全部退出空军作战序列</s><pad><pad><pad><pad>
API生成:  歼7战机年内全部退出空军作战序列
