In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
import logging
import warnings
from torch.utils.data import Dataset, DataLoader
import os
import csv

# Set up logging and warnings
logging.getLogger().setLevel(logging.CRITICAL)
warnings.filterwarnings('ignore')

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)

In [3]:
def choose_from_top(probs, n=5):
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob)  # Normalize
    choice = np.random.choice(n, 1, p=top_prob)
    token_id = ind[choice][0]
    return int(token_id)

In [4]:
class LyricsDataset(Dataset):
    def __init__(self, lyrics_dataset_path='song_lyrics.csv'):
        super().__init__()
        self.lyrics_list = []
        self.end_of_text_token = tokenizer.eos_token

        with open(lyrics_dataset_path, encoding='utf-8') as csv_file:
            csv_reader = csv.reader(csv_file)
            for row in csv_reader:
                lyric_str = f"SONG:{row[1]}{self.end_of_text_token}"
                self.lyrics_list.append(lyric_str)

    def __len__(self):
        return len(self.lyrics_list)

    def __getitem__(self, item):
        return self.lyrics_list[item]

In [5]:
dataset = LyricsDataset('SelenaGomez.csv')
lyrics_loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [6]:
# Training parameters
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 3e-5
WARMUP_STEPS = 5000
MAX_SEQ_LEN = 400

# Optimizer and scheduler
from transformers import AdamW, get_linear_schedule_with_warmup


device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [7]:
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=-1)

# Training loop
proc_seq_count = 0
sum_loss = 0.0
batch_count = 0
tmp_lyrics_tens = None
models_folder = "trained_models"

if not os.path.exists(models_folder):
    os.mkdir(models_folder)

for epoch in range(EPOCHS):
    print(f"EPOCH {epoch} started" + '=' * 30)
    
    for idx, lyric in enumerate(lyrics_loader):
        lyric_tens = torch.tensor(tokenizer.encode(lyric[0])).unsqueeze(0).to(device)
        
        if lyric_tens.size()[1] > MAX_SEQ_LEN:
            continue
        
        if not torch.is_tensor(tmp_lyrics_tens):
            tmp_lyrics_tens = lyric_tens
            continue
        else:
            if tmp_lyrics_tens.size()[1] + lyric_tens.size()[1] > MAX_SEQ_LEN:
                work_lyrics_tens = tmp_lyrics_tens
                tmp_lyrics_tens = lyric_tens
            else:
                tmp_lyrics_tens = torch.cat([tmp_lyrics_tens, lyric_tens[:, 1:]], dim=1)
                continue
        
        outputs = model(work_lyrics_tens, labels=work_lyrics_tens)
        loss, logits = outputs[:2]
        loss.backward()
        sum_loss += loss.detach().data
        
        proc_seq_count += 1
        if proc_seq_count == BATCH_SIZE:
            proc_seq_count = 0
            batch_count += 1
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.zero_grad()

        if batch_count == 100:
            print(f"sum loss {sum_loss}")
            batch_count = 0
            sum_loss = 0.0
    
    # Save model after each epoch
    try:
        torch.save(model.state_dict(), os.path.join(models_folder, f"gpt2_medium_lyrics_{epoch}.pt"))
    except Exception as e:
        print(f"Error saving model: {e}")
        break



In [10]:
MODEL_EPOCH = 4

models_folder = "trained_models"

model_path = os.path.join(models_folder, f"gpt2_medium_lyrics_{MODEL_EPOCH}.pt")
model.load_state_dict(torch.load(model_path))

lyrics_output_file_path = f'generated_{MODEL_EPOCH}.lyrics'

model.eval()
if os.path.exists(lyrics_output_file_path):
    os.remove(lyrics_output_file_path)
    
lyrics_num = 0
with torch.no_grad():
   
        for lyrics_idx in range(1000):
        
            lyrics_finished = False

            cur_ids = torch.tensor(tokenizer.encode("SONG:")).unsqueeze(0).to(device)

            for i in range(100):
                outputs = model(cur_ids, labels=cur_ids)
                loss, logits = outputs[:2]
                softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(from only one in this case) batch and the last predicted embedding
                if i < 3:
                    n = 20
                else:
                    n = 3
                next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) select the next word
                cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence

                if next_token_id in tokenizer.encode('<|endoftext|>'):
                    lyrics_finished = True
                    break

            
            if lyrics_finished:
                
                lyrics_num = lyrics_num + 1
                
                output_list = list(cur_ids.squeeze().to('cpu').numpy())
                output_text = tokenizer.decode(output_list)

                with open(lyrics_output_file_path, 'a') as f:
                    f.write(f"{output_text} \n\n")
                    

AttributeError: module 'tensorflow.core.framework.types_pb2' has no attribute 'SerializedDType'