In [1]:
from functools import partial

from blurr.data.seq2seq.core import HF_Seq2SeqBeforeBatchTransform, HF_Seq2SeqBlock
from blurr.modeling.core import HF_BaseModelWrapper, HF_BaseModelCallback
from blurr.modeling.seq2seq.core import HF_Seq2SeqMetricsCallback, seq2seq_splitter
from blurr.utils import BLURR
from datasets import load_dataset
from fastai.data.block import DataBlock
from fastai.data.transforms import RandomSplitter
from fastai.imports import noop
from fastai.learner import Learner
from fastai.losses import CrossEntropyLossFlat
from fastai.optimizer import ranger
from fastcore.transform import Pipeline
from transformers import BartForConditionalGeneration



In [2]:
dataset = load_dataset('../bart_large_cnn/multi_news', split='train',)

pretrained_model = '../bart_large_cnn'
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model,
                                                                  model_cls=BartForConditionalGeneration)

hf_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
    task='summarization',
    text_gen_kwargs=
 {'max_length': 248,'min_length': 56,'do_sample': False, 'early_stopping': True, 'num_beams': 4, 'temperature': 1.0,
  'top_k': 50, 'top_p': 1.0, 'repetition_penalty': 1.0, 'bad_words_ids': None, 'bos_token_id': 0, 'pad_token_id': 1,
 'eos_token_id': 2, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'encoder_no_repeat_ngram_size': 0,
 'num_return_sequences': 1, 'decoder_start_token_id': 2, 'use_cache': True, 'num_beam_groups': 1,
 'diversity_penalty': 0.0, 'output_attentions': False, 'output_hidden_states': False, 'output_scores': False,
 'return_dict_in_generate': False, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2, 'remove_invalid_values': False})

Found cached dataset multi_news (C:/Users/Andrew/.cache/huggingface/datasets/multi_news/default/1.0.0/2f1f69a2bedc8ad1c5d8ae5148e4755ee7095f465c1c01ae8f85454342065a72)


In [3]:
blocks = (HF_Seq2SeqBlock(before_batch_tfm=hf_batch_tfm), noop)
x = Pipeline(dataset["document"], dataset["summary"])
dblock = DataBlock(blocks=blocks, get_x=x, splitter=RandomSplitter())
dls = dblock.dataloaders(dataset, bs=2)

KeyboardInterrupt: 

In [None]:
seq2seq_metrics = {
        'rouge': {
            'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
            'returns': ["rouge1", "rouge2", "rougeL"]
        },
        'bertscore': {
            'compute_kwargs': { 'lang': 'fr' },
            'returns': ["precision", "recall", "f1"]}}

model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

In [None]:
learn = Learner(dls, model,
                opt_func=ranger,loss_func=CrossEntropyLossFlat(),
                cbs=learn_cbs,splitter=partial(seq2seq_splitter, arch=hf_arch)).to_fp16()

learn.create_opt()
learn.freeze()

learn.fit_one_cycle(3, lr_max=3e-5, cbs=fit_cbs)

In [None]:
learn.save('../finetuned-bart_large_cnn')