In [None]:
import os
import re
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler
import sentencepiece as spm
from sklearn.model_selection import train_test_split

mainPath = "/virtual/akhtar79/"

## Preprocess Song Lyrics

The code below provides a function to clean lyrics by removing annotations (e.g. [Chorus], [Verse 1], etc), parenthesis, quotation marks, extra whitespace and newlines, non-ASCII characters, and other special characters.

In [None]:
def remove_bracketed_annotations(text):
    text = re.sub(r"\[.*?\]", "", text)
    text = re.sub(r"\(.*?\)", "", text)
    return text

def remove_quotation_marks(text):
    return re.sub(r'["""\']', '', text)

def convert_newlines_to_v_tokens(text):
    text = re.sub(r"\n{2,}", "\n", text)
    text = re.sub(r"\n+", " <V> ", text)
    return text

def normalize_whitespace(text):
    return re.sub(r"\s+", " ", text).strip()

def remove_non_ascii(text):
    return text.encode("ascii", errors="ignore").decode()

def remove_standalone_punctuation(text):
    text = re.sub(r'\s+([.,!?;:])\s+', r' \1 ', text)
    return re.sub(r'\s+([\'"])\s+', ' ', text)

def remove_duplicate_punctuation(text):
    text = re.sub(r'([.,!?;:]){2,}', r'\1', text)
    return text

def remove_special_characters(text):
    return re.sub(r'[(){}\[\]<>]', '', text)


def clean_lyrics(text):
    """Clean and normalize song lyrics by removing annotations, special characters, and normalizing whitespace."""
    if pd.isna(text):
        return ""
    
    text = remove_bracketed_annotations(text)
    text = remove_quotation_marks(text)
    text = convert_newlines_to_v_tokens(text)
    text = normalize_whitespace(text)
    text = remove_non_ascii(text)
    text = text.lower()
    text = remove_standalone_punctuation(text)
    text = remove_duplicate_punctuation(text)
    text = remove_special_characters(text)
    text = normalize_whitespace(text)
    
    return text

The code below filters the songs dataset to retain only English songs and specific genres. It also cleans the lyrics using the function above and removes lyrics that are too long or too short. Then, it adds a genre token to each lyric and saves to a CSV 'english_cleaned_songs.csv'.

In [None]:
def filter_by_language_and_genre(chunk):
    allowed_genres = ["rap", "pop", "rock", "rb", "country"]
    filtered = chunk[(chunk["language"] == "en") & (chunk["tag"] != "misc")].copy()
    filtered = filtered[filtered["tag"].isin(allowed_genres)]
    return filtered


def apply_lyrics_cleaning(chunk):
    chunk["clean_lyrics"] = chunk["lyrics"].apply(clean_lyrics)
    return chunk


def filter_lyrics_by_word_count(chunk, min_words, max_words):
    chunk["word_count"] = chunk["clean_lyrics"].str.split().str.len()
    chunk = chunk[
        (chunk["word_count"] >= min_words) &
        (chunk["word_count"] <= max_words)
    ]
    chunk = chunk.drop(columns=["word_count"])
    return chunk


def add_genre_token_to_lyrics(chunk):
    chunk["clean_lyrics"] = "<" + chunk["tag"].str.upper() + "> " + chunk["clean_lyrics"]
    return chunk


def save_chunk_to_csv(chunk, output_csv, is_first_chunk):
    if is_first_chunk:
        chunk.to_csv(output_csv, index=False, mode='w')
    else:
        chunk.to_csv(output_csv, index=False, mode='a', header=False)


def clean_and_filter_lyrics_dataset(
    input_csv='/content/genius_dataset/song_lyrics.csv',
    output_csv='english_cleaned_songs.csv',
    min_words=50,
    max_words=1000
):
    """Process lyrics dataset in chunks: filter by language/genre, clean text, add genre tokens, and save to CSV."""

    chunksize = 100000
    is_first_chunk = True
    chunks_processed = 0
    
    print("Processing dataset in chunks...")
    
    for chunk in pd.read_csv(input_csv, chunksize=chunksize):
        chunk = filter_by_language_and_genre(chunk)
        chunk = apply_lyrics_cleaning(chunk)
        chunk = filter_lyrics_by_word_count(chunk, min_words, max_words)
        chunk = add_genre_token_to_lyrics(chunk)
        
        save_chunk_to_csv(chunk, output_csv, is_first_chunk)
        
        is_first_chunk = False
        chunks_processed += 1
        print(f"Processed {chunks_processed} chunk(s)")
    
    print("Cleaning complete!")

