In [1]:
!pip install ohmeow-blurr -q
!pip install bert-score -q

In [2]:
import pandas as pd
from fastai.text.all import *
from transformers import *
from blurr.data.all import *
from blurr.modeling.all import *

In [3]:
df = pd.read_csv('NYT_Dataset.csv', error_bad_lines=False)
df = df[['abstract', 'title']]
df.columns = ['content', 'title']
df['content'] = df['content'].astype(str)
df['title'] = df['title'].astype(str)


In [4]:
#Clean text
df['content'] = df['content'].apply(lambda x: x.replace('/',''))
df['content'] = df['content'].apply(lambda x: x.replace('\xa0',''))
df.head()

Unnamed: 0,content,title
0,Pakistan’s ambassador to the U.S. said his government would not endorse a separate inquiry modeled after one carried out by the U.N. after the assassination of Rafik Hariri of Lebanon in 2005.,"In Reversal, Pakistan Welcomes Outside Help With Inquiry on Bhutto"
1,"Kenya sank deeper into trouble, with a curfew imposed in Kisumu, the country’s third-largest city, ethnic fighting intensifying and more than 100 people killed in election-related violence.",Fighting Intensifies After Election in Kenya
2,"Prime Minister Ehud Olmert has sent a letter to defense, housing and agriculture ministers, saying that his and the defense minister’s authorization would be required for any new building, planning or land expropriation for Jewish settlements in the West Bank.",Israel: Olmert Curbs Settlements
3,The monthly club night known as Gayhane is an all-too-rare opportunity for gay Muslims to merge their immigrant cultures and their sexual identities.,Gay Muslims Pack a Dance Floor of Their Own
4,"But even as partygoers embraced the New Year, a surge of attacks on Monday served as a potent reminder that 2007 was the bloodiest on record.",Iraqi Revelers Embrace the New Year


In [5]:
from sklearn.model_selection import train_test_split
X_train, y_test = train_test_split(df,  test_size=.2)

In [None]:
#Truncate text to make it fit into the model
#articles['content'] = articles['content'].apply(lambda x: ' '.join(x.split()[:700]))

## Import model and set up data

In [6]:
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, 
                                                                  model_cls=BartForConditionalGeneration)

In [7]:
hf_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model, task='summarization',
text_gen_kwargs={'max_length': 25,
 'min_length': 0,
 'do_sample': False,
 'early_stopping': True,
 'num_beams': 4,
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'repetition_penalty': 1.0,
 'bad_words_ids': None,
 'bos_token_id': 0,
 'pad_token_id': 1,
 'eos_token_id': 2,
 'length_penalty': 2.0,
 'no_repeat_ngram_size': 3,
 'encoder_no_repeat_ngram_size': 0,
 'num_return_sequences': 1,
 'decoder_start_token_id': 2,
 'use_cache': True,
 'num_beam_groups': 1,
 'diversity_penalty': 0.0,
 'output_attentions': False,
 'output_hidden_states': False,
 'output_scores': False,
 'return_dict_in_generate': False,
 'forced_bos_token_id': 0,
 'forced_eos_token_id': 2,
 'remove_invalid_values': False})

blocks = (HF_Seq2SeqBlock(before_batch_tfm=hf_batch_tfm), noop)

dblock = DataBlock(blocks=blocks, get_x=ColReader('content'), get_y=ColReader('title'), splitter=RandomSplitter())

In [8]:
dls = dblock.dataloaders(X_train, bs=2)

### Training


In [9]:
seq2seq_metrics = {
        'rouge': {
            'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
            'returns': ["rouge1", "rouge2", "rougeL"]
        },
        'bertscore': {
            'compute_kwargs': { 'lang': 'fr' },
            'returns': ["precision", "recall", "f1"]
        }
    }

In [10]:
model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(dls, 
                model,
                opt_func=ranger,
                loss_func=CrossEntropyLossFlat(),
                cbs=learn_cbs,
                splitter=partial(seq2seq_splitter, arch=hf_arch)).to_fp16()

learn.create_opt() 
learn.freeze()

In [None]:
learn.fit_one_cycle(3, lr_max=3e-5, cbs=fit_cbs)

epoch,train_loss,valid_loss,rouge1,rouge2,rougeL,bertscore_precision,bertscore_recall,bertscore_f1,time
0,2.191402,2.292068,0.278419,0.102319,0.24686,0.724287,0.722391,0.722798,1:58:07
1,1.928625,2.243559,0.298798,0.112967,0.264995,0.734925,0.737088,0.735593,1:57:46


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe959e65ef0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1301, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.7/multiprocessing/popen_fork.py", line 45, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 921, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.7/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 


KeyboardInterrupt: ignored

### Generate predictions

In [None]:
df['content'][0]

In [19]:
outputs = learn.blurr_generate(df['content'][0], early_stopping=False, num_return_sequences=1)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')

=== Prediction 1 ===
 Pakistan: No to U.S. Inquiry on Hariri’s Killing, No to Inquiry on Pakistan”s Killings’ Victims’ Families’ Deaths’ Killings, No. 2 Confronts U.N. Inquiry’, No 2 Killings in Pakistan, No Inquiry on Killings of Killers in Pakistan. Pakistan: Inquiry Would Be Based on Inquiry on Killing Victims in Lebanon, No Involvement in Killings There, No Enquiries There. No Inquiry Involves Killings Of Killers In Pakistan or Killings In Killings All Over the World, No Intervenes There, But Inquiry Involvering Killings

