# To train the GPT, please upload this notebook and the ```run_clm.py``` script in google drive so that they are in the same directory. Then enable TPU in the notebook settings and run all cells.

In [None]:
# pip install required packages
!pip install transformers datasets accelerate
!pip install cloud-tpu-client==0.10 torch==1.11.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl

# import statements
import os
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
import torch_xla
assert os.environ['COLAB_TPU_ADDR']
from random import randint

In [None]:
# run the GPT2 model with required flags
# uses perplexity as the metric to evaluate the model

!python3 run_clm.py \
--model_type gpt2 \
--model_name_or_path gpt2 \
--train_file "songs_merged.txt" \
--validation_file "rihanna.txt" \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--num_train_epochs 3 \
--output_dir "./gpt" \
--learning_rate 0.00001 \

In [None]:
# use the trained model for predictions
model = TFGPT2LMHeadModel.from_pretrained("./gpt", from_pt=True)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# give the first input for the model to start predictions
input_ids = tokenizer.encode("I love deep learning,", return_tensors='tf')

# generate all outputs using beam-search
beam_output = model.generate(input_ids, 
                             max_length=150, 
                             num_return_sequences=1,
                             no_repeat_ngram_size=2,
                             repetition_penalty=2.5,
                             temperature=0.85,
                             do_sample=True,
                             top_k=0,
                             num_beams=5, 
                             early_stopping=True)

# in the beam outputs, find the longest prediction
output = sorted(beam_output, key=len, reverse=True)

# save the longest prediction to a file
with open("final_song.txt", "w") as f:
    f.writelines([tokenizer.decode(x, skip_special_tokens=True) for x in output])