In [23]:
import pandas as pd
import numpy as np


In [24]:
tv_complete = pd.read_csv('data/TMDB_tv_dataset_v3.csv')
tv_complete.head()


Unnamed: 0,id,name,number_of_seasons,number_of_episodes,original_language,vote_count,vote_average,overview,adult,backdrop_path,...,tagline,genres,created_by,languages,networks,origin_country,spoken_languages,production_companies,production_countries,episode_run_time
0,1399,Game of Thrones,8,73,en,21857,8.442,Seven noble families fight for control of the ...,False,/2OMB0ynKlyIenMJWI2Dy9IWT4c.jpg,...,Winter Is Coming,"Sci-Fi & Fantasy, Drama, Action & Adventure","David Benioff, D.B. Weiss",en,HBO,US,English,"Revolution Sun Studios, Television 360, Genera...","United Kingdom, United States of America",0
1,71446,Money Heist,3,41,es,17836,8.257,"To carry out the biggest heist in history, a m...",False,/gFZriCkpJYsApPZEF3jhxL4yLzG.jpg,...,The perfect robbery.,"Crime, Drama",Álex Pina,es,"Netflix, Antena 3",ES,Español,Vancouver Media,Spain,70
2,66732,Stranger Things,4,34,en,16161,8.624,"When a young boy vanishes, a small town uncove...",False,/2MaumbgBlW1NoPo3ZJO38A6v7OS.jpg,...,Every ending has a beginning.,"Drama, Sci-Fi & Fantasy, Mystery","Matt Duffer, Ross Duffer",en,Netflix,US,English,"21 Laps Entertainment, Monkey Massacre Product...",United States of America,0
3,1402,The Walking Dead,11,177,en,15432,8.121,Sheriff's deputy Rick Grimes awakens from a co...,False,/x4salpjB11umlUOltfNvSSrjSXm.jpg,...,Fight the dead. Fear the living.,"Action & Adventure, Drama, Sci-Fi & Fantasy",Frank Darabont,en,AMC,US,English,"AMC Studios, Circle of Confusion, Valhalla Mot...",United States of America,42
4,63174,Lucifer,6,93,en,13870,8.486,"Bored and unhappy as the Lord of Hell, Lucifer...",False,/aDBRtunw49UF4XmqfyNuD9nlYIu.jpg,...,It's good to be bad.,"Crime, Sci-Fi & Fantasy",Tom Kapinos,en,"FOX, Netflix",US,English,"Warner Bros. Television, DC Entertainment, Jer...",United States of America,45


In [25]:
tv_complete.columns

Index(['id', 'name', 'number_of_seasons', 'number_of_episodes',
       'original_language', 'vote_count', 'vote_average', 'overview', 'adult',
       'backdrop_path', 'first_air_date', 'last_air_date', 'homepage',
       'in_production', 'original_name', 'popularity', 'poster_path', 'type',
       'status', 'tagline', 'genres', 'created_by', 'languages', 'networks',
       'origin_country', 'spoken_languages', 'production_companies',
       'production_countries', 'episode_run_time'],
      dtype='object')

In [26]:
# name, overview, tagline, 
# adult, type, genres, country
tv = tv_complete[['adult','type','genres','origin_country','name','overview']].dropna()
tv

Unnamed: 0,adult,type,genres,origin_country,name,overview
0,False,Scripted,"Sci-Fi & Fantasy, Drama, Action & Adventure",US,Game of Thrones,Seven noble families fight for control of the ...
1,False,Scripted,"Crime, Drama",ES,Money Heist,"To carry out the biggest heist in history, a m..."
2,False,Scripted,"Drama, Sci-Fi & Fantasy, Mystery",US,Stranger Things,"When a young boy vanishes, a small town uncove..."
3,False,Scripted,"Action & Adventure, Drama, Sci-Fi & Fantasy",US,The Walking Dead,Sheriff's deputy Rick Grimes awakens from a co...
4,False,Scripted,"Crime, Sci-Fi & Fantasy",US,Lucifer,"Bored and unhappy as the Lord of Hell, Lucifer..."
...,...,...,...,...,...,...
168621,False,Scripted,Drama,"MX, ES",The Dentist,Forensic dentist Nolasco Black investigates a ...
168624,True,Scripted,"Animation, Comedy",RU,Take My Muffin,The story is about a unicorn who has lost his ...
168625,False,Scripted,Comedy,CL,jappening with ja,It was broadcast for the first time in April 1...
168632,False,Scripted,Drama,TH,Born to Be Y,The story of 14 contestants who audition to co...