Balance the dataset so that we have equal number of songs per genre:

In [None]:
def balance_genres(input_csv="english_cleaned_songs.csv", 
                   output_csv="english_cleaned_reduced.csv", 
                   max_per_genre=85000):
    """
    Balance dataset to have equal samples per genre
    """
    allowed_genres = ["rap", "pop", "rock", "rb", "country"]
    counts = {genre: 0 for genre in allowed_genres}
    chunksize = 100000
    first_write = True
    
    print("Balancing genres...")
    for chunk in pd.read_csv(input_csv, chunksize=chunksize):
        chunk = chunk[chunk["tag"].isin(allowed_genres)]
        chunk = chunk[chunk["tag"].apply(lambda x: counts[x] < max_per_genre)]
        
        if chunk.empty:
            continue
        
        sampled_chunks = []
        for genre, group in chunk.groupby("tag"):
            remaining = max_per_genre - counts[genre]
            if len(group) > remaining:
                group = group.sample(n=remaining, random_state=42)
            counts[genre] += len(group)
            sampled_chunks.append(group)
        
        final_chunk = pd.concat(sampled_chunks)
        final_chunk.to_csv(output_csv, mode='w' if first_write else 'a', 
                          index=False, header=first_write)
        first_write = False
        
        if all(count >= max_per_genre for count in counts.values()):
            break
    
    print(f"Balanced dataset saved: {output_csv}")
    print("Final counts per genre:", counts)

Create training and validation datasets:

In [None]:
def create_train_val_splits(csv_file="english_cleaned_reduced.csv"):
    """
    Split data into train and val sets
    """
    df = pd.read_csv(csv_file)
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    
    train, val = train_test_split(df, test_size=0.2, random_state=42)
    
    train["clean_lyrics"].to_csv(mainPath+"train.txt", index=False, header=False)
    val["clean_lyrics"].to_csv(mainPath+"val.txt", index=False, header=False)
    
    print(f"Train: {len(train)}, Val: {len(val)}")

## Tokenization

In [None]:
def train_tokenizer(vocab_size=10000):
    """
    Train SentencePiece tokenizer with specified vocabulary size
    """
    print(f"Training tokenizer with vocab_size={vocab_size}...")
    os.remove("lyric_tokenizer.model") if os.path.exists("lyric_tokenizer.model") else None
    os.remove("lyric_tokenizer.vocab") if os.path.exists("lyric_tokenizer.vocab") else None
    user_defined_symbols = ["<RAP>", "<POP>", "<ROCK>", "<RB>", "<COUNTRY>", "<V>"]
    
    spm.SentencePieceTrainer.Train(
        input=mainPath+'train.txt',
        model_prefix='lyric_tokenizer',
        vocab_size=vocab_size,
        model_type='bpe',
        character_coverage=1.0,
        user_defined_symbols=user_defined_symbols,
        pad_id=0,
        unk_id=1,
        bos_id=2,
        eos_id=3,
        pad_piece='<pad>',
        unk_piece='<unk>',
        bos_piece='<s>',
        eos_piece='</s>',
        normalization_rule_name='nmt_nfkc',
        remove_extra_whitespaces=True,
        split_by_whitespace=True,
        split_by_number=False,
        byte_fallback=False,
        treat_whitespace_as_suffix=False,
        allow_whitespace_only_pieces=False,
        max_sentence_length=4192,
        num_threads=16
    )
    
    print("Tokenizer training complete!")

