# 引入依赖库

In [None]:
!pip install transformers
!pip install transformers sentencepiece
!pip install d2l
!pip instal matplotlib==3.0.0
!pip install matplotlib_inline
!pip install rouge

import torch
import torch.nn as nn
from torch..utils.data import Dataset, DataLoader
from d2l import torch as d2l
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import pandas as pd
import re
from rouge import Rouge

In [None]:
bart = AutoModelForSeq2SeqLM.from_pretrained('sshleifer/distilbart-xsum-12-6')
tokenizer = AutoTokenizer.from_pretrained('sshleifer/distilbart-xsum-12-6')

In [None]:
# 超参数设定

batch_size = 128
lr = 0.001
devices = d2l.try_all_gpus()
hidden_size = 512
n_layers = 2
bidirectional = True
dropout = 0.15
num_epochs = 4


# 数据预处理 数据迭代器装载

In [None]:
train_path='/content/news_summary/train_dataset.csv'
eval_path='/content/news_summary/eval_dataset.csv'
test_path='/content/news_summary/test_dataset.csv'

def read_data(root, is_test):
    if is_test:
        df = pd.DataFrame([],columns=["Index","Text"])
    else:
        df = pd.DataFrame([],columns=["Index","Text","Summary"])

    f = open(root,'r',encoding='utf-8-sig').readlines()
    for idx, texts in enumerate(f):
        df.loc[idx] = text..split("\t")

    return df

train = read_data(train_root, False)
eval = read_data(eval_root, False)
test = read_data(test_root, True)

In [None]:
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))


class News_summary(Dataset):
    def __init__(self, dataset, mode):
        self.mode = mode
        
        if mode == 'test':
            self.text = dataset["Text"]
        else:
            self.text, self.summary = dataset["Text"], dataset["Summary"]

    def __getitem__(self, idx):
        if mode == 'test':
            self.text[idx] = self.text[idx].split('updated :')[-1].strip()
            self.text[idx] = self.text[idx].split('-lrb- cnn -rrb- --')[-1].strip()

            self.text[idx] = tokenizer(WHITESPACE_HANDLER,
                           self.text[idx], max_length=1768,
                           return_tensors="pt",
                           padding=True, truncation=True
                           )["input_ids"]
            return self.text[idx]
        else:
            self.text[idx] = self.text[idx].split('updated :')[-1].strip()
            self.text[idx] = self.text[idx].split('-lrb- cnn -rrb- --')[-1].strip()

            self.text[idx] = tokenizer(WHITESPACE_HANDLER,
                           self.text[idx], max_length=1768,
                           return_tensors="pt",
                           padding=True, truncation=True)
            self.summary[idx] = tokenizer(self.abst[idx], max_length=512,
                            padding=True, truncation=True
                            )["input_ids"]
            return self.text[idx], self.summary[idx]

    def __len__(self):
        return len(self.text)


In [None]:
train_set = News_summary(train, mode='train')
eval_set = News_summary(eval, mode='eval')
test_set = News_summary(test, mode='test')

train_loader = DataLoader(train_set, batch_size=batch_size,
              shuffle=True, num_workers=4)
eval_loader = DataLoader(eval_set, batch_size=batch_size,
              shuffle=False, num_workers=4)

# 模型

In [None]:
class Distilbart_gru(nn.Module):
    def __init__(self, bart, hidden_size, vocab_size, n_layers,
                 bidirectional, dropout, **kwargs):
        super(TextCNN, self).__init__(**kwargs)

        self.bart = bart
        embedding_size = bart.config.to_dict()['hidden_size']

        self.gru = nn.GRU(embedding_size, hidden_size, n_layers=n_layers,
                  bidirectional=bidirectional, dropout=dropout)
        self.Linear = nn.Linear(hidden_size*2 if bidirectional else hidden_size
                     , vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        with torch.no_grad():
            embed = self.bart(x)[0]

        _, hidden = self.gru(embed)

        if self.gru.bidirectional:
            hidden = self.dropout(torch.cat(hidden[-2,:,:],
                             hidden[-1,:,:]), dim=1)
        else:
            hidden = self.dropout(hidden[-1,:,:])
        
        output = self.Linear(hidden)
        return output

vocab

In [None]:
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open('/content/news_summary/distilbart-xsum-12-6/vocab.json'))
vocab.token_to_idx = {token: idx for idx, token in
                    enumerate(vocab.idx_to_token)}

# 定义LOSS和Rouge

In [None]:
model = Distilbart_gru(bart, hidden_size, len(vocab), bidirectional, dropout)

loss = nn.CrossEntropyLoss(reduction="none")
trainer = torch.optim.Adam(model.parameters(), lr=lr)


def print_rouge_L(output, label):
    rouge = Rouge()
    rouge_score = rouge.get_scores(output, label)

    rouge_L_f1 = 0
    rouge_L_p = 0
    rouge_L_r = 0
    for d in rouge_score:
        rouge_L_f1 += d["rouge-l"]["f"]
        rouge_L_p += d["rouge-l"]["p"]
        rouge_L_r += d["rouge-l"]["r"]
    print("rouge_f1:%.2f" % (rouge_L_f1 / len(rouge_score)))
    print("rouge_p:%.2f" % (rouge_L_p / len(rouge_score)))
    print("rouge_r:%.2f" % (rouge_L_r / len(rouge_score)))


# train ＆ eval

In [None]:
def train_(net, data_iter, lr, num_epochs, device):

    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])
    net.apply(xavier_init_weights)
    net.to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    net.train()
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                 xlim=[10, num_epochs])
    
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)  # Sum of training loss, no. of tokens
        for batch in data_iter:
            optimizer.zero_grad()
            X, Y = [x.to(device) for x in batch]
            Y_hat = net(X)
            l = loss(Y_hat, Y)
            l.sum().backward()  # Make the loss scalar for `backward`
            d2l.grad_clipping(net, 1)
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum())
        if (epoch + 1) % 10 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
          f'tokens/sec on {str(device)}')
    

def eval_(net, eval_iter):
    net.eval()
    net.parameters().requires_grad = False
    summary = []
    for batch in data_iter:
        X, Y = [x.to(device) for x in batch]
        Y_hat = net(X)
        dec_X = Y_hat.argmax(dim=2)
        pred = dec_X.squeeze(dim=0).type(torch.int32).item()
        output_seq = []
        output_seq.append(pred)
        summary.append(vocab.to_tokens(output_seq))

    print_rouge_L(summary, eval["Summary"])

In [None]:
train_(model, train_loader, lr, num_epochs, devices)

In [None]:
eval_(model, eval_loader)

# 保存模型
torch.save(model.state_dict(), '/content/news_summary/model_weights.txt')

# 预测

In [None]:
params = torch.load('/content/THUCNews/model_weights.txt')
model.load_state_dict(params)

fw = open('/content/news_summary/submission.csv', 'w+', encoding='utf-8-sig')

for idx,text.to(devices) in tqdm(enumerate(test_set), total=1000):
    model.eval()
    net.parameters().requires_grad = False
    Y_hat = model(text)
    dec_X = Y_hat.argmax(dim=2)
    pred = dec_X.squeeze(dim=0).type(torch.int32).item()
    output_seq = []
    output_seq.append(pred)
    summary = vocab.to_tokens(output_seq)

    fw = open('/content/news_summary/submission.csv', 'a+')
    fw.write(str(idx))
    fw.write('\t')
    fw.write(summary)
    fw.write('\n')

fw.close()