##Install and import packages

In [None]:
!pip install transformers

In [None]:
pip install git+https://github.com/huggingface/transformers

In [None]:
!pip install tensorboard

In [None]:
import os

from google.colab import drive

import torch
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
torch.manual_seed(42)

from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup

import pandas as pd
from sklearn.model_selection import train_test_split
import re

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install datasets

##Data pre processing 

In [None]:
lyrics = pd.read_csv('/content/drive/MyDrive/lyrics-data.csv')
artists = pd.read_csv('/content/drive/MyDrive/artists-data.csv')

In [None]:
full_df = lyrics.merge(artists[['Artist', 'Genres', 'Link']], left_on='ALink', right_on='Link', how='inner')
full_df = full_df.drop(columns=['ALink','SLink','Link'])

In [None]:
rock_df = full_df[(full_df['Genres'].isin(['Rock']))]
hiphop_df = full_df[(full_df['Genres'].isin(['Hip Hop']))]

In [None]:
train_test_ratio = 0.9
train_valid_ratio = 0.6
rock_train_full, rock_test = train_test_split(rock_df, train_size = train_test_ratio, random_state = 1)
rock_train, rock_val = train_test_split(rock_train_full, train_size = train_valid_ratio, random_state = 1)
hiphop_train_full, hiphop_test = train_test_split(hiphop_df, train_size = train_test_ratio, random_state = 1)
hiphop_train, hiphop_val = train_test_split(hiphop_train_full, train_size = train_valid_ratio, random_state = 1)

In [None]:
def create_data(df, dest_path):
    f = open(dest_path, 'w')
    data = ''
    lyrics_df = df['Lyric'].tolist()
    for lyric in lyrics_df:
        lyric = str(lyric).strip()
        lyric = re.sub(r"\s", " ", lyric)
        bos_token = '<BOS>'
        eos_token = '<EOS>'
        data += bos_token + ' ' + lyric + ' ' + eos_token + '\n'
        
    f.write(data)

In [None]:
create_data(rock_train, '/content/drive/MyDrive/Topics in Computing Notebooks/Data/rock_train.txt')
create_data(rock_val, '/content/drive/MyDrive/Topics in Computing Notebooks/Data/rock_valid.txt')
create_data(rock_test, '/content/drive/MyDrive/Topics in Computing Notebooks/Data/rock_test.txt')
create_data(hiphop_train, '/content/drive/MyDrive/Topics in Computing Notebooks/Data/hiphop_train.txt')
create_data(hiphop_val, '/content/drive/MyDrive/Topics in Computing Notebooks/Data/hiphop_valid.txt')
create_data(hiphop_test, '/content/drive/MyDrive/Topics in Computing Notebooks/Data/hiphop_test.txt')

### Training the Rock and Hip-hop Model

In [None]:
!python '/content/drive/MyDrive/run_clm.py' \
  --output_dir='/content/drive/MyDrive/rock_model_final'\
  --model_type=gpt2 \
  --model_name_or_path=gpt2 \
  --do_train \
  --train_file='/content/drive/MyDrive/rock_train.txt'\
  --do_eval \
  --validation_file='/content/drive/MyDrive/rock_valid.txt'\
  --per_device_train_batch_size=2 \
  --per_device_eval_batch_size=2 \
  --learning_rate 5e-5 \
  --num_train_epochs=5

In [None]:
!python '/content/drive/MyDrive/run_clm.py' \
  --output_dir='/content/drive/MyDrive/hiphop_model'\
  --model_type=gpt2 \
  --model_name_or_path=gpt2 \
  --do_train \
  --train_file='/content/drive/MyDrive/hiphop_train.txt'\
  --do_eval \
  --validation_file='/content/drive/MyDrive/hiphop_valid.txt'\
  --per_device_train_batch_size=2 \
  --per_device_eval_batch_size=2 \
  --learning_rate 5e-5 \
  --num_train_epochs=5

##Generate lyrics

In [None]:
!python '/content/drive/MyDrive/run_generation.py' \
  --model_type gpt2 \
  --model_name_or_path '/content/drive/MyDrive/hiphop_model' \
  --prompt "I love deep learning" \
  --k 50 \
  --length=500 \
  --num_return_sequences 5

In [None]:
!python '/content/drive/MyDrive/run_generation.py' \
  --model_type gpt2 \
  --model_name_or_path '/content/drive/MyDrive/rock_model_final' \
  --prompt "I love deep learning" \
  --k 50 \
  --length=500 \
  --num_return_sequences 5