In [1]:
!pip install -q transformers datasets evaluate transformers[torch] py7zr wandb==0.17.9

In [2]:
!pip install -U "transformers>=4.44" "accelerate>=0.34" "datasets>=2.21"



In [3]:
import os
os.environ["WANDB_DISABLED"] = "true"   # hard-disable wandb
os.environ["WANDB_MODE"] = "disabled"   # belt-and-suspenders

In [21]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from datasets import load_dataset
from huggingface_hub import notebook_login
from transformers import TrainingArguments, Trainer

In [5]:
model_name = 'facebook/bart-large-cnn'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [6]:
dataset = load_dataset("knkarthick/samsum")
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14731
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
})

In [7]:
sample = dataset['test'][0]['dialogue']
label = dataset['test'][0]['summary']

def generate_summary(input, llm):
  input_prompt = f"""Summarize the following conversation.

                  {input}

                  Summary:
                  """

  input_ids = tokenizer(sample, return_tensors='pt')
  tokenized_output = llm.generate(input_ids['input_ids'], min_length=30, max_length=200)
  output = tokenizer.decode(tokenized_output[0], skip_special_tokens=True)

  return output

In [8]:
output = generate_summary(sample, llm=model)
print('Sample')
print(sample)
print("-"*20)
print("Model Generated Summary:")
print(output)
print("Correct summary:")
print(label)

Sample
Hannah: Hey, do you have Betty's number?
Amanda: Lemme check
Hannah: <file_gif>
Amanda: Sorry, can't find it.
Amanda: Ask Larry
Amanda: He called her last time we were at the park together
Hannah: I don't know him well
Hannah: <file_gif>
Amanda: Don't be shy, he's very nice
Hannah: If you say so..
Hannah: I'd rather you texted him
Amanda: Just text him 🙂
Hannah: Urgh.. Alright
Hannah: Bye
Amanda: Bye bye
--------------------
Model Generated Summary:
Hannah: Hey, do you have Betty's number? Amanda: Lemme check. Hannah: Ask Larry. Amanda: He called her last time we were at the park together.
Correct summary:
Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry.


# Prepare our dataset

In [18]:
MAX_INPUT_LEN = 512
MAX_TARGET_LEN = 128

def tokenize_inputs(examples):
    start_prompt = "Summarize the following conversation.\n\n"
    end_prompt = "\n\nSummary: "
    prompts = [start_prompt + d + end_prompt for d in examples["dialogue"]]

    model_inputs = tokenizer(
        prompts,
        max_length=MAX_INPUT_LEN,
        truncation=True,
        padding=False,
    )

    labels = tokenizer(
        text_target=examples["summary"],
        max_length=MAX_TARGET_LEN,
        truncation=True,
        padding=False,
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


tokenized_datasets = dataset.map(
    tokenize_inputs,
    batched=True,
    remove_columns=["id", "dialogue", "summary"],
)

Map:   0%|          | 0/14731 [00:00<?, ? examples/s]

Map:   0%|          | 0/818 [00:00<?, ? examples/s]

Map:   0%|          | 0/819 [00:00<?, ? examples/s]

In [19]:
print(tokenized_datasets['train'].shape)
print(tokenized_datasets['validation'].shape)
print(tokenized_datasets['test'].shape)

(14731, 3)
(818, 3)
(819, 3)


In [20]:
tokenized_datasets['train'][0].keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [22]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [42]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [23]:
training_args = TrainingArguments(
    output_dir="./bart-cc-samsum-finetuned",
    learning_rate=1e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    auto_find_batch_size=True,
    eval_strategy="epoch",
    logging_steps=10,
    report_to="none",
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
)

  trainer = Trainer(


In [24]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,1.3034,1.385201




TrainOutput(global_step=1842, training_loss=1.3684937350783621, metrics={'train_runtime': 2741.7894, 'train_samples_per_second': 5.373, 'train_steps_per_second': 0.672, 'total_flos': 1.0508789665529856e+16, 'train_loss': 1.3684937350783621, 'epoch': 1.0})

In [26]:
loaded_model = AutoModelForSeq2SeqLM.from_pretrained('ingeniumacademy/bart-cnn-samsum-finetuned')

output = generate_summary(sample, llm=loaded_model)

print('Sample')
print(sample)
print("-"*20)
print("Model Generated Summary:")
print(output)
print("Correct summary:")
print(label)

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/358 [00:00<?, ?B/s]

Sample
Hannah: Hey, do you have Betty's number?
Amanda: Lemme check
Hannah: <file_gif>
Amanda: Sorry, can't find it.
Amanda: Ask Larry
Amanda: He called her last time we were at the park together
Hannah: I don't know him well
Hannah: <file_gif>
Amanda: Don't be shy, he's very nice
Hannah: If you say so..
Hannah: I'd rather you texted him
Amanda: Just text him 🙂
Hannah: Urgh.. Alright
Hannah: Bye
Amanda: Bye bye
--------------------
Model Generated Summary:
Amanda can't find Betty's number. She'll ask Larry. He called her last time they were at the park together. He's very nice.
Correct summary:
Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry.


In [44]:
trainer.save_model("./bart-cc-samsum-finetuned-final")
tokenizer.save_pretrained("./bart-cc-samsum-finetuned-final")

('./bart-cc-samsum-finetuned-final/tokenizer_config.json',
 './bart-cc-samsum-finetuned-final/special_tokens_map.json',
 './bart-cc-samsum-finetuned-final/vocab.json',
 './bart-cc-samsum-finetuned-final/merges.txt',
 './bart-cc-samsum-finetuned-final/added_tokens.json',
 './bart-cc-samsum-finetuned-final/tokenizer.json')

In [45]:
!zip -r bart-cc-samsum-finetuned-final.zip bart-cc-samsum-finetuned-final

  adding: bart-cc-samsum-finetuned-final/ (stored 0%)
  adding: bart-cc-samsum-finetuned-final/tokenizer.json (deflated 82%)
  adding: bart-cc-samsum-finetuned-final/merges.txt (deflated 53%)
  adding: bart-cc-samsum-finetuned-final/config.json (deflated 62%)
  adding: bart-cc-samsum-finetuned-final/generation_config.json (deflated 47%)
  adding: bart-cc-samsum-finetuned-final/model.safetensors (deflated 7%)
  adding: bart-cc-samsum-finetuned-final/tokenizer_config.json (deflated 75%)
  adding: bart-cc-samsum-finetuned-final/special_tokens_map.json (deflated 54%)
  adding: bart-cc-samsum-finetuned-final/training_args.bin (deflated 53%)
  adding: bart-cc-samsum-finetuned-final/vocab.json (deflated 59%)
