<a href="https://colab.research.google.com/github/TIMEdilation584/JP_Loksatta_moving_hearts/blob/master/20-08-22_fastai_huggingface_study_group_part_2_session_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install datasets -Uqq
! pip install transformers[sentencepiece] -Uqq
! pip install git+https://github.com/ohmeow/blurr.git@dev-2.0.0 -Uqq
! pip install bert_score -Uqq
! import nltk -Uqq

In [None]:
from blurr.data.all import *
from blurr.modeling.all import *
from datasets import load_dataset
from fastai.data.all import *
from fastai.callback.all import *
from fastai.learner import *
from fastai.optimizer import *
from transformers import *

import nltk
nltk.download('punkt', quiet=True)


## 01. Get a dataset

In [None]:
dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0", split="train[:1000]")
cnndm_df = pd.DataFrame(dataset)
cnndm_df.head(2)

## 02. Get your Hugging Face objects

In [None]:
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)

## 03. Preprocess the raw data (optional)

In [None]:
preprocessor = SummarizationPreprocessor(
    hf_tokenizer,
    id_attr="id",
    text_attr="article",
    target_text_attr="highlights",
    max_input_tok_length=256,
    max_target_tok_length=130,
    min_summary_char_length=30,
)
proc_df = preprocessor.process_df(cnndm_df)
proc_df.head(2)

## 04. Define our DataBlock

In [None]:
Seq2SeqBatchTokenizeTransform??

In [None]:
hf_config

In [None]:
text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task="summarization")
text_gen_kwargs

In [None]:
batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
    hf_arch, hf_config, hf_tokenizer, hf_model, text_gen_kwargs=text_gen_kwargs
)

blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("proc_article"), get_y=ColReader("proc_highlights"), splitter=RandomSplitter())

In [None]:
dls = dblock.dataloaders(proc_df, bs=2)

In [None]:
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[1].shape

(2, torch.Size([2, 257]), torch.Size([2, 49]))

In [None]:
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=500, target_trunc_at=250)

## 05. Define our Learner and metrics

In [None]:
summarization_metrics = {
    "rouge": {
        "compute_kwargs": {"rouge_types": ["rouge1", "rouge2", "rougeL", "rougeLsum"], "use_stemmer": True},
        "returns": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
    },
    "bertscore": {"compute_kwargs": {"lang": "en"}, "returns": ["precision", "recall", "f1"]},
}

translation_metrics = {"bleu": {"returns": "bleu"}, "meteor": {"returns": "meteor"}, "sacrebleu": {"returns": "score"}}

In [None]:
model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]

# calc_very options: 'epoch', 'other_epoch', 'last_eopch'
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=summarization_metrics, calc_every="last_epoch")] 

learn = Learner(
    dls,
    model,
    opt_func=partial(Adam),
    loss_func=PreCalculatedCrossEntropyLoss(),
    cbs=learn_cbs,
    splitter=partial(blurr_seq2seq_splitter, arch=hf_arch),
)

learn = learn.to_fp16()
learn.freeze()

In [None]:
b = dls.one_batch()
preds = learn.model(b[0])

In [None]:
preds

In [None]:
len(preds), preds["loss"].shape, preds["logits"].shape

In [None]:
print(len(learn.opt.param_groups))

## 06. Train

In [None]:
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])

In [None]:
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)

In [None]:
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)


## Inference

In [None]:
learn.metrics = None
learn = learn.to_fp32()
learn.export(fname="article_summary_export.pkl")


In [None]:
test_article = cnndm_df.iloc[10].article
test_article

In [None]:
inf_learn = load_learner(fname="article_summary_export.pkl")
inf_learn.blurr_summarize(test_article)