In [None]:
!wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz && tar -zxvf cnn_dm_v2.tgz && rm cnn_dm_v2.tgz

In [None]:
from transformers import BartForConditionalGeneration, BartTokenizer, TrainingArguments, Trainer
import datasets
import pandas as pd
import matplotlib.pyplot as plt
from nltk.corpus import stopwords
import tqdm
from typing import List
import os
import torch

# Some global variable
train_source = 'cnn_cln/train.source'
train_target = 'cnn_cln/train.target'
valid_source = 'cnn_cln/val.source'
valid_target = 'cnn_cln/val.target'
test_source = 'cnn_cln/test.source'
test_target = 'cnn_cln/test.target'
dataset_dir = 'cnn_summary'

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
stop_words = set(stopwords.words('english'))

In [None]:
# Check the length of the sentences
with open(train_source) as f_in:
    lens = [len(sent.split()) for sent in f_in]
    plt.hist(lens)
    plt.show()

In [None]:
def remove_stopwords(sents:List[str]):
    return [' '.join([w for w in sent.split() if not w.lower() in stop_words]) for sent in tqdm.tqdm(sents)]

with open(train_source) as f_in:
    sents = f_in.readlines()
    sents = remove_stopwords(sents)
    lens = [len(sent.split()) for sent in sents]
    plt.hist(lens)
    plt.show()

In [None]:
with open(train_target) as f_in:
    lens = [len(sent.split()) for sent in f_in]
    lens = [l for l in lens if l < 200]
    plt.hist(lens)
    plt.show()

In [None]:
# [Build] huggingface dataset 
train_df = datasets.Dataset.from_pandas(pd.DataFrame({'source' : open(train_source), 'summary' : open(train_target)})[:1000])
valid_df = datasets.Dataset.from_pandas(pd.DataFrame({'source' : open(valid_source), 'summary' : open(valid_target)})[:100])
test_df = datasets.Dataset.from_pandas(pd.DataFrame({'source' : open(test_source), 'summary' : open(test_target)})[:100])

ds = datasets.DatasetDict()
ds['train'] = train_df
ds['valid'] = valid_df
ds['test'] = test_df
ds.save_to_disk(dataset_dir)

In [None]:
# [Load] huggingface dataset 
ds = datasets.DatasetDict.load_from_disk(dataset_dir)

In [None]:
# Define tokenization function

def tokenize_function(examples):
  ret = tokenizer(examples['source'], padding='max_length', max_length=750, truncation=True)
  with tokenizer.as_target_tokenizer():
    ret['labels'] = tokenizer(examples['summary'], padding='max_length', max_length=150, truncation=True)['input_ids']
  return ret

tokenized_datasets = ds.map(tokenize_function, batched=True, batch_size=2)

In [None]:
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [None]:
train_arg = TrainingArguments(dataset_dir)

In [None]:
trainer = Trainer(
    model=model, args=train_arg, train_dataset=tokenized_datasets['train'], eval_dataset=tokenized_datasets['valid']
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(dataset_dir)

In [None]:
train_arg.device

In [None]:
!echo ${CUDA_VISIBLE_DEVICES}