### **Initialization**

In [1]:
!pip install evaluate
!pip install rouge_score
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from datasets import load_dataset
import evaluate
import numpy as np
import gc



  from .autonotebook import tqdm as notebook_tqdm





In [7]:
# Load the model
checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# Output directory for saving trained model
out_dir = "models"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

### **Load data**

In [None]:
sample_count = 17886 # max is 17886

# Load the dataset
dataset = load_dataset('csv', data_files='YT-titles-transcripts-clean.csv', split='train[0:' + str(sample_count) + ']')
dataset = dataset.train_test_split(test_size=0.2)

source_text = dataset['train'][:]['transcript']
target_text = dataset['train'][:]['title']

# Tokenize the source and target text
tokenized_source_text = tokenizer(list(source_text), truncation=True, padding=False, max_length=512)
tokenized_target_text = tokenizer(list(target_text), truncation=True, padding=False, max_length=512)

# Find maximum lengths for source and target sequences
max_source = max(len(item) for item in tokenized_source_text['input_ids'])
max_target = max(len(item) for item in tokenized_target_text['input_ids'])

# Preprocess function for mapping dataset
def preprocess_function(unit):
    # Prepend "summarize: " to each chat for summarization task
    inputs = ["summarize: " + con for con in unit["transcript"]]
    # Tokenize inputs and labels
    model_inputs = tokenizer(inputs, padding='max_length', truncation=True, max_length=max_source)
    labels = tokenizer(text_target=unit["title"], padding='max_length', truncation=True, max_length=max_target)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Map preprocess function to dataset
tokenized_data = dataset.map(preprocess_function, batched=True)

### **Training**

In [None]:
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

# Load Rouge for evaluation
rouge = evaluate.load("rouge")

# Compute metrics function
def compute_metrics(preds):
    predictions, labels = preds
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}

# Load model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model.generation_config.max_new_tokens = 20

batch_size = 16

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=out_dir,
    eval_strategy="epoch",
    learning_rate=0.00005,
    # auto_find_batch_size=True, # added this
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # gradient_accumulation_steps=4, # added this
    # eval_accumulation_steps=1, # added this
    # eval_steps=10, # added this
    weight_decay=0.01,
    save_total_limit=5,
    num_train_epochs=20,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False
)

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


#gc.collect()
#torch.cuda.empty_cache()
# Train the model
trainer.train()

# Save the trained model
trainer.save_model(os.path.join(out_dir, 'trained-model'))

### **Inference**

In [10]:
# Load trained model
model = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(out_dir, 'trained-model')).to('cuda')

# Example input text for inference (transcription is from https://www.youtube.com/watch?v=dHy-qfkO54E)
input_text = 'summarize: ' + """I'm in Minecraft to discover one of the world's best
 kept food secrets.
 Come on, look at this.
 Oh, my Lord.
 I'm here in search of the most incredible apple,
 and there's only one way to get them.
 I've never made love to a tree.
 I haven't climbed a tree for decades, and I mean decades.
 That is delicious.
 It's so juicy.
 An amazing discovery.
 With enough apples from the jungle,
 I'm off to meet a guy who's been revolutionizing
 farming technology.
 Now, this farmer produces the most amazing potatoes.
 In fact, he's actually known as the mad potato scientist.
 After getting scammed for 15 emeralds,
 I've even managed to get my hands on a potato.
 That is amazing.
 Wow.
 This is the kind of discovery that sends me straight back
 to my kitchen.
 Because I've just tasted the most amazing potatoes.
 Next, I'm off to a place that's supposed
 to have the finest ingredient in the land.
 Oh, Lord, you are kidding me.
 Getting down there was an absolute fucking nightmare.
 And at the bottom, another surprise was waiting for me.
 Oh, my God.
 At first, I thought I found treasure,
 but then I made a horrifying discovery.
 What the fuck is this?
 Oh, dear.
 It's raw.
 Oh.
 Having survived my partial poisoning,
 it's time to face today's real challenge.
 challenge. This is it. Everything I've understood got up to speed with this week. It all comes
 down to this. For my main dish, I'm making a potato and apple soup. Capable of replenishing
 eight hunger bars. I'm also going to have a touch of honey in there as well. Seems as
 though word got out about my incredible meal. But right now, I can't afford to get distracted.
 But with just moments of spare, things took a turn for the worse. Hold on a minute. Oh
 for God's sake man. Let's get out of here before we get shot. And if you liked this
 video, don't forget to subscribe to my amazing YouTube channel for more. Good luck."""

# Perform inference
with torch.no_grad():
    tokenized_text = tokenizer(input_text, truncation=True, padding=True, return_tensors='pt')

    source_ids = tokenized_text['input_ids'].to('cuda', dtype=torch.long)
    source_mask = tokenized_text['attention_mask'].to('cuda', dtype=torch.long)

    generated_ids = model.generate(
        input_ids=source_ids,
        attention_mask=source_mask,
        max_length=512,
        num_beams=5,
        repetition_penalty=1,
        length_penalty=1,
        early_stopping=True,
        no_repeat_ngram_size=2
    )

    pred = tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

print("\noutput:\n" + pred)


output:
What's the Best-case scenario?
