<a href="https://colab.research.google.com/github/shiftkey-labs/GenAI-Course/blob/main/gen_ai_book_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Fine Tuning

In [27]:
!pip install transformers datasets



In [29]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [31]:
print(device)

cuda


In [32]:
from datasets import load_dataset

dataset = load_dataset("cnn_dailymail", "3.0.0", split="train")

dataset_split = dataset.train_test_split(test_size=0.1)

small_train_dataset = dataset_split['train'].train_test_split(test_size=0.99)['train']

eval_dataset = dataset_split['test']

In [34]:
def preprocess_function(examples):
  inputs = [doc for doc in examples['article']]
  model_inputs = tokenizer(inputs, max_length=512, padding="max_length", truncation=True, return_tensors="pt")

  with tokenizer.as_target_tokenizer():
    labels= tokenizer(examples['highlights'], max_length=128, padding="max_length", truncation=True, return_tensors="pt")

  model_inputs["labels"] = labels["input_ids"]

  model_inputs = {k:v.to(device) for k, v in model_inputs.items()}

  return model_inputs

tokenized_train_dataset = small_train_dataset.map(preprocess_function, batched=True)
tokenized_eval_dataset = eval_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/2584 [00:00<?, ? examples/s]



Map:   0%|          | 0/28712 [00:00<?, ? examples/s]

In [37]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    logging_dir="./logs"
)



In [38]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset,
    tokenizer=tokenizer
)

In [39]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,2.634745
2,5.687000,1.888724
3,5.687000,1.697963


TrainOutput(global_step=969, training_loss=3.997419642713171, metrics={'train_runtime': 1885.1766, 'train_samples_per_second': 4.112, 'train_steps_per_second': 0.514, 'total_flos': 1441023192465408.0, 'train_loss': 3.997419642713171, 'epoch': 3.0})

In [40]:
metrics = trainer.evaluate()
print(metrics)

{'eval_loss': 1.6979626417160034, 'eval_runtime': 476.6947, 'eval_samples_per_second': 60.231, 'eval_steps_per_second': 7.529, 'epoch': 3.0}


In [42]:
def summarize(text):
  inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
  summary_ids = model.generate(inputs["input_ids"], max_length=128, num_beams=4, early_stopping=True)
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

text = ""
print(summarize(text))

    


In [43]:
print(summarize("The global pandemic has significantly impacted the way we work, leading to an unprecedented shift towards remote work. Companies around the world have had to adapt to new working conditions, often at a rapid pace. This transition has brought about both challenges and opportunities. While some employees enjoy the flexibility of working from home, others struggle with isolation and the blurring of work-life boundaries. Moreover, companies are now rethinking their long-term strategies, with many considering permanent remote work policies. However, this shift also raises concerns about maintaining company culture and ensuring effective collaboration among teams. As businesses navigate this new landscape, the ability to adapt and innovate will be key to their success."))

Companies around the world have had to adapt to new working conditions, often at a rapid pace.


In [45]:
print(summarize(
    """
Person A: Hey, did you hear about the new project management software our company is planning to implement?

Person B: Yeah, I heard a bit about it. What’s the deal with it?

Person A: It’s called "TaskFlow." The management thinks it’s going to streamline our workflow, especially with remote teams. It’s supposed to integrate all the tools we use, like Slack, Trello, and Google Drive, into one platform.

Person B: That sounds interesting. But I’m a bit concerned about the learning curve. Is it user-friendly?

Person A: From what I’ve seen, it looks pretty intuitive. They’re also planning to run a couple of training sessions to get everyone up to speed. The first one is next Monday.

Person B: Okay, that helps. I guess I’ll have to attend that session. How does it compare to what we’re using now?

Person A: It’s supposed to be much more efficient. We’ll be able to track project progress more easily and get real-time updates. Plus, it has built-in analytics to help us with performance tracking.

Person B: That sounds promising. I just hope it doesn’t come with too many bugs at launch.

Person A: Yeah, that’s always a concern with new software. But they’ve been testing it for a while now, so fingers crossed it goes smoothly.

Person B: Let’s hope for the best. Thanks for the info!

Person A: No problem. See you at the training!
"""
))

Project management software is going to streamline our workflow, especially with remote teams. It’s supposed to integrate all the tools we use, like Slack, Trello, and Google Drive, into one platform. It’s supposed to integrate all the tools we use, like Slack, Trello, and Google Drive, into one platform. It’s supposed to be much more efficient. They’ll be able to track project progress more easily and get real-time updates. It has built-in analytics to help us with performance tracking.
