In [None]:
!pip install GPUtil  

In [None]:
!pip install pandas==1.5.3
!pip install transformers
!pip install datasets==2.11
!pip install wandb

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import os
import wandb
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
df = pd.read_csv("/kaggle/input/processed-taylorswift-df/processed_df.csv")

In [None]:
ds = load_dataset("csv", data_files="/kaggle/input/processed-taylorswift-df/processed_df.csv", split = "train")

In [None]:
ds

In [None]:
print(f"Train dataset size: {len(ds)}")


In [None]:
print(f"TRAINING SAMPLE: \n{ds['lyrics'][0]}")

In [None]:
# tokenize the text 

model_id="gpt2"

# Load tokenizer of FLAN-t5-base
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
# test tokenizer

tokenizer(ds["lyrics"][0])

In [None]:
ds

In [None]:

tokenizer.pad_token = tokenizer.eos_token
tokenized_dataset = ds.map(lambda x: tokenizer(x["lyrics"], truncation = True, padding = True), batched=True, remove_columns =["Tracks","Album_ID", "Album", "Album_Path"])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

In [None]:
tokenized_dataset.format

In [None]:
# split the dataset
tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.1)

In [None]:
tokenized_dataset

In [None]:
tokenized_dataset["train"]["input_ids"]

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id)

In [None]:
from transformers import pipeline

In [None]:
#define wandb variables
wandb.login()

os.environ["WANDB_PROJECT"] = "song-generator" # log to your project 

In [None]:
%env WANDB_LOG_MODEL=true

In [None]:
# check GPU usage
from GPUtil import showUtilization as gpu_usage
gpu_usage()  

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
#train GPT2
training_args = TrainingArguments(
    output_dir="/kaggle/working/finetuned_gpt2",
    evaluation_strategy="epoch",
    
    save_strategy="no",
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_steps = 250,
    num_train_epochs = 10,
    per_device_train_batch_size = 4,
    report_to="wandb",
    run_name = "baseline_gpt2_finetune",
    load_best_model_at_end = True
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator = data_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"]
)

trainer.train()

In [None]:
trainer.save_model()

In [None]:
tokenizer.save_pretrained("/kaggle/working/finetuned_gpt2")

In [None]:
# base
test_prompt = "End of passion play, crumbling away\nI'm your source of self-destruction\nVeins that pump with fear, sucking darkest clear"

In [None]:
model = pipeline('text-generation', model= model_id, device="cuda:0")

#Generate text and show results
result = model(test_prompt, penalty_alpha=0.7, top_k=5, max_new_tokens=300)

print(result[0]["generated_text"])

In [None]:
# inference 

#Load model and move to GPU
model = pipeline('text-generation', model="/kaggle/working/finetuned_gpt2", device="cuda:0")

#Generate text and show results
result = model(test_prompt, penalty_alpha=0.7, top_k=5, max_new_tokens=300)

print(result[0]["generated_text"])