In [None]:
def tokenize_files(output_dir="./"):
    """
    Tokenize train and val files
    """
    sp = spm.SentencePieceProcessor()
    sp.load("lyric_tokenizer.model")
    
    def tokenize_file(input_file, output_file):
        token_ids_list = []
        with open(input_file, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if len(line) == 0:
                    continue
                token_ids = sp.encode(line, out_type=int)
                token_ids_list.append(token_ids)
        
        np.save(output_file, np.array(token_ids_list, dtype=object))
        print(f"Saved {output_file}, total sequences: {len(token_ids_list)}")
    
    tokenize_file(mainPath+"train.txt", os.path.join(output_dir, "train_tokens.npy"))
    tokenize_file(mainPath+"val.txt", os.path.join(output_dir, "val_tokens.npy"))


## Lyrics Dataset

In [None]:
class LyricsDataset(Dataset):
    """
    Create pairs of input-output sequences with genre labels for lyrics generation
    """
    def __init__(self, token_list, sp, seq_len, stride=None):
        self.samples = []
        self.seq_len = seq_len
        self.stride = stride or seq_len
        
        self.sp_genre_to_idx = {
            sp.piece_to_id('<RAP>'): 0,
            sp.piece_to_id('<POP>'): 1,
            sp.piece_to_id('<ROCK>'): 2,
            sp.piece_to_id('<RB>'): 3,
            sp.piece_to_id('<COUNTRY>'): 4
        }

        
        for song in token_list:
            if len(song) < 3: 
                continue
            
            # Second token is genre ID (after <s>)
            genre_sp_id = song[1]
            genre_idx = self.sp_genre_to_idx.get(genre_sp_id)
            
            if genre_idx is None:
                continue
            
            # Remove genre tokens from sequence (keep only lyrics)
            song_content = song[2:] 
            L = len(song_content)
            
            for i in range(0, L - seq_len, self.stride):
                chunk = song_content[i:i + seq_len + 1]
                
                if len(chunk) == seq_len + 1:
                    x = chunk[:-1]
                    y = chunk[1:]
                    self.samples.append((x, y, genre_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        x, y, g = self.samples[idx]
        return (
            torch.tensor(x, dtype=torch.long),
            torch.tensor(y, dtype=torch.long),
            torch.tensor(g, dtype=torch.long)
        )

## LSTM Model

In [None]:
class LSTM_LyricGenerator(nn.Module):
    """
    LSTM-based lyrics generation model with genre conditioning
    """
    def __init__(self, vocab_size, genre_size,
                 embed_dim=384, genre_embed_dim=64,
                 hidden_size=768, num_layers=3, dropout=0.2):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.genre_embedding = nn.Embedding(genre_size, genre_embed_dim)
        
        self.embed_norm = nn.LayerNorm(embed_dim + genre_embed_dim)
        
        self.lstm = nn.LSTM(
            embed_dim + genre_embed_dim,
            hidden_size,
            num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        
        self.output_norm = nn.LayerNorm(hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
        self._init_weights()
    
    def _init_weights(self):
        """
        Initialize model weights
        """
        for name, param in self.named_parameters():
            if 'weight' in name and 'norm' not in name:
                if len(param.shape) >= 2:
                    nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)
    
    def forward(self, x, genre_id, hidden=None):
        """
        Forward pass of the model.
        """
        batch_size, seq_len = x.shape
        
        word_embed = self.embedding(x)
        genre_embed = self.genre_embedding(genre_id).unsqueeze(1).repeat(1, seq_len, 1)
        
        lstm_input = torch.cat([word_embed, genre_embed], dim=2)
        lstm_input = self.embed_norm(lstm_input)
        
        lstm_output, hidden = self.lstm(lstm_input, hidden)
        
        lstm_output = self.output_norm(lstm_output)
        logits = self.fc(lstm_output)
        
        return logits, hidden

## Training Model

Train one epoch:

In [None]:
def compute_loss(model, x, y, genre, criterion, accumulation_steps):
    """Compute loss with mixed precision."""
    with autocast():
        logits, _ = model(x, genre)
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            y.reshape(-1)
        )
        return loss / accumulation_steps


def update_model_weights(optimizer, scaler, model):
    """Update model parameters with gradient clipping."""
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()


def train_one_epoch(model, train_loader, criterion, optimizer, scaler, device, 
                    accumulation_steps=4):
    """Train model for one epoch."""
    model.train()
    total_loss = 0
    total_batches = len(train_loader)
    
    optimizer.zero_grad()
    
    for i, (x, y, genre) in enumerate(train_loader):
        x, y, genre = x.to(device), y.to(device), genre.to(device)
        
        loss = compute_loss(model, x, y, genre, criterion, accumulation_steps)
        scaler.scale(loss).backward()
        
        if (i + 1) % accumulation_steps == 0:
            update_model_weights(optimizer, scaler, model)
        
        total_loss += loss.item() * accumulation_steps
        
        if (i + 1) % 500 == 0:
            avg_loss = total_loss / (i + 1)
            percent = (i + 1) / total_batches * 100
            print(f"  Batch {i+1}/{total_batches} ({percent:.1f}%) - Loss: {avg_loss:.4f}")
    
    return total_loss / len(train_loader)

Evaluate one epoch:

In [None]:
def eval_one_epoch(model, val_loader, criterion, device):
    """
    Validation loop with perplexity calculation
    """
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for x, y, genre in val_loader:
            x, y, genre = x.to(device), y.to(device), genre.to(device)
            
            logits, _ = model(x, genre)
            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                y.reshape(-1)
            )
            total_loss += loss.item()
    
    avg_loss = total_loss / len(val_loader)
    perplexity = math.exp(min(avg_loss, 10)) 
    
    return avg_loss, perplexity

Train Model:

In [None]:
def save_best_model(epoch, model, optimizer, val_loss, val_perplexity, save_path):
    """Save the best model checkpoint."""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_perplexity': val_perplexity
    }, save_path)
    print("Saved best model")