In [27]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import Dataset, DatasetDict

In [28]:
def make_input_prompt(row):
    return (f"adult: {row['adult']}\n"
            f"type: {row['type']}\n"
            f"genres: {row['genres']}\n"
            f"origin_country: {row['origin_country']}\n"
            "Generate name and overview:")

def make_target_text(row):
    return (f"name: {row['name']}\n"
            f"overview: {row['overview']}")

tv['input_text'] = tv.apply(make_input_prompt, axis=1)
tv['target_text'] = tv.apply(make_target_text, axis=1)
tv.head()

Unnamed: 0,adult,type,genres,origin_country,name,overview,input_text,target_text
0,False,Scripted,"Sci-Fi & Fantasy, Drama, Action & Adventure",US,Game of Thrones,Seven noble families fight for control of the ...,adult: False\ntype: Scripted\ngenres: Sci-Fi &...,name: Game of Thrones\noverview: Seven noble f...
1,False,Scripted,"Crime, Drama",ES,Money Heist,"To carry out the biggest heist in history, a m...","adult: False\ntype: Scripted\ngenres: Crime, D...",name: Money Heist\noverview: To carry out the ...
2,False,Scripted,"Drama, Sci-Fi & Fantasy, Mystery",US,Stranger Things,"When a young boy vanishes, a small town uncove...","adult: False\ntype: Scripted\ngenres: Drama, S...",name: Stranger Things\noverview: When a young ...
3,False,Scripted,"Action & Adventure, Drama, Sci-Fi & Fantasy",US,The Walking Dead,Sheriff's deputy Rick Grimes awakens from a co...,adult: False\ntype: Scripted\ngenres: Action &...,name: The Walking Dead\noverview: Sheriff's de...
4,False,Scripted,"Crime, Sci-Fi & Fantasy",US,Lucifer,"Bored and unhappy as the Lord of Hell, Lucifer...","adult: False\ntype: Scripted\ngenres: Crime, S...",name: Lucifer\noverview: Bored and unhappy as ...


In [None]:
# simple random split
train_df = tv.sample(frac=0.8, random_state=42)
val_df = tv.drop(train_df.index)

In [32]:
train_dataset = Dataset.from_pandas(train_df[['input_text','target_text']])
val_dataset = Dataset.from_pandas(val_df[['input_text','target_text']])

tv_dict = DatasetDict({
    "train": train_dataset,
    "validation": val_dataset
})

In [None]:
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)

def tokenize_function(examples):
    # we feed the model the input_text, and want to predict target_text
    inputs = tokenizer(examples["input_text"], padding="max_length", truncation=True, max_length=512)
    targets = tokenizer(examples["target_text"], padding="max_length", truncation=True, max_length=512)
    inputs["labels"] = targets["input_ids"]
    return inputs

tokenized_datasets = tv_dict.map(tokenize_function, batched=True)

Map:   0%|          | 0/47972 [00:00<?, ? examples/s]

Map:   0%|          | 0/11993 [00:00<?, ? examples/s]

In [34]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_text', 'target_text', '__index_level_0__', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 47972
    })
    validation: Dataset({
        features: ['input_text', 'target_text', '__index_level_0__', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 11993
    })
})

In [None]:
model = T5ForConditionalGeneration.from_pretrained(model_name)

training_args = TrainingArguments(
    output_dir="t5_tv_shows",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,  # 2-5
    weight_decay=0.01,
    logging_steps=100,
    save_steps=500,
    save_total_limit=2
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"]
)

In [38]:
trainer.train()
trainer.save_model("t5_tv_shows")
tokenizer.save_pretrained("t5_tv_shows")

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 