In [1]:
import pandas
import re, json
import csv

import torch
import torch.nn as nn
from datasets import load_metric,Dataset,DatasetDict, load_dataset, Sequence, Value
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, BartForConditionalGeneration
from transformers import AutoTokenizer, Trainer

import evaluate

import numpy as np
import nltk
import os
import random
from sklearn.model_selection import train_test_split
from typing import List, Optional, Tuple, Union, Dict, Any
from jointbart_lmhead_step2 import myBartForConditionalGeneration
from hg_utils import GenerationMixin

In [2]:
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
_numpy_rng = np.random.default_rng(seed)
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(False)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
max_input_length = 256
max_target_length = 128

In [5]:
model_checkpoint = "hallucination-tagging-classifier-lmhead"
metric = evaluate.load("rouge")
model = myBartForConditionalGeneration.from_pretrained(model_checkpoint).to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large", add_prefix_space=True)

In [6]:
for name, param in model.named_parameters():
    if name == 'classifier.weight' or name == 'classifier.bias':
        param.requires_grad=False

In [7]:
dataset = load_dataset('pvisnrt/special_samsum')
id2label =  {0: 'C', 1: 'M', 2: 'N', 3: 'O', 4: 'OB', 5: 'W'}
label2id = {'C': 0, 'M': 1, 'N': 2, 'O': 3, 'OB': 4, 'W': 5}

In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'tags', 'tag_ids'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'tags', 'tag_ids'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'tags', 'tag_ids'],
        num_rows: 818
    })
})

In [9]:
# dataset['train'] = dataset['train'].cast_column("tag_ids", Sequence(Value("int32")))
# dataset['validation'] = dataset['validation'].cast_column("tag_ids", Sequence(Value("int32")))
# dataset['test'] = dataset['test'].cast_column("tag", Sequence(Value("int32")))

In [10]:
def tokenize_and_align_labels(examples):
    inputs = [doc for doc in examples['dialogue']]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, is_split_into_words=True, return_tensors='pt', padding='max_length')

    with tokenizer.as_target_tokenizer():
        tokenized_inputs = tokenizer(examples["summary"], max_length=max_target_length, truncation=True, is_split_into_words=True, return_tensors='pt', padding='max_length')

    labels = []
    for i, label in enumerate(examples["tag_ids"]):
        
        
        word_ids = tokenized_inputs.word_ids(batch_index=i)# Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    model_inputs['labels'] = tokenized_inputs['input_ids']

    for i, t in zip(model_inputs['labels'], labels):
        if len(i) != len(t):
            print("Issue")

    model_inputs["decoder_tags"] = labels
     
    return model_inputs

In [11]:
tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)

In [12]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'tags', 'tag_ids', 'input_ids', 'attention_mask', 'labels', 'decoder_tags'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'tags', 'tag_ids', 'input_ids', 'attention_mask', 'labels', 'decoder_tags'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'tags', 'tag_ids', 'input_ids', 'attention_mask', 'labels', 'decoder_tags'],
        num_rows: 818
    })
})

In [13]:
tokenized_datasets['train'] = tokenized_datasets['train'].remove_columns(['id', 'dialogue', 'summary', 'tags', 'tag_ids'])
tokenized_datasets['validation'] = tokenized_datasets['validation'].remove_columns(['id', 'dialogue', 'summary', 'tags', 'tag_ids'])
tokenized_datasets['test'] = tokenized_datasets['test'].remove_columns(['id', 'dialogue', 'summary', 'tags', 'tag_ids'])


#tokenized_datasets['train'] = tokenized_datasets['train'].select(range(100))
#tokenized_datasets['validation'] = tokenized_datasets['validation'].select(range(20))
#tokenized_datasets['test'] = tokenized_datasets['test'].select(range(20))


In [14]:
tokenized_datasets['train']

Dataset({
    features: ['input_ids', 'attention_mask', 'labels', 'decoder_tags'],
    num_rows: 100
})