def save_periodic_checkpoint(epoch, model, optimizer, save_path):
    """Save periodic training checkpoint."""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, save_path)


def should_stop_early(patience_counter, patience):
    """Check if early stopping criteria is met."""
    return patience_counter >= patience


def train_model(model, train_loader, val_loader, device, epochs=80, 
                lr=0.001, weight_decay=1e-5):
    """Training loop with early stopping and checkpointing."""
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    scaler = GradScaler()

    best_val_loss = float('inf')
    patience = 5
    patience_counter = 0
    
    print("\n" + "="*70)
    print("Starting Training")
    print("="*70)
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        print("-" * 70)
        
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
        val_loss, val_perplexity = eval_one_epoch(model, val_loader, criterion, device)
        
        scheduler.step(val_loss)
        print(f"\n  Train Loss:      {train_loss:.4f}")
        print(f"  Val Loss:        {val_loss:.4f}")
        print(f"  Val Perplexity:  {val_perplexity:.2f}")
        print(f"  Learning Rate:   {optimizer.param_groups[0]['lr']:.6f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            save_best_model(epoch, model, optimizer, val_loss, val_perplexity, 
                          mainPath + "best_model.pt")
        else:
            patience_counter += 1
            if should_stop_early(patience_counter, patience):
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
        
        if (epoch + 1) % 10 == 0:
            save_periodic_checkpoint(epoch, model, optimizer, 
                                    f"{mainPath}checkpoint_epoch_{epoch+1}.pt")
    
    print("\n" + "="*70)
    print("Training Complete!")
    print("="*70)

## Generation of Lyrics

In [None]:
def convert_v_tokens_to_newlines(text):
    return re.sub(r'\s*<v>\s*', '\n', text)


def remove_empty_lines(text):
    lines = [line.strip() for line in text.split('\n') if line.strip()]
    return '\n'.join(lines)


def clean_generated_text(text):
    """Post-process generated lyrics text."""
    text = remove_standalone_punctuation(text)
    text = normalize_whitespace(text)
    text = remove_special_characters(text)
    text = remove_duplicate_punctuation(text)
    text = remove_quotation_marks(text)
    text = convert_v_tokens_to_newlines(text)
    text = remove_empty_lines(text)
    
    return text.strip()

To generate lyrics, we can pass in the first few words of a song as well as the genre:

In [None]:
def get_genre_id(genre_name):
    genre_map = {"rap": 0, "pop": 1, "rock": 2, "rb": 3, "country": 4}
    return genre_map[genre_name.lower()]


def prepare_initial_tokens(genre_name, prompt, sp, device):
    initial_text = f"<{genre_name.upper()}>"
    if prompt:
        initial_text += f" {prompt}"
    
    tokens = sp.encode(initial_text, out_type=int)
    tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
    return tokens


def apply_repetition_penalty(logits, generated_tokens, recent_tokens, penalty):
    for token in set(generated_tokens[-50:]):
        logits[token] /= penalty
    
    if len(recent_tokens) >= 20:
        for token in recent_tokens[-10:]:
            if recent_tokens[-20:].count(token) > 3:
                logits[token] /= 2.0
    
    return logits


def apply_top_k_filtering(logits, top_k):
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = float('-inf')
    return logits


def apply_top_p_filtering(logits, top_p):
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = float('-inf')
    
    return logits


def sample_next_token(logits):
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, 1).item()


def should_stop_generation(token, sp):
    return token == sp.eos_id() or token == 0


