In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import blurr
import fastai

import datasets
import pandas as pd
from fastai.text.all import *
from transformers import *

from blurr.data.all import *
from blurr.modeling.all import *

In [None]:
raw_data = datasets.load_dataset('reddit_tifu', 'long', split='train') 
df = pd.DataFrame(raw_data)
df.head(1)

In [None]:
df['x'] = df.apply(lambda row: str(row.tname) + " [X_SEP] " + str(row.qType) + " [X_SEP] " + str(row.col) + " [X_SEP] " + 
                   str(row.row) + " [X_SEP] " + str(row.data) + " [X_SEP] " + str(row.stat) + " [X_SEP] " , axis = 1)


df = df.drop(['tname', 'qType','col','row','data','stat'], axis=1)

df.head(1)

In [None]:
model_choice="prophetnet"

"""
Working models: 
- Bert works.. it comes already trained for summarization and we add on that...
- Bart works.. (but this might not be RXF), similarly, it seems to be already trained for summarization, so we might just add on that
- t5 apparently working, but it seems to be already trained for summarization, so we might just add on that
- pegasus working, but it seems it is precisely trained for summarization, so we might just add on that
- blenderbot seems to be working and results in something like a summarizer with a strong personality, with a lot of extractive behavior
- prophetnet seems to be working after some adaptation... overall good results
"""

if model_choice=="t5":
  pretrained_model_name = "t5-base"
  m_cls= T5ForConditionalGeneration
elif model_choice=="bert":
  pretrained_model_name = "patrickvonplaten/bert2bert_cnn_daily_mail" #"this is the only model we have that really is structured as an encoder_decoder in HF
  m_cls=EncoderDecoderModel
elif model_choice=="pegasus":  
  pretrained_model_name = "google/pegasus-large" #large pegasus really uses a lot of RAM :(
  m_cls=PegasusForConditionalGeneration
elif model_choice=="bart": 
  pretrained_model_name = "facebook/bart-base"
  m_cls=BartForConditionalGeneration
elif model_choice=="prophetnet":
  pretrained_model_name = "microsoft/prophetnet-large-uncased-cnndm"
  m_cls=ProphetNetForConditionalGeneration
elif model_choice=="blenderbot":
  pretrained_model_name = "facebook/blenderbot-90M"
  m_cls=BlenderbotForConditionalGeneration

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, model_cls=m_cls)
if model_choice=="blenderbot":#we benefit from the similar code structure in Hugging Face
  hf_arch="bart"
if model_choice=="bert":
  hf_arch="bert_encoder_decoder"
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)

In [None]:
text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='summarization'); 

if model_choice=="t5" and "prefix" in text_gen_kwargs :
  del text_gen_kwargs["prefix"]



MAX_LENGTH = 300
MIN_LENGTH = 30

text_gen_kwargs['max_length'] = MAX_LENGTH
text_gen_kwargs['min_length'] = MIN_LENGTH

text_gen_kwargs['num_beam_groups'] = 1
text_gen_kwargs['num_beams'] = 4
text_gen_kwargs['temperature'] = 0.6





In [None]:
text_gen_kwargs

In [None]:
hf_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model, max_length=MAX_LENGTH, min_length=MIN_LENGTH, text_gen_kwargs=text_gen_kwargs)
blocks = (HF_Seq2SeqBlock(before_batch_tfm=hf_batch_tfm), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader('x'), get_y=ColReader('target'), splitter=RandomSplitter())

In [None]:
dls = dblock.dataloaders(df[:1000], bs=8)

In [None]:
len(dls.train.items), len(dls.valid.items)

In [None]:
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape

In [None]:
dls.show_batch(dataloaders=dls, max_n=2)

In [None]:
import torch

seq2seq_metrics = {
        'rouge': {
            'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
            'returns': ["rouge1", "rouge2", "rougeL"]
        }
    }

model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics), CSVLogger]



def sum_split(m, arch):
    """Custom param splitter for summarization models"""
    model = m.hf_model if (hasattr(m, 'hf_model')) else m

    if arch in ['bert_encoder_decoder']:
        embeds = nn.Sequential(
          model.encoder.embeddings.word_embeddings,
          model.encoder,
          model.decoder.cls.predictions.decoder
        )
        groups = L(embeds, model.encoder, model.decoder.cls.predictions.decoder)
        return groups.map(params).filter(lambda el: len(el) > 0)
    if arch in ['prophetnet']:
        embeds = nn.Sequential(
          model.prophetnet.word_embeddings,
          model.prophetnet.encoder,
          model.prophetnet.decoder,
        )
        groups = L(embeds, model.prophetnet.encoder, model.prophetnet.decoder)
        return groups.map(params).filter(lambda el: len(el) > 0)
    raise ValueError('Invalid architecture')

learn = Learner(dls, 
                model,
                opt_func=ranger,
                loss_func=CrossEntropyLossFlat(),
                cbs=learn_cbs,
                splitter=partial(seq2seq_splitter, arch=hf_arch)).to_fp16()

learn.create_opt() 

learn.freeze()

In [None]:
learn.lr_find(suggestions=True)

In [None]:
learn.show_results(learner=learn, max_n=1)

In [None]:
learn.fit_one_cycle(25, lr_max=3e-3,cbs=fit_cbs)

In [None]:
learn.recorder.plot_loss()

In [None]:
learn.show_results(learner=learn, max_n=5)

In [None]:
learn.save("/vol3/bertpro/models/PN_save.pkl")