In [15]:
class MySeq2SeqTrainer(Seq2SeqTrainer):
    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on `model` using `inputs`.
        Subclass and override to inject custom behavior.
        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.
                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """
        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)
        
        # print("prediction_step inputs: {}".format(inputs.keys()))

        # XXX: adapt synced_gpus for fairscale as well
        gen_kwargs = self._gen_kwargs.copy()
        if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
            gen_kwargs["max_length"] = self.model.config.max_length
        
        # disable beam search
        #gen_kwargs["num_beams"] = (
        #    gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
        #)
        
        # enable greedy search
        gen_kwargs["num_beams"] = 1
        gen_kwargs['early_stopping'] = False
        
        # default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
        default_synced_gpus = False
        gen_kwargs["synced_gpus"] = (
            gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
        )

        if "attention_mask" in inputs:
            gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
        if "global_attention_mask" in inputs:
            gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)

        # prepare generation inputs
        # some encoder-decoder models can have varying encoder's and thus
        # varying model input names
        if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
            generation_inputs = inputs[self.model.encoder.main_input_name]
        else:
            generation_inputs = inputs[self.model.main_input_name]

        tags = inputs["decoder_tags"]
        gen_kwargs.update({"decoder_tags": tags})
        # print(f"Gen kwargs: {gen_kwargs}")
        # print(f"Gen inputs:{generation_inputs}")
         #generated_tokens = self.model.generate(
        #    generation_inputs,
        #    **gen_kwargs,
        #)
        
        gen_mix = GenerationMixin(model)
        generated_tokens, classification_ids = gen_mix.generate(generation_inputs, **gen_kwargs)
        
        dialog = tokenizer.batch_decode(generation_inputs, skip_special_tokens=True)
        print('-'*89)
        print('dialog:\n', dialog)
        
        generated_summaries = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        
        print('\n\nGenerated Summaries:\n',*generated_summaries, sep='\n')
        print(f'Generated summary length: {generated_tokens.shape}')
        
        classification_labels = []
        classification_ids_lst = classification_ids.cpu().detach().tolist()
        for batch_classification_ids in classification_ids_lst:
            batch_classification_labels = []
            for classification_id in batch_classification_ids:
                classification_id = classification_id - 3
                if classification_id >= 0 and classification_id < len(id2label):
                    batch_classification_labels.append(id2label[classification_id])
            
            classification_labels.append(' '.join(batch_classification_labels))
        
        print('\nGenerated Classification Labels:\n',*classification_labels, sep='\n')
        print(f'Generated classification tag length: {classification_ids.shape}')
        
       
        # in case the batch is shorter than max length, the output should be padded
        if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
        elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
            gen_kwargs["max_new_tokens"] + 1
        ):
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)

        with torch.no_grad():
            if has_labels:
                with self.compute_loss_context_manager():
                    outputs = model(**inputs) # lm_logits as output
                if self.label_smoother is not None:
                    loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return (loss, None, None)

        if has_labels:
            labels = inputs["labels"]
            if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
                labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
            elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
                gen_kwargs["max_new_tokens"] + 1
            ):
                labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
        else:
            labels = None
        # print(labels)

        return (loss, generated_tokens, labels)

In [16]:
# training_args = Seq2SeqTrainingArguments(
#     output_dir="checkpoints/",
#     evaluation_strategy="epoch",
#     learning_rate=2e-5,
#     per_device_train_batch_size=2,
#     per_device_eval_batch_size=2,
#     weight_decay=0.01,
#     save_total_limit=4,
#     num_train_epochs=10,
#     predict_with_generate=True,
#     do_train=True,
#     do_eval=True,
#     fp16=True,
#     logging_steps=1,
#     save_strategy="epoch",
#     greater_is_better=True,
#     metric_for_best_model='Rouge1',
#     load_best_model_at_end=True,
#     seed=42,
#     generation_max_length=max_target_length,
# )

In [17]:
training_args = Seq2SeqTrainingArguments(
    output_dir="checkpoints_lmhead/",
    logging_steps=1,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="rouge1",
    seed=42,
    generation_max_length=max_target_length,
    dataloader_drop_last=True
)

In [18]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [19]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    print(f"Generated summary: {decoded_preds[0]}")

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    print(f"Gold summary: {decoded_labels[0]}")

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [20]:
trainer = MySeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [21]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdevavratj[0m. Use [1m`wandb login --relogin`[0m to force relogin


You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,5.6229,4.733024,0.3889,0.1911,0.3195,0.3193,26.25
2,3.8813,3.887948,0.3987,0.1918,0.315,0.3146,35.05


-----------------------------------------------------------------------------------------
dialog:
 [" A : Hi Tom, are you busy tomorrow’s afternoon? B : I’m pretty sure I am. What’s up? A : Can you go with me to the animal shelter?. B : What do you want to do? A : I want to get a puppy for my son. B : That will make him so happy. A : Yeah, we’ve discussed it many times. I think he’s ready now. B : That’s good. Raising a dog is a tough issue. Like having a baby ; -) A : I'll get him one of those little dogs. B : One that won't grow up too big ; -) A : And eat too much ; -)) B : Do you know which one he would like? A : Oh, yes, I took him there last Monday. He showed me one that he really liked. B : I bet you had to drag him away. A : He wanted to take it home right away ; -). B : I wonder what he'll name it. A : He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan : -)))"]


Generated Summaries:

 A wants to get a puppy for her son. She will take him to the a

-----------------------------------------------------------------------------------------
dialog:
 [' Nancy : Howdy, how y\'all doin\'? Tina : Is that a Texan drawl, girl? Nancy : Yes ma\'am! Loving it out here! Tina : How\'s the job going? Kids behaving themselves? Nancy : Mostly! They laugh at my accent though! Tina : Well, they probably haven\'t met a Welsh person before! Nancy : No shit! They ask me to repeat everything! Best one is "Water", course, it\'s mostly "Waarderr" here! Tina : LOL. I\'d love to hear that, you picked up the accent yet? Nancy : Nah, 21 years in Cardiff isn\'t easily removed! Tina : We\'re missing you here, the pub is quiet these days without your laugh! Nancy : Miss you too! I\'m coming home in 6 weeks, though. Last fortnight I\'m going travelling with 3 other Brits working here, a Geordie girl, a guy from Belfast and Annie, who\'s from Glasgow. Tina : My God, I\'m so jealous! I bet they had even more trouble being understood out there! See you after your tr

-----------------------------------------------------------------------------------------
dialog:
 [" Sash : need to see u Caron : y Caron : i'm out from 12 Sash : will be before Sash : then Caron : k Sash : open the door : Caron : what time u coming I need to go out Sash : soon Caron : hurry up I need to go out"]


Generated Summaries:

 Caron is going out from 12. She will be out from 11.
Generated summary length: torch.Size([1, 18])

Generated Classification Labels:

O O O O O O O O O O O O O O O O O
Generated classification tag length: torch.Size([1, 18])
-----------------------------------------------------------------------------------------
dialog:
 [" Giuseppe : Hi man Matteo : Yo Giuseppe : How's it going with Gosia? Matteo : I don't know, she's a little strange Giuseppe : Why? Matteo : She always criticizes me because I like football and video games Giuseppe : Damn Matteo : Yeah... Giuseppe : Ok, I don't like games either, but... Matteo : You boring guy Giuseppe : Lol Matteo 

-----------------------------------------------------------------------------------------
dialog:
 [" Keith : Meg, pls buy some milk and cereals, I see now we've run out of them Megan : hm, sure, I can do that Megan : but did you check in the drawer next to the fridge? Keith : nope, let me have a look Keith : ok, false alarm, we have cereal and milk : D Megan : <file_gif>"]


Generated Summaries:

 Megan has run out of milk and cereals. Keith will buy them for her.
Generated summary length: torch.Size([1, 20])

Generated Classification Labels:

O O O O O O O O O O O O O O O O O O O
Generated classification tag length: torch.Size([1, 20])
-----------------------------------------------------------------------------------------
dialog:
 [" Samantha : <file_video> Evelyn : LOL Holly : Is SHE making that noise?? Samatha : Yes (＾▽＾) Holly : How possible?? : o Samantha : Idk, I'm also surprised!! Evelyn : xD"]


Generated Summaries:

 Samatha makes a noise and Holly is surprised.
Generated s

-----------------------------------------------------------------------------------------
dialog:
 [" Julia : What is your biggest dream Julia : I mean the kind that can be achieved James : Everyone say I have nice voice James : My mom liked very much when I was reading outloud James : I've had this dream for some time now, to become a voice actor James : Be a part of cartoon or video game as a voice actor reading a character Julia : Wow. Nice one. Julia : Btw you do have a nice voice Julia : I could listen to you as a radio speaker. James : Thanks James : I've worked in radio, but it was during college so I had little time for this Julia : Shame. James : I know. But nothing is lost. I still have microphone at home and with a bit of help I could make homemade radio station Julia : That's actually a great idea Julia : I cheer for you!"]


Generated Summaries:

 James has a dream to become a voice actor. He wants to be a part of cartoon or video game as a voice artist. He has a microphon

TrainOutput(global_step=200, training_loss=6.089315139055252, metrics={'train_runtime': 456.2163, 'train_samples_per_second': 0.438, 'train_steps_per_second': 0.438, 'total_flos': 108494205542400.0, 'train_loss': 6.089315139055252, 'epoch': 2.0})

In [22]:
!nvidia-smi

Sun Dec  3 21:16:12 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 537.13                 Driver Version: 537.13       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3070 Ti   WDDM  | 00000000:01:00.0 Off |                  N/A |
|  0%   42C    P2              36W / 310W |   7944MiB /  8192MiB |     46%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [23]:
trainer.evaluate(tokenized_datasets['test'])


|    0   N/A  N/A     14228    C+G   ...5n1h2txyewy\ShellExperienceHost.exe    N/A      |
|    0   N/A  N/A     14496    C+G   ...oogle\Chrome\Application\chrome.exe    N/A      |
|    0   N/A  N/A     14892      C   ...\anaconda3\envs\cap_proj\python.exe    N/A      |
|    0   N/A  N/A     17444    C+G   ...t.LockApp_cw5n1h2txyewy\LockApp.exe    N/A      |
|    0   N/A  N/A     18124    C+G   ...crosoft\Edge\Application\msedge.exe    N/A      |
|    0   N/A  N/A     18196      C   ...\anaconda3\envs\cap_proj\python.exe    N/A      |
|    0   N/A  N/A     19052      C   ...\anaconda3\envs\cap_proj\python.exe    N/A      |
+---------------------------------------------------------------------------------------+
-----------------------------------------------------------------------------------------
dialog:
 [" Hannah : Hey, do you have Betty's number? Amanda : Lemme check Hannah : <file_gif> Amanda : Sorry, can't find it. Amanda : Ask Larry Amanda : He called her last time we were at 

-----------------------------------------------------------------------------------------
dialog:
 [" Eric : MACHINE! Rob : That's so gr8! Eric : I know! And shows how Americans see Russian ; ) Rob : And it's really funny! Eric : I know! I especially like the train part! Rob : Hahaha! No one talks to the machine like that! Eric : Is this his only stand-up? Rob : Idk. I'll check. Eric : Sure. Rob : Turns out no! There are some of his stand-ups on youtube. Eric : Gr8! I'll watch them now! Rob : Me too! Eric : MACHINE! Rob : MACHINE! Eric : TTYL? Rob : Sure : )"]


Generated Summaries:

 Eric will watch Gr8's stand-up. He especially likes the train part.
Generated summary length: torch.Size([1, 20])

Generated Classification Labels:

O O O O O O O O O O O O O O O O O O O
Generated classification tag length: torch.Size([1, 20])
-----------------------------------------------------------------------------------------
dialog:
 [" Lenny : Babe, can you help me with something? Bob : Sure, what

-----------------------------------------------------------------------------------------
dialog:
 [" Beatrice : I am in town, shopping. They have nice scarfs in the shop next to the church. Do you want one? Leo : No, thanks Beatrice : But you don't have a scarf. Leo : Because I don't need it. Beatrice : Last winter you had a cold all the time. A scarf could help. Leo : I don't like them. Beatrice : Actually, I don't care. You will get a scarf. Leo : How understanding of you! Beatrice : You were complaining the whole winter that you're going to die. I've had enough. Leo : Eh."]


Generated Summaries:

 Leo doesn't have a scarf. He has a cold and doesn't like scarfs. Beatrice will get him one.
Generated summary length: torch.Size([1, 28])

Generated Classification Labels:

O O O O O O O O O O O O O O O O O O O O O O O O O O O
Generated classification tag length: torch.Size([1, 28])
-----------------------------------------------------------------------------------------
dialog:
 [" Ivan

-----------------------------------------------------------------------------------------
dialog:
 [" Clara : Hi, what you up to? Neela : Not much, chilling out. Clara : Just rewatching Dear White People on Netflix, love it! 😍 Neela : Oh yeah, heard of it, but not seen it yet? Any good? Clara : Well, yes! I just said it was, LOL. It's about a fictional Ivy League University and the students in one House of Residence. Neela : Why is it called Dear White People? Clara : That's the name of the radio show the main character, Sam, presents on college radio. Neela : Yeah, but why is it so good? Clara : Well, it's mainly stories from the perspective of black students there, which I find very interesting. The characters are strong and likeable too. Neela : I suppose it's rather different from the UK, then? Clara : It seems so, as there is a lot more racial awareness and discrimination there than here. It all kicks off when there is a Blackface party held by an elite group of white students, wh

{'eval_loss': 3.7750210762023926,
 'eval_rouge1': 0.4314,
 'eval_rouge2': 0.1833,
 'eval_rougeL': 0.3326,
 'eval_rougeLsum': 0.3321,
 'eval_gen_len': 37.15,
 'eval_runtime': 50.9579,
 'eval_samples_per_second': 0.392,
 'eval_steps_per_second': 0.392,
 'epoch': 2.0}

In [24]:
model.save_pretrained("summarizer_w_classifier_loss_frozen_lmhead")