In [87]:
import re
import warnings
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import time
import fasttext
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

In [88]:

print("Loading data...")
path = "data/best_200songs_perartist_nw.csv"
data = pd.read_csv(path)
print("Data loaded. Shape:", data.shape)

# Filter lyrics for Taylor Swift
data = data[data['artist'] == 'Taylor Swift']
print("Data filtered for Taylor Swift. Shape:", data.shape)

# Further cleaning
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
    return text

print("Preprocessing lyrics...")
data['clean_lyrics'] = data['clean_lyrics'].apply(preprocess_text)
print("Lyrics preprocessed. Sample data:\n", data['clean_lyrics'].head())

# Creating a corpus
print("Creating corpus...")
corpus = [lyrics.split() for lyrics in data['clean_lyrics']]
corpus_flat = [" ".join(sublist) for sublist in corpus]
with open("corpus.txt", "w") as file:
    file.write("\n".join(corpus_flat))
print("Corpus created. Total documents:", len(corpus))

# Train FastText model
print("Training FastText model...")
fasttext_model = fasttext.train_unsupervised('corpus.txt', model='skipgram', dim=100)
print("FastText model trained.")

# Creating vocabulary and mappings
print("Creating vocabulary and mappings...")
unique_words = fasttext_model.words
word_to_index = {word: idx for idx, word in enumerate(unique_words)}
index_to_word = {idx: word for word, idx in word_to_index.items()}
embedding_dim = fasttext_model.get_dimension()
print("Vocabulary size:", len(unique_words))

# Add a special token for unknown words
unk_token = '<UNK>'
word_to_index[unk_token] = len(word_to_index)
index_to_word[len(index_to_word)] = unk_token

Loading data...
Data loaded. Shape: (3800, 5)
Data filtered for Taylor Swift. Shape: (190, 5)
Preprocessing lyrics...
Lyrics preprocessed. Sample data:
 27    \n i walked through the door with you the air ...
31    \n vintage tee brand new phone \n high heels o...
48    \n we could leave the christmas lights up til ...
50    \n i m doing good i m on some new shit \n been...
65    \n i walked through the door with you the air ...
Name: clean_lyrics, dtype: object
Creating corpus...
Corpus created. Total documents: 190
Training FastText model...


Read 0M words
Number of words:  1169
Number of labels: 0
Progress:  63.7% words/sec/thread:  162579 lr:  0.018167 avg.loss:  2.593716 ETA:   0h 0m 0s

FastText model trained.
Creating vocabulary and mappings...
Vocabulary size: 1169


Progress: 100.0% words/sec/thread:  172235 lr:  0.000000 avg.loss:  2.608968 ETA:   0h 0m 0s


In [89]:
# Create an embedding matrix
print("Creating embedding matrix...")
embedding_matrix = np.zeros((len(unique_words) + 1, embedding_dim))
for word, idx in word_to_index.items():
    if word in fasttext_model:
        embedding_matrix[idx] = fasttext_model.get_word_vector(word)
    else:
        embedding_matrix[idx] = np.random.normal(size=(embedding_dim,))

# Prepare sequences for training
sequence_length = 40
features = []
targets = []
print("Preparing sequences for training...")
for lyrics in corpus:
    for i in range(len(lyrics) - sequence_length):
        input_sequence = [word_to_index.get(word, word_to_index[unk_token]) for word in lyrics[i:i+sequence_length]]
        target_word = word_to_index.get(lyrics[i+sequence_length], word_to_index[unk_token])
        features.append(input_sequence)
        targets.append(target_word)

print("Total sequences prepared:", len(features))

# Convert to numpy arrays
features = np.array(features)
targets = np.array(targets)
print("Converted sequences to numpy arrays. Features shape:", features.shape, "Targets shape:", targets.shape)

# Split the dataset into train and test sets
print("Splitting data into train and test sets...")
X_train, X_test, y_train, y_test = train_test_split(features, targets, test_size=0.2, random_state=42)
print("Data split. Train shape:", X_train.shape, "Test shape:", X_test.shape)

# Convert to PyTorch tensors
print("Converting data to PyTorch tensors...")
X_train = torch.tensor(X_train, dtype=torch.long)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)
print("Data converted to tensors.")

Creating embedding matrix...
Preparing sequences for training...
Total sequences prepared: 66990
Converted sequences to numpy arrays. Features shape: (66990, 40) Targets shape: (66990,)
Splitting data into train and test sets...
Data split. Train shape: (53592, 40) Test shape: (13398, 40)
Converting data to PyTorch tensors...
Data converted to tensors.


In [91]:
class LyricsDataset(Dataset):
    def __init__(self, features, targets):
        self.features = features
        self.targets = targets

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