def generate_lyrics(model, sp, prompt="", genre_name="rap", max_len=250, 
                   temperature=0.7, top_k=40, top_p=0.85, repetition_penalty=1.2,
                   device='cuda'):
    """Generate lyrics with top-k/top-p sampling and repetition penalty."""
    model.eval()
    
    genre_id = get_genre_id(genre_name)
    tokens = prepare_initial_tokens(genre_name, prompt, sp, device)
    
    genre_tensor = torch.tensor([genre_id], dtype=torch.long).to(device)
    generated = tokens.tolist()[0]
    recent_tokens = []
    
    with torch.no_grad():
        hidden = None
        
        for _ in range(max_len):
            context = tokens[:, -128:]
            
            logits, hidden = model(context, genre_tensor, hidden)
            next_token_logits = logits[0, -1, :] / temperature
            
            next_token_logits = apply_repetition_penalty(
                next_token_logits, generated, recent_tokens, repetition_penalty
            )
            next_token_logits = apply_top_k_filtering(next_token_logits, top_k)
            next_token_logits = apply_top_p_filtering(next_token_logits, top_p)
            
            next_token = sample_next_token(next_token_logits)
            
            if should_stop_generation(next_token, sp):
                break
            
            generated.append(next_token)
            recent_tokens.append(next_token)
            tokens = torch.cat([tokens, torch.tensor([[next_token]]).to(device)], dim=1)
    
    raw_text = sp.decode(generated)
    cleaned_text = clean_generated_text(raw_text)
    
    return cleaned_text

## Main Execution Pipeline

In [None]:
def main():
    """
    Complete pipeline from data preprocessing to model training to generating lyrics samples.
    """
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    print("Preprocessing Data")
    clean_and_filter_lyrics_dataset(
        input_csv=mainPath+'song_lyrics.csv',
        output_csv=mainPath+'english_cleaned_songs.csv'
    )
    balance_genres(
        input_csv=mainPath+"english_cleaned_songs.csv",
        output_csv=mainPath+"english_cleaned_reduced.csv",
        max_per_genre=85000
    )

    create_train_val_splits(mainPath+"english_cleaned_reduced.csv")
    
    print("Training Tokenizer")
    train_tokenizer(vocab_size=16000)
    tokenize_files(output_dir=mainPath)

    print("Loading Data")
    sp = spm.SentencePieceProcessor()
    sp.load("lyric_tokenizer.model")
    
    train_tokens = np.load(mainPath+"train_tokens.npy", allow_pickle=True)
    val_tokens = np.load(mainPath+"val_tokens.npy", allow_pickle=True)
    
    print(f"Loaded {len(train_tokens)} training songs")
    print(f"Loaded {len(val_tokens)} validation songs")
    print(f"Vocabulary size: {sp.get_piece_size()}")
    
    print("Creating Datasets")
    seq_len = 128
    
    train_dataset = LyricsDataset(train_tokens, sp, seq_len=seq_len, stride=64)
    val_dataset = LyricsDataset(val_tokens, sp, seq_len=seq_len, stride=seq_len)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=8,
        persistent_workers=True,
        prefetch_factor=4,
        pin_memory=True if DEVICE == "cuda" else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=128,
        shuffle=False,
        num_workers=8,
        persistent_workers=True,
        prefetch_factor=4,
        pin_memory=True if DEVICE == "cuda" else False
    )
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    print("Initializing Model")
    model = LSTM_LyricGenerator(
        vocab_size=sp.get_piece_size(),
        genre_size=5,
        embed_dim=384,
        genre_embed_dim=64,
        hidden_size=768,
        num_layers=3,
        dropout=0.2
    ).to(DEVICE)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    print("Training Model")
    train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=DEVICE,
        epochs=70,
        lr=0.001,
        weight_decay=1e-5
    )
     
    print("Generating Sample Lyrics")
    checkpoint = torch.load(mainPath+"best_model.pt", map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
    print(f"Best validation loss: {checkpoint['val_loss']:.4f}")
    print(f"Best validation perplexity: {checkpoint['val_perplexity']:.2f}")
    
    genres = ["rap", "pop", "rock", "rb", "country"]
    
    for genre in genres:
        print(f"\n{'='*70}")
        print(f"Generated {genre.upper()} Lyrics:")
        print('='*70)
        
        lyrics = generate_lyrics(
            model=model,
            sp=sp,
            prompt="",
            genre_name=genre,
            max_len=200,
            temperature=0.7,
            top_k=40,
            top_p=0.85,
            repetition_penalty=1.2,
            device=DEVICE
        )
        
        print(lyrics)
        print()


if __name__ == "__main__":
    main()

## Usage Examples

Generate lyrics after training:

In [None]:
sp = spm.SentencePieceProcessor()
sp.load("lyric_tokenizer.model")

model = LSTM_LyricGenerator(vocab_size=sp.get_piece_size(), genre_size=5)
checkpoint = torch.load("best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

lyrics = generate_lyrics(
    model=model,
    sp=sp,
    prompt="I am the",
    genre_name="rap",
    max_len=300
)
print(lyrics)