print("Creating DataLoader...")
train_dataset = LyricsDataset(X_train, y_train)
test_dataset = LyricsDataset(X_test, y_test)

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
print("DataLoader created.")

class LyricsLSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, embedding_matrix):
        super(LyricsLSTMModel, self).__init__()
        self.embedding = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float32), freeze=False)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        h0 = torch.zeros(num_layers, x.size(0), hidden_size).to(x.device)
        c0 = torch.zeros(num_layers, x.size(0), hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]
        out = self.fc(out)
        return out

# Hyperparameters
vocab_size = len(unique_words) + 1  # Add 1 for the UNK token
embed_size = embedding_dim
hidden_size = 256
num_layers = 2
num_epochs = 150
learning_rate = 0.001

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model = LyricsLSTMModel(vocab_size, embed_size, hidden_size, num_layers, embedding_matrix).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Creating DataLoader...
DataLoader created.


In [92]:
# Train the model
def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        start_time = time.time()
        total_loss = 0
        with tqdm(total=len(train_loader), desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch') as pbar:
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': total_loss / (pbar.n + 1)})
                pbar.update(1)
        end_time = time.time()
        epoch_time = end_time - start_time
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Time: {epoch_time:.2f} sec')

print("Starting training...")
train_model(model, train_loader, criterion, optimizer, num_epochs)
print("Training completed.")

Starting training...


Epoch 1/150: 100%|██████████| 210/210 [00:16<00:00, 12.83batch/s, loss=5.56]


Epoch [1/150], Loss: 5.5561, Time: 16.37 sec


Epoch 2/150: 100%|██████████| 210/210 [00:13<00:00, 15.78batch/s, loss=5.26]


Epoch [2/150], Loss: 5.2635, Time: 13.31 sec


Epoch 3/150: 100%|██████████| 210/210 [00:13<00:00, 15.79batch/s, loss=4.79]


Epoch [3/150], Loss: 4.7899, Time: 13.30 sec


Epoch 4/150: 100%|██████████| 210/210 [00:13<00:00, 15.48batch/s, loss=4.52]


Epoch [4/150], Loss: 4.5181, Time: 13.57 sec


Epoch 5/150: 100%|██████████| 210/210 [00:14<00:00, 14.92batch/s, loss=4.33]


Epoch [5/150], Loss: 4.3251, Time: 14.08 sec


Epoch 6/150: 100%|██████████| 210/210 [00:14<00:00, 14.57batch/s, loss=4.15]


Epoch [6/150], Loss: 4.1540, Time: 14.41 sec


Epoch 7/150: 100%|██████████| 210/210 [00:14<00:00, 14.13batch/s, loss=4]   


Epoch [7/150], Loss: 4.0012, Time: 14.87 sec


Epoch 8/150: 100%|██████████| 210/210 [00:15<00:00, 13.84batch/s, loss=3.86]


Epoch [8/150], Loss: 3.8625, Time: 15.18 sec


Epoch 9/150: 100%|██████████| 210/210 [00:15<00:00, 13.57batch/s, loss=3.73]


Epoch [9/150], Loss: 3.7325, Time: 15.47 sec


Epoch 10/150: 100%|██████████| 210/210 [00:14<00:00, 14.38batch/s, loss=3.61]


Epoch [10/150], Loss: 3.6120, Time: 14.61 sec


Epoch 11/150: 100%|██████████| 210/210 [00:14<00:00, 14.31batch/s, loss=3.5] 


Epoch [11/150], Loss: 3.5006, Time: 14.68 sec


Epoch 12/150: 100%|██████████| 210/210 [00:14<00:00, 14.35batch/s, loss=3.39]


Epoch [12/150], Loss: 3.3914, Time: 14.63 sec


Epoch 13/150: 100%|██████████| 210/210 [00:14<00:00, 14.12batch/s, loss=3.29]


Epoch [13/150], Loss: 3.2907, Time: 14.87 sec


Epoch 14/150: 100%|██████████| 210/210 [00:14<00:00, 14.25batch/s, loss=3.19]


Epoch [14/150], Loss: 3.1939, Time: 14.74 sec


Epoch 15/150: 100%|██████████| 210/210 [00:15<00:00, 13.56batch/s, loss=3.1] 


Epoch [15/150], Loss: 3.0995, Time: 15.49 sec


Epoch 16/150: 100%|██████████| 210/210 [00:14<00:00, 14.09batch/s, loss=3.01]


Epoch [16/150], Loss: 3.0070, Time: 14.91 sec


Epoch 17/150: 100%|██████████| 210/210 [00:15<00:00, 13.98batch/s, loss=2.92]


Epoch [17/150], Loss: 2.9243, Time: 15.02 sec


Epoch 18/150: 100%|██████████| 210/210 [00:15<00:00, 13.32batch/s, loss=2.84]


Epoch [18/150], Loss: 2.8378, Time: 15.77 sec


Epoch 19/150: 100%|██████████| 210/210 [00:15<00:00, 13.38batch/s, loss=2.76]


Epoch [19/150], Loss: 2.7585, Time: 15.70 sec


Epoch 20/150: 100%|██████████| 210/210 [00:16<00:00, 12.97batch/s, loss=2.68]


Epoch [20/150], Loss: 2.6849, Time: 16.20 sec


Epoch 21/150: 100%|██████████| 210/210 [00:16<00:00, 13.03batch/s, loss=2.61]


Epoch [21/150], Loss: 2.6065, Time: 16.12 sec


Epoch 22/150: 100%|██████████| 210/210 [00:17<00:00, 12.26batch/s, loss=2.53]


Epoch [22/150], Loss: 2.5337, Time: 17.13 sec


Epoch 23/150: 100%|██████████| 210/210 [00:18<00:00, 11.56batch/s, loss=2.46]


Epoch [23/150], Loss: 2.4618, Time: 18.17 sec


Epoch 24/150: 100%|██████████| 210/210 [00:18<00:00, 11.46batch/s, loss=2.39]


Epoch [24/150], Loss: 2.3912, Time: 18.32 sec


Epoch 25/150: 100%|██████████| 210/210 [00:18<00:00, 11.59batch/s, loss=2.32]


Epoch [25/150], Loss: 2.3216, Time: 18.12 sec


Epoch 26/150: 100%|██████████| 210/210 [00:18<00:00, 11.63batch/s, loss=2.26]


Epoch [26/150], Loss: 2.2552, Time: 18.05 sec


Epoch 27/150: 100%|██████████| 210/210 [00:18<00:00, 11.65batch/s, loss=2.19]


Epoch [27/150], Loss: 2.1897, Time: 18.03 sec


Epoch 28/150: 100%|██████████| 210/210 [00:18<00:00, 11.56batch/s, loss=2.13]


Epoch [28/150], Loss: 2.1265, Time: 18.16 sec


Epoch 29/150: 100%|██████████| 210/210 [00:18<00:00, 11.52batch/s, loss=2.06]


Epoch [29/150], Loss: 2.0631, Time: 18.23 sec


Epoch 30/150: 100%|██████████| 210/210 [00:18<00:00, 11.67batch/s, loss=2]   


Epoch [30/150], Loss: 2.0040, Time: 18.00 sec


Epoch 31/150: 100%|██████████| 210/210 [00:18<00:00, 11.61batch/s, loss=1.94]


Epoch [31/150], Loss: 1.9433, Time: 18.10 sec


Epoch 32/150: 100%|██████████| 210/210 [00:17<00:00, 11.67batch/s, loss=1.89]


Epoch [32/150], Loss: 1.8853, Time: 18.00 sec


Epoch 33/150: 100%|██████████| 210/210 [00:17<00:00, 11.70batch/s, loss=1.83]


Epoch [33/150], Loss: 1.8312, Time: 17.95 sec


Epoch 34/150: 100%|██████████| 210/210 [00:18<00:00, 11.63batch/s, loss=1.78]


Epoch [34/150], Loss: 1.7761, Time: 18.05 sec


Epoch 35/150: 100%|██████████| 210/210 [00:18<00:00, 11.63batch/s, loss=1.72]


Epoch [35/150], Loss: 1.7221, Time: 18.05 sec


Epoch 36/150: 100%|██████████| 210/210 [00:18<00:00, 11.59batch/s, loss=1.67]


Epoch [36/150], Loss: 1.6701, Time: 18.13 sec


Epoch 37/150: 100%|██████████| 210/210 [00:18<00:00, 11.49batch/s, loss=1.62]


Epoch [37/150], Loss: 1.6184, Time: 18.27 sec


Epoch 38/150: 100%|██████████| 210/210 [00:17<00:00, 11.68batch/s, loss=1.57]


Epoch [38/150], Loss: 1.5701, Time: 17.98 sec


Epoch 39/150: 100%|██████████| 210/210 [00:18<00:00, 11.58batch/s, loss=1.52]


Epoch [39/150], Loss: 1.5199, Time: 18.14 sec


Epoch 40/150: 100%|██████████| 210/210 [00:18<00:00, 11.60batch/s, loss=1.47]


Epoch [40/150], Loss: 1.4715, Time: 18.11 sec


Epoch 41/150: 100%|██████████| 210/210 [00:18<00:00, 11.56batch/s, loss=1.43]


Epoch [41/150], Loss: 1.4267, Time: 18.16 sec


Epoch 42/150: 100%|██████████| 210/210 [00:18<00:00, 11.63batch/s, loss=1.38]


Epoch [42/150], Loss: 1.3788, Time: 18.06 sec


Epoch 43/150: 100%|██████████| 210/210 [00:18<00:00, 11.51batch/s, loss=1.34]


Epoch [43/150], Loss: 1.3358, Time: 18.24 sec


Epoch 44/150: 100%|██████████| 210/210 [00:18<00:00, 11.60batch/s, loss=1.29]


Epoch [44/150], Loss: 1.2923, Time: 18.10 sec


Epoch 45/150: 100%|██████████| 210/210 [00:18<00:00, 11.53batch/s, loss=1.25]


Epoch [45/150], Loss: 1.2482, Time: 18.21 sec


Epoch 46/150: 100%|██████████| 210/210 [00:18<00:00, 11.54batch/s, loss=1.21]


Epoch [46/150], Loss: 1.2062, Time: 18.20 sec


Epoch 47/150: 100%|██████████| 210/210 [00:18<00:00, 11.58batch/s, loss=1.17]


Epoch [47/150], Loss: 1.1664, Time: 18.14 sec


Epoch 48/150: 100%|██████████| 210/210 [00:18<00:00, 11.55batch/s, loss=1.13]


Epoch [48/150], Loss: 1.1261, Time: 18.19 sec


Epoch 49/150: 100%|██████████| 210/210 [00:18<00:00, 11.23batch/s, loss=1.09]


Epoch [49/150], Loss: 1.0882, Time: 18.70 sec


Epoch 50/150: 100%|██████████| 210/210 [00:19<00:00, 10.95batch/s, loss=1.05]


Epoch [50/150], Loss: 1.0523, Time: 19.18 sec


Epoch 51/150: 100%|██████████| 210/210 [00:20<00:00, 10.40batch/s, loss=1.01] 


Epoch [51/150], Loss: 1.0143, Time: 20.20 sec


Epoch 52/150: 100%|██████████| 210/210 [00:18<00:00, 11.12batch/s, loss=0.977]


Epoch [52/150], Loss: 0.9766, Time: 18.89 sec


Epoch 53/150: 100%|██████████| 210/210 [00:18<00:00, 11.22batch/s, loss=0.941]


Epoch [53/150], Loss: 0.9414, Time: 18.72 sec


Epoch 54/150: 100%|██████████| 210/210 [00:18<00:00, 11.16batch/s, loss=0.907]


Epoch [54/150], Loss: 0.9067, Time: 18.81 sec


Epoch 55/150: 100%|██████████| 210/210 [00:18<00:00, 11.28batch/s, loss=0.875]


Epoch [55/150], Loss: 0.8751, Time: 18.62 sec


Epoch 56/150: 100%|██████████| 210/210 [00:18<00:00, 11.39batch/s, loss=0.841]


Epoch [56/150], Loss: 0.8413, Time: 18.44 sec


Epoch 57/150: 100%|██████████| 210/210 [00:19<00:00, 10.94batch/s, loss=0.806]


Epoch [57/150], Loss: 0.8055, Time: 19.20 sec


Epoch 58/150: 100%|██████████| 210/210 [00:18<00:00, 11.48batch/s, loss=0.776]


Epoch [58/150], Loss: 0.7758, Time: 18.30 sec


Epoch 59/150: 100%|██████████| 210/210 [00:18<00:00, 11.24batch/s, loss=0.745]


Epoch [59/150], Loss: 0.7445, Time: 18.69 sec


Epoch 60/150: 100%|██████████| 210/210 [00:18<00:00, 11.29batch/s, loss=0.714]


Epoch [60/150], Loss: 0.7143, Time: 18.61 sec


Epoch 61/150: 100%|██████████| 210/210 [00:18<00:00, 11.24batch/s, loss=0.685]


Epoch [61/150], Loss: 0.6849, Time: 18.68 sec


Epoch 62/150: 100%|██████████| 210/210 [00:18<00:00, 11.20batch/s, loss=0.656]


Epoch [62/150], Loss: 0.6561, Time: 18.75 sec


Epoch 63/150: 100%|██████████| 210/210 [00:18<00:00, 11.17batch/s, loss=0.628]


Epoch [63/150], Loss: 0.6282, Time: 18.79 sec


Epoch 64/150: 100%|██████████| 210/210 [00:18<00:00, 11.20batch/s, loss=0.602]


Epoch [64/150], Loss: 0.6018, Time: 18.75 sec


Epoch 65/150: 100%|██████████| 210/210 [00:18<00:00, 11.31batch/s, loss=0.575]


Epoch [65/150], Loss: 0.5750, Time: 18.57 sec


Epoch 66/150: 100%|██████████| 210/210 [00:18<00:00, 11.06batch/s, loss=0.55] 


Epoch [66/150], Loss: 0.5497, Time: 18.98 sec


Epoch 67/150: 100%|██████████| 210/210 [00:18<00:00, 11.13batch/s, loss=0.524]


Epoch [67/150], Loss: 0.5240, Time: 18.88 sec


Epoch 68/150: 100%|██████████| 210/210 [00:18<00:00, 11.29batch/s, loss=0.504]


Epoch [68/150], Loss: 0.5040, Time: 18.60 sec


Epoch 69/150: 100%|██████████| 210/210 [00:19<00:00, 10.90batch/s, loss=0.478]


Epoch [69/150], Loss: 0.4784, Time: 19.28 sec


Epoch 70/150: 100%|██████████| 210/210 [00:19<00:00, 10.94batch/s, loss=0.456]


Epoch [70/150], Loss: 0.4562, Time: 19.21 sec


Epoch 71/150: 100%|██████████| 210/210 [00:19<00:00, 10.75batch/s, loss=0.433]


Epoch [71/150], Loss: 0.4330, Time: 19.53 sec


Epoch 72/150: 100%|██████████| 210/210 [00:19<00:00, 10.62batch/s, loss=0.413]


Epoch [72/150], Loss: 0.4131, Time: 19.77 sec


Epoch 73/150: 100%|██████████| 210/210 [00:19<00:00, 10.50batch/s, loss=0.392]


Epoch [73/150], Loss: 0.3918, Time: 20.00 sec


Epoch 74/150: 100%|██████████| 210/210 [00:19<00:00, 10.53batch/s, loss=0.372]


Epoch [74/150], Loss: 0.3724, Time: 19.95 sec


Epoch 75/150: 100%|██████████| 210/210 [00:20<00:00, 10.26batch/s, loss=0.353]


Epoch [75/150], Loss: 0.3534, Time: 20.48 sec


Epoch 76/150: 100%|██████████| 210/210 [00:20<00:00, 10.21batch/s, loss=0.339]


Epoch [76/150], Loss: 0.3390, Time: 20.57 sec


Epoch 77/150: 100%|██████████| 210/210 [00:21<00:00,  9.89batch/s, loss=0.324]


Epoch [77/150], Loss: 0.3240, Time: 21.23 sec


Epoch 78/150: 100%|██████████| 210/210 [00:21<00:00,  9.86batch/s, loss=0.305]


Epoch [78/150], Loss: 0.3050, Time: 21.30 sec


Epoch 79/150: 100%|██████████| 210/210 [00:21<00:00,  9.97batch/s, loss=0.283]


Epoch [79/150], Loss: 0.2834, Time: 21.06 sec


Epoch 80/150: 100%|██████████| 210/210 [00:21<00:00,  9.96batch/s, loss=0.27] 


Epoch [80/150], Loss: 0.2699, Time: 21.08 sec


Epoch 81/150: 100%|██████████| 210/210 [00:20<00:00, 10.25batch/s, loss=0.257]


Epoch [81/150], Loss: 0.2569, Time: 20.50 sec


Epoch 82/150: 100%|██████████| 210/210 [00:20<00:00, 10.06batch/s, loss=0.241]


Epoch [82/150], Loss: 0.2412, Time: 20.89 sec


Epoch 83/150: 100%|██████████| 210/210 [00:21<00:00,  9.79batch/s, loss=0.228]


Epoch [83/150], Loss: 0.2280, Time: 21.44 sec


Epoch 84/150: 100%|██████████| 210/210 [00:21<00:00,  9.75batch/s, loss=0.216]


Epoch [84/150], Loss: 0.2155, Time: 21.54 sec


Epoch 85/150: 100%|██████████| 210/210 [00:22<00:00,  9.37batch/s, loss=0.204]


Epoch [85/150], Loss: 0.2043, Time: 22.42 sec


Epoch 86/150: 100%|██████████| 210/210 [00:21<00:00,  9.81batch/s, loss=0.2]  


Epoch [86/150], Loss: 0.1997, Time: 21.41 sec


Epoch 87/150: 100%|██████████| 210/210 [00:22<00:00,  9.36batch/s, loss=0.186]


Epoch [87/150], Loss: 0.1858, Time: 22.44 sec


Epoch 88/150: 100%|██████████| 210/210 [00:24<00:00,  8.68batch/s, loss=0.174]


Epoch [88/150], Loss: 0.1741, Time: 24.19 sec


Epoch 89/150: 100%|██████████| 210/210 [00:23<00:00,  8.90batch/s, loss=0.165]


Epoch [89/150], Loss: 0.1653, Time: 23.60 sec


Epoch 90/150: 100%|██████████| 210/210 [00:23<00:00,  8.87batch/s, loss=0.154]


Epoch [90/150], Loss: 0.1542, Time: 23.69 sec


Epoch 91/150: 100%|██████████| 210/210 [00:23<00:00,  8.80batch/s, loss=0.141]


Epoch [91/150], Loss: 0.1409, Time: 23.87 sec


Epoch 92/150: 100%|██████████| 210/210 [00:23<00:00,  8.86batch/s, loss=0.138]


Epoch [92/150], Loss: 0.1375, Time: 23.70 sec


Epoch 93/150: 100%|██████████| 210/210 [00:23<00:00,  8.75batch/s, loss=0.128]


Epoch [93/150], Loss: 0.1284, Time: 24.00 sec


Epoch 94/150: 100%|██████████| 210/210 [00:24<00:00,  8.65batch/s, loss=0.121]


Epoch [94/150], Loss: 0.1212, Time: 24.29 sec


Epoch 95/150: 100%|██████████| 210/210 [00:23<00:00,  8.81batch/s, loss=0.116]


Epoch [95/150], Loss: 0.1156, Time: 23.84 sec


Epoch 96/150: 100%|██████████| 210/210 [00:24<00:00,  8.67batch/s, loss=0.114]


Epoch [96/150], Loss: 0.1137, Time: 24.22 sec


Epoch 97/150: 100%|██████████| 210/210 [00:23<00:00,  9.08batch/s, loss=0.146]


Epoch [97/150], Loss: 0.1458, Time: 23.13 sec


Epoch 98/150: 100%|██████████| 210/210 [00:22<00:00,  9.39batch/s, loss=0.174]


Epoch [98/150], Loss: 0.1736, Time: 22.37 sec


Epoch 99/150: 100%|██████████| 210/210 [00:21<00:00,  9.56batch/s, loss=0.111]


Epoch [99/150], Loss: 0.1106, Time: 21.97 sec


Epoch 100/150: 100%|██████████| 210/210 [00:21<00:00,  9.69batch/s, loss=0.0852]


Epoch [100/150], Loss: 0.0852, Time: 21.67 sec


Epoch 101/150: 100%|██████████| 210/210 [00:21<00:00,  9.83batch/s, loss=0.0786]


Epoch [101/150], Loss: 0.0786, Time: 21.36 sec


Epoch 102/150: 100%|██████████| 210/210 [00:21<00:00,  9.99batch/s, loss=0.0705]


Epoch [102/150], Loss: 0.0705, Time: 21.02 sec


Epoch 103/150: 100%|██████████| 210/210 [00:20<00:00, 10.11batch/s, loss=0.0669]


Epoch [103/150], Loss: 0.0669, Time: 20.78 sec


Epoch 104/150: 100%|██████████| 210/210 [00:20<00:00, 10.06batch/s, loss=0.0654]


Epoch [104/150], Loss: 0.0654, Time: 20.88 sec


Epoch 105/150: 100%|██████████| 210/210 [00:21<00:00,  9.97batch/s, loss=0.0624]


Epoch [105/150], Loss: 0.0624, Time: 21.07 sec


Epoch 106/150: 100%|██████████| 210/210 [00:20<00:00, 10.03batch/s, loss=0.0619]


Epoch [106/150], Loss: 0.0619, Time: 20.94 sec


Epoch 107/150: 100%|██████████| 210/210 [00:20<00:00, 10.03batch/s, loss=0.0731]


Epoch [107/150], Loss: 0.0731, Time: 20.94 sec


Epoch 108/150: 100%|██████████| 210/210 [00:21<00:00, 10.00batch/s, loss=0.107] 


Epoch [108/150], Loss: 0.1071, Time: 21.01 sec


Epoch 109/150: 100%|██████████| 210/210 [00:21<00:00,  9.96batch/s, loss=0.131]


Epoch [109/150], Loss: 0.1312, Time: 21.08 sec


Epoch 110/150: 100%|██████████| 210/210 [00:21<00:00,  9.74batch/s, loss=0.0735]


Epoch [110/150], Loss: 0.0735, Time: 21.56 sec


Epoch 111/150: 100%|██████████| 210/210 [00:21<00:00,  9.79batch/s, loss=0.0534]


Epoch [111/150], Loss: 0.0534, Time: 21.46 sec


Epoch 112/150: 100%|██████████| 210/210 [00:21<00:00,  9.82batch/s, loss=0.0449]


Epoch [112/150], Loss: 0.0449, Time: 21.39 sec


Epoch 113/150: 100%|██████████| 210/210 [00:21<00:00,  9.91batch/s, loss=0.0415]


Epoch [113/150], Loss: 0.0415, Time: 21.19 sec


Epoch 114/150: 100%|██████████| 210/210 [00:20<00:00, 10.03batch/s, loss=0.0395]


Epoch [114/150], Loss: 0.0395, Time: 20.95 sec


Epoch 115/150: 100%|██████████| 210/210 [00:20<00:00, 10.08batch/s, loss=0.0444]


Epoch [115/150], Loss: 0.0444, Time: 20.83 sec


Epoch 116/150: 100%|██████████| 210/210 [00:20<00:00, 10.10batch/s, loss=0.0489]


Epoch [116/150], Loss: 0.0489, Time: 20.79 sec


Epoch 117/150: 100%|██████████| 210/210 [00:20<00:00, 10.09batch/s, loss=0.046] 


Epoch [117/150], Loss: 0.0460, Time: 20.82 sec


Epoch 118/150: 100%|██████████| 210/210 [00:20<00:00, 10.09batch/s, loss=0.0426]


Epoch [118/150], Loss: 0.0426, Time: 20.81 sec


Epoch 119/150: 100%|██████████| 210/210 [00:20<00:00, 10.05batch/s, loss=0.041] 


Epoch [119/150], Loss: 0.0410, Time: 20.91 sec


Epoch 120/150: 100%|██████████| 210/210 [00:20<00:00, 10.05batch/s, loss=0.0582]


Epoch [120/150], Loss: 0.0582, Time: 20.90 sec


Epoch 121/150: 100%|██████████| 210/210 [00:21<00:00, 10.00batch/s, loss=0.249]


Epoch [121/150], Loss: 0.2486, Time: 21.01 sec


Epoch 122/150: 100%|██████████| 210/210 [00:21<00:00,  9.99batch/s, loss=0.0844]


Epoch [122/150], Loss: 0.0844, Time: 21.03 sec


Epoch 123/150: 100%|██████████| 210/210 [00:21<00:00,  9.90batch/s, loss=0.0407]


Epoch [123/150], Loss: 0.0407, Time: 21.22 sec


Epoch 124/150: 100%|██████████| 210/210 [00:21<00:00,  9.79batch/s, loss=0.0319]


Epoch [124/150], Loss: 0.0319, Time: 21.46 sec


Epoch 125/150: 100%|██████████| 210/210 [00:21<00:00,  9.74batch/s, loss=0.0295]


Epoch [125/150], Loss: 0.0295, Time: 21.57 sec


Epoch 126/150: 100%|██████████| 210/210 [00:21<00:00,  9.62batch/s, loss=0.0278]


Epoch [126/150], Loss: 0.0278, Time: 21.84 sec


Epoch 127/150: 100%|██████████| 210/210 [00:21<00:00,  9.57batch/s, loss=0.0267]


Epoch [127/150], Loss: 0.0267, Time: 21.94 sec


Epoch 128/150: 100%|██████████| 210/210 [00:22<00:00,  9.48batch/s, loss=0.0261]


Epoch [128/150], Loss: 0.0261, Time: 22.16 sec


Epoch 129/150: 100%|██████████| 210/210 [00:22<00:00,  9.44batch/s, loss=0.0257]


Epoch [129/150], Loss: 0.0257, Time: 22.26 sec


Epoch 130/150: 100%|██████████| 210/210 [00:22<00:00,  9.39batch/s, loss=0.0257]


Epoch [130/150], Loss: 0.0257, Time: 22.36 sec


Epoch 131/150: 100%|██████████| 210/210 [00:22<00:00,  9.33batch/s, loss=0.0254]


Epoch [131/150], Loss: 0.0254, Time: 22.51 sec


Epoch 132/150: 100%|██████████| 210/210 [00:22<00:00,  9.27batch/s, loss=0.0271]


Epoch [132/150], Loss: 0.0271, Time: 22.67 sec


Epoch 133/150: 100%|██████████| 210/210 [00:22<00:00,  9.47batch/s, loss=0.0276]


Epoch [133/150], Loss: 0.0276, Time: 22.18 sec


Epoch 134/150: 100%|██████████| 210/210 [00:22<00:00,  9.46batch/s, loss=0.047] 


Epoch [134/150], Loss: 0.0470, Time: 22.20 sec


Epoch 135/150: 100%|██████████| 210/210 [00:21<00:00,  9.57batch/s, loss=0.327]


Epoch [135/150], Loss: 0.3266, Time: 21.94 sec


Epoch 136/150: 100%|██████████| 210/210 [00:21<00:00,  9.63batch/s, loss=0.0836]


Epoch [136/150], Loss: 0.0836, Time: 21.81 sec


Epoch 137/150: 100%|██████████| 210/210 [00:21<00:00,  9.67batch/s, loss=0.0373]


Epoch [137/150], Loss: 0.0373, Time: 21.72 sec


Epoch 138/150: 100%|██████████| 210/210 [00:21<00:00,  9.68batch/s, loss=0.0275]


Epoch [138/150], Loss: 0.0275, Time: 21.70 sec


Epoch 139/150: 100%|██████████| 210/210 [00:21<00:00,  9.60batch/s, loss=0.0245]


Epoch [139/150], Loss: 0.0245, Time: 21.87 sec


Epoch 140/150: 100%|██████████| 210/210 [00:21<00:00,  9.63batch/s, loss=0.0223]


Epoch [140/150], Loss: 0.0223, Time: 21.81 sec


Epoch 141/150: 100%|██████████| 210/210 [00:21<00:00,  9.64batch/s, loss=0.0218]


Epoch [141/150], Loss: 0.0218, Time: 21.78 sec


Epoch 142/150: 100%|██████████| 210/210 [00:21<00:00,  9.72batch/s, loss=0.0209]


Epoch [142/150], Loss: 0.0209, Time: 21.60 sec


Epoch 143/150: 100%|██████████| 210/210 [00:21<00:00,  9.81batch/s, loss=0.0208]


Epoch [143/150], Loss: 0.0208, Time: 21.40 sec


Epoch 144/150: 100%|██████████| 210/210 [00:21<00:00,  9.78batch/s, loss=0.0208]


Epoch [144/150], Loss: 0.0208, Time: 21.48 sec


Epoch 145/150: 100%|██████████| 210/210 [00:21<00:00,  9.76batch/s, loss=0.0209]


Epoch [145/150], Loss: 0.0209, Time: 21.51 sec


Epoch 146/150: 100%|██████████| 210/210 [00:21<00:00,  9.76batch/s, loss=0.0208]


Epoch [146/150], Loss: 0.0208, Time: 21.52 sec


Epoch 147/150: 100%|██████████| 210/210 [00:21<00:00,  9.80batch/s, loss=0.0242]


Epoch [147/150], Loss: 0.0242, Time: 21.43 sec


Epoch 148/150: 100%|██████████| 210/210 [00:21<00:00,  9.89batch/s, loss=0.027] 


Epoch [148/150], Loss: 0.0270, Time: 21.23 sec


Epoch 149/150: 100%|██████████| 210/210 [00:21<00:00,  9.95batch/s, loss=0.192] 


Epoch [149/150], Loss: 0.1915, Time: 21.11 sec


Epoch 150/150: 100%|██████████| 210/210 [00:20<00:00, 10.00batch/s, loss=0.14] 

Epoch [150/150], Loss: 0.1399, Time: 20.99 sec
Training completed.





In [93]:
# Function to generate lyrics with temperature sampling and formatting
def generate_lyrics(model, start_text, length, temperature=1.0):
    model.eval()
    generated = start_text
    input_sequence = [word_to_index.get(word, word_to_index[unk_token]) for word in start_text.split()]
    generated_words = start_text.split()
    for _ in range(length):
        input_tensor = torch.tensor(input_sequence).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(input_tensor)
        output = output.squeeze().div(temperature).exp()
        word_weights = output.cpu().numpy()
        next_word_idx = np.random.choice(len(word_weights), p=word_weights/word_weights.sum())
        next_word = index_to_word[next_word_idx]
        generated_words.append(next_word)
        input_sequence.append(next_word_idx)
        input_sequence = input_sequence[1:]

    return generated_words

In [94]:
# Function to format lyrics into verses and choruses
def format_lyrics(generated_words):
    formatted_lyrics = "\nVerse 1:\n"
    line_length = 0
    verse_count = 1
    chorus_count = 1
    is_chorus = False

    for word in generated_words:
        formatted_lyrics += word + " "
        line_length += 1
        if line_length >= 7:
            formatted_lyrics = formatted_lyrics.strip() + "\n"
            line_length = 0
            if is_chorus:
                chorus_count += 1
                if chorus_count > 4:
                    is_chorus = False
                    chorus_count = 1
                    verse_count += 1
                    formatted_lyrics += f"\nVerse {verse_count}:\n"
            else:
                verse_count += 1
                if verse_count > 4:
                    is_chorus = True
                    verse_count = 1
                    formatted_lyrics += f"\nChorus:\n"

    return formatted_lyrics.strip()

# Example usage for lyrics generation
start_text = "Once upon a time"
generated_words = generate_lyrics(model, start_text, 72, temperature=1.2)
formatted_lyrics = format_lyrics(generated_words)
print("Generated lyrics:\n", formatted_lyrics)

Generated lyrics:
 Verse 1:
Once upon a time break anyone away
oh you had a sound like me
when let what that i i sweet
ta been to again had but away

Chorus:
i never did <UNK> to just you
and me little running soul shine thinking
knowing ll say is she we talk
of a dance need first face but

Verse 2:
we saw everyone is we ll gon
gon together you see the feeling every
time his way on a <UNK>
