In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import numpy as np
from torch.optim import AdamW
from collections import Counter, defaultdict
import re
from tqdm import tqdm
import random

# Enhanced Tokenizer with Markov Chain for Better Word Prediction
class EnhancedTokenizer:
    def __init__(self, vocab_size=10000):
        self.vocab_size = vocab_size
        self.word_to_idx = {}
        self.idx_to_word = {}
        self.pad_token_id = 0
        self.unk_token_id = 1
        self.markov_chains = defaultdict(lambda: defaultdict(int))

    def fit(self, texts):
        words = []
        for text in tqdm(texts, desc="Building vocabulary"):
            words.extend(re.findall(r'\b\w+\b', text.lower()))

        word_counts = Counter(words)
        common_words = ['<pad>', '<unk>'] + [word for word, _ in word_counts.most_common(self.vocab_size - 2)]
        self.word_to_idx = {word: idx for idx, word in enumerate(common_words)}
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}

        # Build Markov chain for transition probabilities
        for text in texts:
            word_list = re.findall(r'\b\w+\b', text.lower())
            for i in range(len(word_list) - 1):
                self.markov_chains[word_list[i]][word_list[i + 1]] += 1

        # Normalize transition probabilities
        for word in self.markov_chains:
            total = sum(self.markov_chains[word].values())
            if total > 0:
                for next_word in self.markov_chains[word]:
                    self.markov_chains[word][next_word] /= total

    def encode(self, text, max_length=512):
        words = re.findall(r'\b\w+\b', text.lower())
        # Avoid unknown tokens by filtering out words not in vocab
        ids = [self.word_to_idx.get(word, self.pad_token_id) for word in words if word in self.word_to_idx]
        if not ids:
            ids = [self.pad_token_id]

        if len(ids) < max_length:
            ids = ids + [self.pad_token_id] * (max_length - len(ids))
        else:
            ids = ids[:max_length]

        return torch.tensor(ids)

    def decode(self, ids):
        words = [self.idx_to_word.get(id.item(), '') for id in ids if id != self.pad_token_id and id in self.idx_to_word]
        return ' '.join(words)

    def suggest_next_word(self, current_word):
        if current_word in self.markov_chains and self.markov_chains[current_word]:
            next_words = list(self.markov_chains[current_word].keys())
            probs = list(self.markov_chains[current_word].values())
            return random.choices(next_words, weights=probs, k=1)[0]
        return random.choice(list(self.word_to_idx.keys())[2:])  # Exclude pad and unk

# Transformer Components
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)
        context = context.permute(0, 2, 1, 3).contiguous()
        context = context.reshape(batch_size, seq_len, d_model)
        return self.proj(context)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.dropout(self.attn(self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

# Main Model
class LyricsTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, num_heads=4, num_layers=4, d_ff=512, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.zeros(1, 512, d_model))
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_encoding[:, :seq_len]
        x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x, mask)
        return self.fc(x)

# Dataset
class LyricsDataset(Dataset):
    def __init__(self, texts, input_ids, target_ids):
        self.texts = texts
        self.input_ids = input_ids
        self.target_ids = target_ids

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'target_ids': self.target_ids[idx]
        }

# Training function with gradient accumulation
def train_model(model, train_loader, val_loader, device, num_epochs=3, gradient_accumulation_steps=4):
    optimizer = AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.98))
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        optimizer.zero_grad()

        for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            input_ids = batch['input_ids'].to(device)
            target_ids = batch['target_ids'].to(device)
            outputs = model(input_ids)
            loss = criterion(outputs.view(-1, outputs.size(-1)), target_ids.view(-1))
            loss = loss / gradient_accumulation_steps
            loss.backward()

            if (i + 1) % gradient_accumulation_steps == 0 or (i + 1) == len(train_loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item() * gradient_accumulation_steps

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                input_ids = batch['input_ids'].to(device)
                target_ids = batch['target_ids'].to(device)
                outputs = model(input_ids)
                loss = criterion(outputs.view(-1, outputs.size(-1)), target_ids.view(-1))
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        avg_train_loss = total_loss / len(train_loader)
        scheduler.step(avg_val_loss)
        print(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f}')

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
            }, 'best_lyrics_model.pth')
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")

# Generate text with repetition control and Markov guidance
def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7, device='cuda'):
    model.eval()
    input_ids = tokenizer.encode(prompt, max_length=512).unsqueeze(0).to(device)
    generated = []
    seen_words = set()
    max_seq_length = 512
    min_tokens = 10  # Ensure at least 10 tokens

    with torch.no_grad():
        for _ in range(max_length):
            if input_ids.size(1) > max_seq_length:
                input_ids = input_ids[:, -max_seq_length:]
            outputs = model(input_ids)
            next_token_logits = outputs[0, -1, :] / temperature
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, 1).to(device)

            if len(generated) >= min_tokens and next_token.item() == tokenizer.pad_token_id:
                break

            word = tokenizer.idx_to_word.get(next_token.item(), '<unk>')
            if word not in seen_words or len(generated) < min_tokens:  # Allow repeats only if below min_tokens
                generated.append(next_token.item())
                seen_words.add(word)
                next_word_suggestion = tokenizer.suggest_next_word(word)
                if next_word_suggestion and next_word_suggestion not in seen_words:
                    next_token = torch.tensor([tokenizer.word_to_idx.get(next_word_suggestion, tokenizer.pad_token_id)]).to(device)
            else:
                # Try to find an unseen word using Markov suggestion
                next_word = tokenizer.suggest_next_word(tokenizer.idx_to_word.get(generated[-1], ''))
                while next_word in seen_words and len(seen_words) < len(tokenizer.word_to_idx) - 2:
                    next_word = tokenizer.suggest_next_word(next_word)
                next_token = torch.tensor([tokenizer.word_to_idx.get(next_word, tokenizer.pad_token_id)]).to(device)

            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

    result = prompt + ' ' + tokenizer.decode(torch.tensor(generated))
    return result

# Data processing
def process_text_dataset(dataset, max_samples=500000):
    print(f"Processing up to {max_samples} samples from dataset...")
    texts = []
    counter = 0

    if hasattr(dataset, 'keys'):
        splits = list(dataset.keys())
    else:
        splits = ['dataset']
        dataset = {'dataset': dataset}

    for split in splits:
        split_data = dataset[split]
        if len(split_data) > 0:
            sample = split_data[0]
            text_fields = [key for key, value in sample.items() if isinstance(value, str) and len(value.split()) > 5]
            text_field = next((field for field in ['text', 'content', 'lyrics', 'sentence', 'article'] if field in text_fields), text_fields[0] if text_fields else None)

            if text_field:
                print(f"Using '{text_field}' as the text field")
                for item in tqdm(split_data, desc=f"Processing {split}"):
                    if text_field in item and item[text_field] and isinstance(item[text_field], str) and len(item[text_field].strip().split()) > 5:
                        texts.append(item[text_field])
                        counter += 1
                        if counter >= max_samples:
                            break
            else:
                print(f"Could not identify a text field in the dataset. Sample keys: {list(sample.keys())}")

        if counter >= max_samples:
            break

    print(f"Processed {len(texts)} texts")
    return texts

# Main function
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if device.type == 'cpu':
        print("Warning: Running on CPU. Training will be slow.")

    max_samples = 500000
    dataset_options = [
        "bookcorpus",
        "wikitext",
        "Abirate/english_quotes",
        "imdb",
        "ag_news",
        "glue/sst2"
    ]

    texts = []
    for dataset_name in dataset_options:
        try:
            print(f"Attempting to load dataset: {dataset_name}")
            dataset = load_dataset(dataset_name, split="train")
            texts = process_text_dataset(dataset, max_samples=max_samples)
            if texts:
                print(f"Successfully loaded {len(texts)} samples from {dataset_name}")
                break
        except Exception as e:
            print(f"Failed to load {dataset_name}: {e}")
            continue

    if not texts:
        print("All dataset attempts failed. Using wikitext-103-v1 as a fallback.")
        try:
            dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
            texts = [item["text"] for item in dataset if item["text"] and len(item["text"].strip()) > 50]
            if len(texts) > max_samples:
                texts = texts[:max_samples]
        except Exception as e:
            print(f"Fallback failed: {e}")
            return

    print("Initializing tokenizer...")
    tokenizer = EnhancedTokenizer(vocab_size=10000)
    tokenizer.fit(texts)

    print("Preparing datasets...")
    max_length = 512
    all_input_ids = []
    all_target_ids = []

    for text in tqdm(texts, desc="Encoding texts"):
        input_id = tokenizer.encode(text, max_length)
        target_id = input_id.clone()
        target_id[:-1] = input_id[1:]
        target_id[-1] = input_id[0]
        all_input_ids.append(input_id)
        all_target_ids.append(target_id)

    indices = list(range(len(texts)))
    train_indices, val_indices = train_test_split(indices, test_size=0.1, random_state=42)

    train_dataset = LyricsDataset(
        [texts[i] for i in train_indices],
        [all_input_ids[i] for i in train_indices],
        [all_target_ids[i] for i in train_indices]
    )
    val_dataset = LyricsDataset(
        [texts[i] for i in val_indices],
        [all_input_ids[i] for i in val_indices],
        [all_target_ids[i] for i in val_indices]
    )

    batch_size = 16 if device.type == 'cuda' else 4
    if len(train_dataset) > 100000:
        batch_size = min(batch_size, 8)
    num_workers = 2 if device.type == 'cuda' else 0

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)

    print(f"Training with batch size: {batch_size}")
    print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

    model = LyricsTransformer(
        vocab_size=10000,
        d_model=128,
        num_heads=4,
        num_layers=4,
        d_ff=512
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model size: {total_params/1e6:.2f}M parameters")

    print("Starting training...")
    train_model(
        model,
        train_loader,
        val_loader,
        device,
        num_epochs=3,
        gradient_accumulation_steps=4
    )

    torch.save({
        'model_state_dict': model.state_dict(),
        'tokenizer': tokenizer
    }, 'final_lyrics_model.pth')
    print("Training completed and model saved!")

    print("\nGenerating sample text...")
    prompt = "Write a song about love and romance of my life:"
    generated_text = generate_text(model, tokenizer, prompt, device=device)
    print(generated_text)

if __name__ == "__main__":
    main()

Using device: cuda
Attempting to load dataset: bookcorpus


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/18.5k [00:00<?, ?B/s]

bookcorpus.py:   0%|          | 0.00/3.25k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Generating train split:   0%|          | 0/74004228 [00:00<?, ? examples/s]

Processing up to 500000 samples from dataset...
Using 'text' as the text field


Processing dataset:   1%|          | 571703/74004228 [00:11<23:56, 51123.98it/s]


Processed 500000 texts
Successfully loaded 500000 samples from bookcorpus
Initializing tokenizer...


Building vocabulary: 100%|██████████| 500000/500000 [00:02<00:00, 194326.85it/s]


Preparing datasets...


Encoding texts: 100%|██████████| 500000/500000 [01:12<00:00, 6870.50it/s]


Training with batch size: 8
Training samples: 450000, Validation samples: 50000
Model size: 3.43M parameters
Starting training...


Epoch 1/3: 100%|██████████| 56250/56250 [35:54<00:00, 26.10it/s]
Validation: 100%|██████████| 6250/6250 [01:26<00:00, 72.52it/s]


Epoch 1/3 - Train Loss: 3.8106, Val Loss: 1.3019, LR: 0.000100
Saved best model with validation loss: 1.3019


Epoch 2/3: 100%|██████████| 56250/56250 [36:06<00:00, 25.97it/s]
Validation: 100%|██████████| 6250/6250 [01:26<00:00, 72.52it/s]


Epoch 2/3 - Train Loss: 0.8826, Val Loss: 0.4150, LR: 0.000100
Saved best model with validation loss: 0.4150


Epoch 3/3: 100%|██████████| 56250/56250 [36:04<00:00, 25.99it/s]
Validation: 100%|██████████| 6250/6250 [01:27<00:00, 71.80it/s]

Epoch 3/3 - Train Loss: 0.5198, Val Loss: 0.2891, LR: 0.000100
Saved best model with validation loss: 0.2891





AttributeError: Can't pickle local object 'EnhancedTokenizer.__init__.<locals>.<lambda>'

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import numpy as np
from torch.optim import AdamW
from collections import Counter, defaultdict
import re
from tqdm import tqdm
import uuid

# Common words for lyrics (curated to reduce <unk> and improve lyrical quality)
COMMON_LYRICS_WORDS = [
    '<pad>', '<unk>', 'love', 'heart', 'dream', 'night', 'day', 'life', 'time', 'way',
    'feel', 'know', 'go', 'see', 'world', 'soul', 'sky', 'star', 'moon', 'sun',
    'baby', 'darling', 'forever', 'always', 'never', 'together', 'apart', 'home', 'road', 'fire',
    'dance', 'sing', 'song', 'melody', 'rhyme', 'beat', 'rhythm', 'free', 'run', 'fly',
    'hold', 'touch', 'kiss', 'smile', 'cry', 'tears', 'pain', 'joy', 'hope', 'fear',
    'light', 'dark', 'shadow', 'shine', 'burn', 'break', 'fall', 'rise', 'stay', 'leave',
    'come', 'gone', 'back', 'memory', 'dreams', 'eyes', 'hands', 'voice', 'mind', 'spirit'
]

# Custom Tokenizer with Lyrics-Specific Vocabulary
class LyricsTokenizer:
    def __init__(self, vocab_size=10000):
        self.vocab_size = vocab_size
        self.word_to_idx = {word: idx for idx, word in enumerate(COMMON_LYRICS_WORDS)}
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}
        self.pad_token_id = 0
        self.unk_token_id = 1
        self.vocab_size = len(self.word_to_idx)

    def fit(self, texts):
        words = []
        for text in tqdm(texts, desc="Building vocabulary"):
            words.extend(re.findall(r'\b\w+\b', text.lower()))

        word_counts = Counter(words)
        additional_words = [word for word, _ in word_counts.most_common(self.vocab_size - len(COMMON_LYRICS_WORDS))]
        for word in additional_words:
            if word not in self.word_to_idx and len(self.word_to_idx) < self.vocab_size:
                self.word_to_idx[word] = len(self.word_to_idx)
                self.idx_to_word[len(self.idx_to_word)] = word
        self.vocab_size = len(self.word_to_idx)

    def encode(self, text, max_length=256):
        words = re.findall(r'\b\w+\b', text.lower())
        ids = [self.word_to_idx.get(word, self.unk_token_id) for word in words]
        if len(ids) < max_length:
            ids = ids + [self.pad_token_id] * (max_length - len(ids))
        else:
            ids = ids[:max_length]
        return torch.tensor(ids)

    def decode(self, ids):
        words = [self.idx_to_word.get(id.item(), '<unk>') for id in ids if id != self.pad_token_id]
        return ' '.join(words)

# Markov Chain for Coherence
class MarkovChain:
    def __init__(self, order=2):
        self.order = order
        self.transitions = defaultdict(list)

    def fit(self, texts, tokenizer):
        for text in tqdm(texts, desc="Building Markov chain"):
            ids = tokenizer.encode(text).tolist()
            for i in range(len(ids) - self.order):
                state = tuple(ids[i:i + self.order])
                next_token = ids[i + self.order]
                self.transitions[state].append(next_token)

    def get_next_token(self, state, vocab_size):
        state = tuple(state[-self.order:])
        if state in self.transitions and self.transitions[state]:
            return np.random.choice(self.transitions[state])
        return np.random.randint(2, vocab_size)  # Avoid <pad> and <unk>

# Transformer Components
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)
        context = context.permute(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)
        return self.proj(context)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.dropout(self.attn(self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

# Simplified Transformer Model
class LyricsTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=256, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.zeros(1, 256, d_model))
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_encoding[:, :seq_len]
        x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x, mask)
        return self.fc(x)

# Dataset
class LyricsDataset(Dataset):
    def __init__(self, texts, input_ids, target_ids):
        self.texts = texts
        self.input_ids = input_ids
        self.target_ids = target_ids

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'target_ids': self.target_ids[idx]
        }

# Training Function
def train_model(model, train_loader, val_loader, device, markov_chain, tokenizer, num_epochs=3, gradient_accumulation_steps=4):
    optimizer = AdamW(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            input_ids = batch['input_ids'].to(device)
            target_ids = batch['target_ids'].to(device)
            outputs = model(input_ids)
            loss = criterion(outputs.view(-1, outputs.size(-1)), target_ids.view(-1))
            loss = loss / gradient_accumulation_steps
            loss.backward()
            if (i + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad()
            total_loss += loss.item() * gradient_accumulation_steps

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                input_ids = batch['input_ids'].to(device)
                target_ids = batch['target_ids'].to(device)
                outputs = model(input_ids)
                val_loss += criterion(outputs.view(-1, outputs.size(-1)), target_ids.view(-1)).item()

        avg_val_loss = val_loss / len(val_loader)
        avg_train_loss = total_loss / len(train_loader)
        scheduler.step()
        print(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
            }, 'best_lyrics_model.pth')

# Hybrid Generation Function
def generate_text(model, tokenizer, markov_chain, prompt, max_length=100, temperature=0.8, device='cuda'):
    model.eval()
    input_ids = tokenizer.encode(prompt, max_length=256).unsqueeze(0).to(device)
    generated = input_ids[0].tolist()
    max_seq_length = 256

    with torch.no_grad():
        for _ in range(max_length):
            if len(generated) > max_seq_length:
                input_ids = torch.tensor(generated[-max_seq_length:]).unsqueeze(0).to(device)
            else:
                input_ids = torch.tensor(generated).unsqueeze(0).to(device)

            outputs = model(input_ids)
            logits = outputs[0, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)

            # Use Markov chain to guide sampling
            markov_suggestion = markov_chain.get_next_token(generated, tokenizer.vocab_size)
            markov_boost = torch.zeros_like(probs)
            markov_boost[markov_suggestion] += 0.3  # Boost probability of Markov suggestion
            probs = F.softmax(probs + markov_boost, dim=-1)

            next_token = torch.multinomial(probs, 1).item()
            if next_token == tokenizer.pad_token_id:
                break
            generated.append(next_token)

    return prompt + ' ' + tokenizer.decode(torch.tensor(generated[len(tokenizer.encode(prompt)):]))

# Data Processing
def process_text_dataset(dataset, max_samples=100000):
    texts = []
    counter = 0
    splits = dataset.keys() if hasattr(dataset, 'keys') else ['dataset']
    dataset = dataset if hasattr(dataset, 'keys') else {'dataset': dataset}

    for split in splits:
        split_data = dataset[split]
        sample = split_data[0]
        text_fields = [key for key, value in sample.items() if isinstance(value, str) and len(value.split()) > 5]
        text_field = next((field for field in ['text', 'content', 'lyrics', 'sentence', 'article'] if field in text_fields), text_fields[0] if text_fields else None)

        if text_field:
            for item in tqdm(split_data, desc=f"Processing {split}"):
                if text_field in item and item[text_field] and isinstance(item[text_field], str) and len(item[text_field].strip().split()) >= 5:
                    texts.append(item[text_field])
                    counter += 1
                    if counter >= max_samples:
                        break
        if counter >= max_samples:
            break

    return texts

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    max_samples = 100000
    dataset_options = ["rotten_tomatoes", "bookcorpus", "wikitext", "Abirate/english_quotes", "imdb", "ag_news", "glue/sst2"]

    texts = []
    for dataset_name in dataset_options:
        try:
            print(f"Attempting to load dataset: {dataset_name}")
            dataset = load_dataset(dataset_name, split="train")
            texts = process_text_dataset(dataset, max_samples=max_samples)
            if texts:
                print(f"Successfully loaded {len(texts)} samples from {dataset_name}")
                break
        except Exception as e:
            print(f"Failed to load {dataset_name}: {e}")
            continue

    if not texts:
        print("All dataset attempts failed.")
        return

    tokenizer = LyricsTokenizer(vocab_size=10000)
    tokenizer.fit(texts)

    markov_chain = MarkovChain(order=2)
    markov_chain.fit(texts, tokenizer)

    max_length = 256
    all_input_ids = []
    all_target_ids = []
    for text in tqdm(texts, desc="Encoding texts"):
        input_id = tokenizer.encode(text, max_length)
        target_id = input_id.clone()
        target_id[:-1] = input_id[1:]
        target_id[-1] = input_id[0]
        all_input_ids.append(input_id)
        all_target_ids.append(target_id)

    train_indices, val_indices = train_test_split(list(range(len(texts))), test_size=0.1, random_state=42)
    train_dataset = LyricsDataset(
        [texts[i] for i in train_indices],
        [all_input_ids[i] for i in train_indices],
        [all_target_ids[i] for i in train_indices]
    )
    val_dataset = LyricsDataset(
        [texts[i] for i in val_indices],
        [all_input_ids[i] for i in val_indices],
        [all_target_ids[i] for i in val_indices]
    )

    batch_size = 8 if device.type == 'cuda' else 4
    num_workers = 2 if device.type == 'cuda' else 0
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)

    print(f"Training with batch size: {batch_size}")
    print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

    model = LyricsTransformer(
        vocab_size=tokenizer.vocab_size,
        d_model=64,
        num_heads=4,
        num_layers=2,
        d_ff=256
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model size: {total_params/1e6:.2f}M parameters")

    train_model(
        model,
        train_loader,
        val_loader,
        device,
        markov_chain,
        tokenizer,
        num_epochs=3,
        gradient_accumulation_steps=4
    )

    torch.save({
        'model_state_dict': model.state_dict(),
        'tokenizer': tokenizer,
        'markov_chain': markov_chain
    }, 'final_lyrics_model.pth')
    print("Training completed and model saved!")

    print("\nGenerating sample text...")
    prompt = "Write a song about love and romance of my life:"
    generated_text = generate_text(model, tokenizer, markov_chain, prompt, device=device)
    print(generated_text)

if __name__ == "__main__":
    main()

Using device: cpu
Attempting to load dataset: rotten_tomatoes


README.md:   0%|          | 0.00/7.46k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


train.parquet:   0%|          | 0.00/699k [00:00<?, ?B/s]

validation.parquet:   0%|          | 0.00/90.0k [00:00<?, ?B/s]

test.parquet:   0%|          | 0.00/92.2k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8530 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1066 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1066 [00:00<?, ? examples/s]

Processing dataset: 100%|███████████████████████████████████████████████████████| 8530/8530 [00:00<00:00, 10972.76it/s]


Successfully loaded 8388 samples from rotten_tomatoes


Building vocabulary: 100%|██████████████████████████████████████████████████████| 8388/8388 [00:00<00:00, 45570.36it/s]
Building Markov chain: 100%|█████████████████████████████████████████████████████| 8388/8388 [00:02<00:00, 3061.09it/s]
Encoding texts: 100%|████████████████████████████████████████████████████████████| 8388/8388 [00:02<00:00, 3284.73it/s]


Training with batch size: 4
Training samples: 7549, Validation samples: 839
Model size: 0.13M parameters


Epoch 1/3: 100%|███████████████████████████████████████████████████████████████████| 1888/1888 [00:58<00:00, 32.45it/s]
Validation: 100%|████████████████████████████████████████████████████████████████████| 210/210 [00:02<00:00, 90.07it/s]


Epoch 1/3 - Train Loss: 0.3556, Val Loss: 0.1595


Epoch 2/3: 100%|███████████████████████████████████████████████████████████████████| 1888/1888 [00:55<00:00, 34.22it/s]
Validation: 100%|████████████████████████████████████████████████████████████████████| 210/210 [00:02<00:00, 93.00it/s]


Epoch 2/3 - Train Loss: 0.1532, Val Loss: 0.1368


Epoch 3/3: 100%|███████████████████████████████████████████████████████████████████| 1888/1888 [00:54<00:00, 34.35it/s]
Validation: 100%|████████████████████████████████████████████████████████████████████| 210/210 [00:02<00:00, 99.18it/s]


Epoch 3/3 - Train Loss: 0.1399, Val Loss: 0.1338
Training completed and model saved!

Generating sample text...
Write a song about love and romance of my life: always rise shadow <unk> <unk> pain hold smile <unk> come leave never touch fire sky voice light <unk> time night sun soul dreams joy smile joy free <unk> song run go leave mind joy sing always sing hands rhyme life song night time fly hold break sing sing shine never never eyes come rhythm fall <unk> rhythm day hold heart road break night forever memory star smile darling together always light moon rise sky hands together apart dreams fall hold <unk> baby fire smile way cry hands <unk> dark road hands back eyes <unk> heart fall fall know <unk> break


In [11]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
from collections import Counter, defaultdict
from gtts import gTTS
from pydub import AudioSegment
from tqdm import tqdm

# Define the themes and the prompt for each theme
themes = {
    "love": "Write a beautiful song about love and relationships.",
    "nature": "Write a song about the beauty of nature and the environment.",
    "friendship": "Write a song about the power and importance of friendship.",
    "hope": "Write a song about hope and positivity for the future.",
    "freedom": "Write a song about freedom and independence."
}

# Lyrics-specific vocabulary
COMMON_LYRICS_WORDS = [
    '<pad>', '<unk>', 'love', 'heart', 'dream', 'night', 'day', 'life', 'time', 'way',
    'feel', 'know', 'go', 'see', 'world', 'soul', 'sky', 'star', 'moon', 'sun',
    'baby', 'darling', 'forever', 'always', 'never', 'together', 'apart', 'home', 'road', 'fire',
    'dance', 'sing', 'song', 'melody', 'rhyme', 'beat', 'rhythm', 'free', 'run', 'fly',
    'hold', 'touch', 'kiss', 'smile', 'cry', 'tears', 'pain', 'joy', 'hope', 'fear',
    'light', 'dark', 'shadow', 'shine', 'burn', 'break', 'fall', 'rise', 'stay', 'leave',
    'come', 'gone', 'back', 'memory', 'dreams', 'eyes', 'hands', 'voice', 'mind', 'spirit'
]

# Custom Tokenizer (matching the training code)
class LyricsTokenizer:
    def __init__(self, vocab_size=10000):
        self.vocab_size = vocab_size
        self.word_to_idx = {word: idx for idx, word in enumerate(COMMON_LYRICS_WORDS)}
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}
        self.pad_token_id = 0
        self.unk_token_id = 1
        self.vocab_size = len(self.word_to_idx)

    def fit(self, texts):
        words = []
        for text in tqdm(texts, desc="Building vocabulary"):
            words.extend(re.findall(r'\b\w+\b', text.lower()))

        word_counts = Counter(words)
        additional_words = [word for word, _ in word_counts.most_common(self.vocab_size - len(COMMON_LYRICS_WORDS))]
        for word in additional_words:
            if word not in self.word_to_idx and len(self.word_to_idx) < self.vocab_size:
                self.word_to_idx[word] = len(self.word_to_idx)
                self.idx_to_word[len(self.idx_to_word)] = word
        self.vocab_size = len(self.word_to_idx)

    def encode(self, text, max_length=256):
        words = re.findall(r'\b\w+\b', text.lower())
        ids = [self.word_to_idx.get(word, self.unk_token_id) for word in words]
        if len(ids) < max_length:
            ids = ids + [self.pad_token_id] * (max_length - len(ids))
        else:
            ids = ids[:max_length]
        return torch.tensor(ids)

    def decode(self, ids):
        words = [self.idx_to_word.get(id.item(), '<unk>') for id in ids if id != self.pad_token_id]
        return ' '.join(words)

# Markov Chain (matching the training code)
class MarkovChain:
    def __init__(self, order=2):
        self.order = order
        self.transitions = defaultdict(list)

    def fit(self, texts, tokenizer):
        for text in tqdm(texts, desc="Building Markov chain"):
            ids = tokenizer.encode(text).tolist()
            for i in range(len(ids) - self.order):
                state = tuple(ids[i:i + self.order])
                next_token = ids[i + self.order]
                self.transitions[state].append(next_token)

    def get_next_token(self, state, vocab_size):
        state = tuple(state[-self.order:])
        if state in self.transitions and self.transitions[state]:
            return np.random.choice(self.transitions[state])
        return np.random.randint(2, vocab_size)

# Model components
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)
        context = context.permute(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)
        return self.proj(context)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.dropout(self.attn(self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

# Main model (matching the training code)
class LyricsTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=256, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.zeros(1, 256, d_model))
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_encoding[:, :seq_len]
        x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x, mask)
        return self.fc(x)

# Generation function
def generate_with_model(model, prompt, tokenizer, markov_chain, max_length=200, temperature=0.8, device='cpu'):
    model.eval()
    input_ids = tokenizer.encode(prompt, max_length=256).unsqueeze(0).to(device)
    generated = input_ids[0].tolist()
    max_seq_length = 256

    with torch.no_grad():
        for _ in range(max_length):
            if len(generated) > max_seq_length:
                input_ids = torch.tensor(generated[-max_seq_length:]).unsqueeze(0).to(device)
            else:
                input_ids = torch.tensor(generated).unsqueeze(0).to(device)

            outputs = model(input_ids)
            logits = outputs[0, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)

            # Use Markov chain to guide sampling
            markov_suggestion = markov_chain.get_next_token(generated, tokenizer.vocab_size)
            markov_boost = torch.zeros_like(probs)
            markov_boost[markov_suggestion] += 0.3
            probs = F.softmax(probs + markov_boost, dim=-1)

            next_token = torch.multinomial(probs, 1).item()
            if next_token == tokenizer.pad_token_id:
                break
            generated.append(next_token)

    generated_lyrics = tokenizer.decode(torch.tensor(generated[len(tokenizer.encode(prompt)):]))
    formatted_lyrics = format_lyrics(generated_lyrics)
    return formatted_lyrics

def format_lyrics(text):
    words = text.split()
    if len(words) < 20:
        return text

    lines = []
    line_length = 0
    current_line = []
    for word in words:
        current_line.append(word)
        line_length += 1
        if line_length >= min(4, max(2, len(current_line)//2)) and word[-1] in '.,:;?!':
            lines.append(' '.join(current_line))
            current_line = []
            line_length = 0
        elif line_length >= 8:
            lines.append(' '.join(current_line))
            current_line = []
            line_length = 0
    if current_line:
        lines.append(' '.join(current_line))

    formatted_lyrics = []
    formatted_lyrics.append("Verse 1:")
    verse_length = min(8, len(lines) // 2)
    for i in range(verse_length):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Chorus:")
    chorus_start = verse_length
    chorus_end = min(chorus_start + 4, len(lines))
    for i in range(chorus_start, chorus_end):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    if chorus_end + 3 < len(lines):
        formatted_lyrics.append("Verse 2:")
        verse2_end = min(chorus_end + 6, len(lines))
        for i in range(chorus_end, verse2_end):
            formatted_lyrics.append(lines[i])
        formatted_lyrics.append("")
        formatted_lyrics.append("Chorus:")
        for i in range(chorus_start, chorus_end):
            formatted_lyrics.append(lines[i])

    return '\n'.join(formatted_lyrics)

def load_custom_lyrics_model(model_path="/content/final_lyrics_model.pth"):
    try:
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file {model_path} not found")

        # Add safe globals for secure loading
        torch.serialization.add_safe_globals([LyricsTokenizer, MarkovChain])

        # Load checkpoint
        try:
            checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
            print("Successfully loaded checkpoint")
        except Exception as e:
            print(f"Failed to load checkpoint: {e}")
            raise

        # Extract tokenizer and Markov chain
        tokenizer = checkpoint.get('tokenizer')
        if tokenizer is None:
            raise ValueError("Tokenizer not found in checkpoint")
        print("Loaded tokenizer from checkpoint")

        markov_chain = checkpoint.get('markov_chain')
        if markov_chain is None:
            raise ValueError("MarkovChain not found in checkpoint")
        print("Loaded MarkovChain from checkpoint")

        # Initialize model with correct vocab_size
        model = LyricsTransformer(
            vocab_size=tokenizer.vocab_size,
            d_model=64,
            num_heads=4,
            num_layers=2,
            d_ff=256
        )

        # Load model weights
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print("Loaded model weights from 'model_state_dict' key")
        else:
            raise ValueError("Model state dictionary not found in checkpoint")

        model.eval()
        print(f"Model loaded successfully from {model_path}")
        return model, tokenizer, markov_chain

    except Exception as e:
        print(f"Error loading model: {e}")
        raise

def generate_lyrics(theme, model, tokenizer, markov_chain):
    try:
        prompt = themes.get(theme.lower(), "Write a beautiful song.")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        model = model.to(device)
        lyrics = generate_with_model(model, prompt, tokenizer, markov_chain, device=device)
        return lyrics
    except Exception as e:
        print("An error occurred during lyrics generation:", e)
        return None

def text_to_speech(lyrics, output_file="lyrics_audio.mp3"):
    try:
        tts = gTTS(text=lyrics, lang='en')
        tts.save(output_file)
        print(f"Speech generated and saved as '{output_file}'")
    except Exception as e:
        print("Error in text-to-speech conversion:", e)

def extend_melody_to_match_lyrics(lyrics_audio_file, melody_file, output_file):
    try:
        melody = AudioSegment.from_wav(melody_file)
        lyrics = AudioSegment.from_mp3(lyrics_audio_file)
        lyrics_duration = len(lyrics)
        melody_duration = len(melody)
        if melody_duration < lyrics_duration:
            loop_count = (lyrics_duration // melody_duration) + 1
            extended_melody = melody * loop_count
            extended_melody = extended_melody[:lyrics_duration]
        else:
            extended_melody = melody[:lyrics_duration]
        melody_volume = -6
        extended_melody = extended_melody + melody_volume
        final_output = extended_melody.overlay(lyrics)
        final_output.export(output_file, format="mp3")
        print(f"Final combined song saved as '{output_file}'")
    except Exception as e:
        print(f"Error combining audio: {e}")

def create_fallback_model_and_tokenizer():
    print("\nWARNING: Creating fallback model and tokenizer for testing purposes.")
    print("This will not generate coherent lyrics without proper training.\n")
    tokenizer = LyricsTokenizer(vocab_size=10000)
    sample_words = ["the", "a", "and", "in", "of", "to", "is", "was", "it",
                    "you", "i", "love", "nature", "sky", "tree", "river", "mountain",
                    "friend", "hope", "dream", "sun", "moon", "star", "heart", "soul",
                    "wind", "rain", "song", "music", "dance", "life", "time", "day",
                    "night", "light", "dark", "beautiful", "happy", "sad", "free"]
    tokenizer.word_to_idx = {word: idx+2 for idx, word in enumerate(sample_words)}
    tokenizer.word_to_idx["<pad>"] = 0
    tokenizer.word_to_idx["<unk>"] = 1
    tokenizer.idx_to_word = {idx: word for word, idx in tokenizer.word_to_idx.items()}
    tokenizer.vocab_size = len(tokenizer.word_to_idx)
    model = LyricsTransformer(vocab_size=tokenizer.vocab_size)
    markov_chain = MarkovChain(order=2)
    return model, tokenizer, markov_chain

if __name__ == "__main__":
    theme = input("Enter a theme for your song (love, nature, friendship, hope, freedom): ").strip().lower()
    if theme not in themes:
        print(f"Theme '{theme}' not recognized. Using 'love' as default.")
        theme = "love"

    print("\nLoading custom lyrics generator model...\n")
    model = None
    tokenizer = None
    markov_chain = None
    try:
        model, tokenizer, markov_chain = load_custom_lyrics_model("/content/final_lyrics_model.pth")
    except Exception as e:
        print(f"\nFatal error loading model: {e}")
        print("\nWould you like to create a fallback model for testing purposes?")
        response = input("This won't generate good lyrics but will test the code flow (y/n): ").strip().lower()
        if response == 'y' or response == 'yes':
            model, tokenizer, markov_chain = create_fallback_model_and_tokenizer()
        else:
            print("Exiting program.")
            exit()

    print("\nGenerating Lyrics... Please wait.\n")
    lyrics = generate_lyrics(theme, model, tokenizer, markov_chain)

    if lyrics:
        print("Generated Lyrics:\n")
        print(lyrics)

        lyrics_audio_file = "lyrics_audio.mp3"
        text_to_speech(lyrics, lyrics_audio_file)

        melody_file = input("\nEnter path to melody file (default: generated_melody.wav): ").strip()
        if not melody_file:
            melody_file = "/content/drive/MyDrive/GenAI/generated_melody.wav"

        if not os.path.exists(melody_file):
            print(f"Warning: Melody file '{melody_file}' not found. Please check the path.")
        else:
            output_file = f"{theme}_song.mp3"
            print("\nCombining lyrics with melody...\n")
            extend_melody_to_match_lyrics(lyrics_audio_file, melody_file, output_file)
            print(f"\nSong generation complete! Your song is saved as '{output_file}'")
    else:
        print("Failed to generate lyrics.")

Enter a theme for your song (love, nature, friendship, hope, freedom): love

Loading custom lyrics generator model...

Successfully loaded checkpoint
Loaded tokenizer from checkpoint
Loaded MarkovChain from checkpoint
Loaded model weights from 'model_state_dict' key
Model loaded successfully from /content/final_lyrics_model.pth

Generating Lyrics... Please wait.

Using device: cuda
Generated Lyrics:

Verse 1:
<unk> joy tears life song leave melody hope
light eyes baby cry pain gone voice dance
together come sing sun sun sing kiss touch
break touch rhythm come time go life darling

Chorus:
leave home voice dream forever see kiss home
life apart leave dreams joy rise fly come
sing song dream rhythm apart love break dreams
know heart hope eyes feel world apart <unk>

Speech generated and saved as 'lyrics_audio.mp3'

Enter path to melody file (default: generated_melody.wav): /content/generated_melody.wav

Combining lyrics with melody...

Final combined song saved as 'love_song.mp3'

Song g

In [4]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import numpy as np
from collections import Counter, defaultdict
from gtts import gTTS
from pydub import AudioSegment
from tqdm import tqdm

# Define the themes and the prompt for each theme
themes = {
    "love": "Write a beautiful song about love and relationships.",
    "nature": "Write a song about the beauty of nature and the environment.",
    "friendship": "Write a song about the power and importance of friendship.",
    "hope": "Write a song about hope and positivity for the future.",
    "freedom": "Write a song about freedom and independence."
}

# Lyrics-specific vocabulary with theme-specific words
COMMON_LYRICS_WORDS = [
    '<pad>', '<unk>', 'love', 'heart', 'dream', 'night', 'day', 'life', 'time', 'way',
    'feel', 'know', 'go', 'see', 'world', 'soul', 'sky', 'star', 'moon', 'sun',
    'baby', 'darling', 'forever', 'always', 'never', 'together', 'apart', 'home', 'road', 'fire',
    'dance', 'sing', 'song', 'melody', 'rhyme', 'beat', 'rhythm', 'free', 'run', 'fly',
    'hold', 'touch', 'kiss', 'smile', 'cry', 'tears', 'pain', 'joy', 'hope', 'fear',
    'light', 'dark', 'shadow', 'shine', 'burn', 'break', 'fall', 'rise', 'stay', 'leave',
    'come', 'gone', 'back', 'memory', 'dreams', 'eyes', 'hands', 'voice', 'mind', 'spirit',
    # Love
    'passion', 'devotion', 'eternal', 'cherish', 'adore', 'sweet', 'romance', 'embrace', 'lover', 'dear',
    # Nature
    'forest', 'river', 'mountain', 'ocean', 'wind', 'tree', 'flower', 'earth', 'valley', 'meadow',
    'stream', 'breeze', 'horizon', 'dawn', 'twilight', 'rain', 'mist', 'lake', 'pine', 'bloom',
    # Friendship
    'friend', 'bond', 'trust', 'loyal', 'share', 'laughter', 'support', 'care', 'companion', 'unity',
    # Hope
    'future', 'promise', 'vision', 'aspire', 'uplift', 'believe', 'shine', 'tomorrow', 'wish', 'dreamer',
    # Freedom
    'liberty', 'wings', 'open', 'skyward', 'unbound', 'journey', 'release', 'soar', 'freebird', 'escape'
]

# Theme-specific word lists for boosting
THEME_WORDS = {
    "love": [
        'love', 'heart', 'kiss', 'passion', 'devotion', 'eternal', 'cherish', 'adore', 'sweet', 'romance',
        'embrace', 'lover', 'dear', 'baby', 'darling', 'forever', 'always', 'together', 'soul', 'dream'
    ],
    "nature": [
        'forest', 'river', 'mountain', 'ocean', 'wind', 'tree', 'flower', 'earth', 'valley', 'meadow',
        'stream', 'breeze', 'horizon', 'dawn', 'twilight', 'rain', 'mist', 'lake', 'pine', 'bloom'
    ],
    "friendship": [
        'friend', 'bond', 'trust', 'loyal', 'share', 'laughter', 'support', 'care', 'companion', 'unity'
    ],
    "hope": [
        'hope', 'future', 'promise', 'vision', 'aspire', 'uplift', 'believe', 'shine', 'tomorrow', 'dreamer'
    ],
    "freedom": [
        'free', 'liberty', 'wings', 'open', 'skyward', 'unbound', 'journey', 'release', 'soar', 'freebird'
    ]
}

# Custom Tokenizer
class LyricsTokenizer:
    def __init__(self, vocab_size=15000):
        self.vocab_size = vocab_size
        self.word_to_idx = {word: idx for idx, word in enumerate(COMMON_LYRICS_WORDS)}
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}
        self.pad_token_id = 0
        self.unk_token_id = 1
        self.vocab_size = len(self.word_to_idx)

    def fit(self, texts):
        words = []
        for text in tqdm(texts, desc="Building vocabulary"):
            words.extend(re.findall(r'\b\w+\b', text.lower()))

        word_counts = Counter(words)
        additional_words = [word for word, _ in word_counts.most_common(self.vocab_size - len(COMMON_LYRICS_WORDS))]
        for word in additional_words:
            if word not in self.word_to_idx and len(self.word_to_idx) < self.vocab_size:
                self.word_to_idx[word] = len(self.word_to_idx)
                self.idx_to_word[len(self.idx_to_word)] = word
        self.vocab_size = len(self.word_to_idx)

    def encode(self, text, max_length=256):
        words = re.findall(r'\b\w+\b', text.lower())
        ids = [self.word_to_idx.get(word, self.unk_token_id) for word in words]
        if len(ids) < max_length:
            ids = ids + [self.pad_token_id] * (max_length - len(ids))
        else:
            ids = ids[:max_length]
        print(f"Encoded prompt: {words} -> {ids[:10]}... (length: {len(ids)})")
        return torch.tensor(ids)

    def decode(self, ids):
        words = [self.idx_to_word.get(id.item(), '<unk>') for id in ids if id != self.pad_token_id]
        return ' '.join(words)

# Markov Chain
class MarkovChain:
    def __init__(self, order=3):
        self.order = order
        self.transitions = defaultdict(list)

    def fit(self, texts, tokenizer):
        for text in tqdm(texts, desc="Building Markov chain"):
            ids = tokenizer.encode(text).tolist()
            for i in range(len(ids) - self.order):
                state = tuple(ids[i:i + self.order])
                next_token = ids[i + self.order]
                self.transitions[state].append(next_token)

    def get_next_token(self, state, vocab_size):
        state = tuple(state[-self.order:])
        if state in self.transitions and self.transitions[state]:
            return np.random.choice(self.transitions[state])
        return np.random.randint(2, vocab_size)

# Model components
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)
        context = context.permute(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)
        return self.proj(context)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.dropout(self.attn(self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

# Main model
class LyricsTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=256, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.zeros(1, 256, d_model))  # Match checkpoint
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_encoding[:, :seq_len]
        x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x, mask)
        return self.fc(x)

# Generation function with theme-specific boosting
def generate_with_model(model, prompt, tokenizer, markov_chain, max_length=600, temperature=0.85, device='cpu', theme='love'):
    model.eval()
    input_ids = tokenizer.encode(prompt, max_length=256).unsqueeze(0).to(device)
    generated = input_ids[0].tolist()
    max_seq_length = 256

    # Get indices of theme-specific words
    theme_word_ids = [tokenizer.word_to_idx[word] for word in THEME_WORDS.get(theme, []) if word in tokenizer.word_to_idx]

    with torch.no_grad():
        for _ in range(max_length):
            if len(generated) > max_seq_length:
                input_ids = torch.tensor(generated[-max_seq_length:]).unsqueeze(0).to(device)
            else:
                input_ids = torch.tensor(generated).unsqueeze(0).to(device)

            outputs = model(input_ids)
            logits = outputs[0, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)

            # Boost theme-specific words and Markov suggestions
            markov_suggestion = markov_chain.get_next_token(generated, tokenizer.vocab_size)
            markov_boost = torch.zeros_like(probs)
            markov_boost[markov_suggestion] += 0.7
            for theme_id in theme_word_ids:
                markov_boost[theme_id] += 0.6
            probs = F.softmax(probs + markov_boost, dim=-1)
            probs[tokenizer.unk_token_id] *= 0.01  # Penalize <unk> tokens
            probs[tokenizer.pad_token_id] *= 0.01  # Avoid early termination

            next_token = torch.multinomial(probs, 1).item()
            generated.append(next_token)
            if len(generated) >= max_length:
                break

    generated_lyrics = tokenizer.decode(torch.tensor(generated[len(tokenizer.encode(prompt)):]))
    formatted_lyrics = format_lyrics(generated_lyrics, theme)
    return formatted_lyrics

def format_lyrics(text, theme):
    words = text.split()
    if len(words) < 40:
        words.extend(THEME_WORDS.get(theme, LOVE_WORDS) * 3)
        print(f"Warning: Generated text too short ({len(words)} words). Padded with {theme}-themed words.")

    lines = []
    line_length = 0
    current_line = []
    for word in words:
        current_line.append(word)
        line_length += 1
        if line_length >= min(5, max(3, len(current_line)//2)) and word[-1] in '.,:;?!':
            lines.append(' '.join(current_line))
            current_line = []
            line_length = 0
        elif line_length >= 7:
            lines.append(' '.join(current_line))
            current_line = []
            line_length = 0
    if current_line:
        lines.append(' '.join(current_line))

    formatted_lyrics = []
    total_lines = len(lines)
    formatted_lyrics.append("Verse 1:")
    verse_length = min(8, total_lines // 4)
    for i in range(min(verse_length, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Chorus:")
    chorus_start = verse_length
    chorus_end = min(chorus_start + 4, total_lines)
    for i in range(chorus_start, min(chorus_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Verse 2:")
    verse2_start = chorus_end
    verse2_end = min(verse2_start + 8, total_lines)
    for i in range(verse2_start, min(verse2_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Bridge:")
    bridge_start = verse2_end
    bridge_end = min(bridge_start + 4, total_lines)
    for i in range(bridge_start, min(bridge_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Verse 3:")
    verse3_start = bridge_end
    verse3_end = min(verse3_start + 8, total_lines)
    for i in range(verse3_start, min(verse3_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Chorus:")
    for i in range(chorus_start, min(chorus_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Outro:")
    outro_start = verse3_end
    outro_end = min(outro_start + 3, total_lines)
    for i in range(outro_start, min(outro_end, total_lines)):
        formatted_lyrics.append(lines[i])

    return '\n'.join(formatted_lyrics)

def load_custom_lyrics_model(model_path="/content/final_lyrics_model.pth"):
    try:
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file {model_path} not found")

        torch.serialization.add_safe_globals([LyricsTokenizer, MarkovChain])

        checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
        print("Successfully loaded checkpoint")

        tokenizer = checkpoint.get('tokenizer')
        if tokenizer is None:
            raise ValueError("Tokenizer not found in checkpoint")
        print("Loaded tokenizer from checkpoint")

        markov_chain = checkpoint.get('markov_chain')
        if markov_chain is None:
            raise ValueError("MarkovChain not found in checkpoint")
        print("Loaded MarkovChain from checkpoint")

        model = LyricsTransformer(
            vocab_size=tokenizer.vocab_size,
            d_model=64,
            num_heads=4,
            num_layers=2,
            d_ff=256
        )

        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print("Loaded model weights from 'model_state_dict' key")
        else:
            raise ValueError("Model state dictionary not found in checkpoint")

        model.eval()
        print(f"Model loaded successfully from {model_path}")
        return model, tokenizer, markov_chain

    except Exception as e:
        print(f"Error loading model: {e}")
        raise

def generate_lyrics(theme, model, tokenizer, markov_chain):
    try:
        prompt = themes.get(theme.lower(), "Write a beautiful song.")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        model = model.to(device)
        lyrics = generate_with_model(model, prompt, tokenizer, markov_chain, device=device, theme=theme)
        return lyrics
    except Exception as e:
        print("An error occurred during lyrics generation:", e)
        return None

def text_to_speech(lyrics, output_file="lyrics_audio.mp3"):
    try:
        tts = gTTS(text=lyrics, lang='en')
        tts.save(output_file)
        print(f"Speech generated and saved as '{output_file}'")
    except Exception as e:
        print("Error in text-to-speech conversion:", e)

def extend_melody_to_match_lyrics(lyrics_audio_file, melody_file, output_file):
    try:
        melody = AudioSegment.from_wav(melody_file)
        lyrics = AudioSegment.from_mp3(lyrics_audio_file)
        lyrics_duration = len(lyrics)
        melody_duration = len(melody)
        if melody_duration < lyrics_duration:
            loop_count = (lyrics_duration // melody_duration) + 1
            extended_melody = melody * loop_count
            extended_melody = extended_melody[:lyrics_duration]
        else:
            extended_melody = melody[:lyrics_duration]
        melody_volume = -6
        extended_melody = extended_melody + melody_volume
        final_output = extended_melody.overlay(lyrics)
        final_output.export(output_file, format="mp3")
        print(f"Final combined song saved as '{output_file}'")
    except Exception as e:
        print(f"Error combining audio: {e}")

def create_fallback_model_and_tokenizer(theme="love"):
    print("\nWARNING: Creating fallback model and tokenizer for testing purposes.")
    print("This will not generate coherent lyrics without proper training.\n")
    tokenizer = LyricsTokenizer(vocab_size=15000)
    sample_words = [
        "the", "a", "and", "in", "of", "to", "is", "was", "it", "you", "i",
        "life", "time", "way", "feel", "know", "go", "see", "world", "soul", "sky"
    ] + THEME_WORDS.get(theme, THEME_WORDS["love"])
    tokenizer.word_to_idx = {word: idx+2 for idx, word in enumerate(sample_words)}
    tokenizer.word_to_idx["<pad>"] = 0
    tokenizer.word_to_idx["<unk>"] = 1
    tokenizer.idx_to_word = {idx: word for word, idx in tokenizer.word_to_idx.items()}
    tokenizer.vocab_size = len(tokenizer.word_to_idx)
    model = LyricsTransformer(vocab_size=tokenizer.vocab_size)
    markov_chain = MarkovChain(order=3)
    return model, tokenizer, markov_chain

if __name__ == "__main__":
    theme = input("Enter a theme for your song (love, nature, friendship, hope, freedom): ").strip().lower()
    if theme not in themes:
        print(f"Theme '{theme}' not recognized. Using 'love' as default.")
        theme = "love"

    print("\nLoading custom lyrics generator model...\n")
    model = None
    tokenizer = None
    markov_chain = None
    try:
        model, tokenizer, markov_chain = load_custom_lyrics_model("final_lyrics_model.pth")
    except Exception as e:
        print(f"\nFatal error loading model: {e}")
        print("\nWould you like to create a fallback model for testing purposes?")
        response = input("This won't generate good lyrics but will test the code flow (y/n): ").strip().lower()
        if response == 'y' or response == 'yes':
            model, tokenizer, markov_chain = create_fallback_model_and_tokenizer(theme)
        else:
            print("Exiting program.")
            exit()

    print("\nGenerating Lyrics... Please wait.\n")
    lyrics = generate_lyrics(theme, model, tokenizer, markov_chain)

    if lyrics:
        print("Generated Lyrics:\n")
        print(lyrics)

        lyrics_audio_file = "lyrics_audio.mp3"
        text_to_speech(lyrics, lyrics_audio_file)

        melody_file = input("\nEnter path to melody file (default: generated_melody.wav): ").strip()
        if not melody_file:
            melody_file = "/content/drive/MyDrive/GenAI/generated_melody.wav"

        if not os.path.exists(melody_file):
            print(f"Warning: Melody file '{melody_file}' not found. Please check the path.")
        else:
            output_file = f"{theme}_song.mp3"
            print("\nCombining lyrics with melody...\n")
            extend_melody_to_match_lyrics(lyrics_audio_file, melody_file, output_file)
            print(f"\nSong generation complete! Your song is saved as '{output_file}'")
    else:
        print("Failed to generate lyrics.")

Enter a theme for your song (love, nature, friendship, hope, freedom):  friendship



Loading custom lyrics generator model...

Successfully loaded checkpoint
Loaded tokenizer from checkpoint
Loaded MarkovChain from checkpoint
Loaded model weights from 'model_state_dict' key
Model loaded successfully from final_lyrics_model.pth

Generating Lyrics... Please wait.

Using device: cpu
Encoded prompt: ['write', 'a', 'song', 'about', 'the', 'power', 'and', 'importance', 'of', 'friendship'] -> [1, 1, 32, 1, 1, 1, 1, 1, 1, 1]... (length: 256)
Encoded prompt: ['write', 'a', 'song', 'about', 'the', 'power', 'and', 'importance', 'of', 'friendship'] -> [1, 1, 32, 1, 1, 1, 1, 1, 1, 1]... (length: 256)
Generated Lyrics:

Verse 1:
light eyes darling baby shine forever hope
come life sky always fear fly life
touch baby baby run know star eyes
back run see song stay rhythm heart
mind sing light hold moon rhyme baby
darling time always light spirit dark tears
hands know mind dreams always eyes home
feel shadow sky gone dreams eyes free

Chorus:
light road kiss see road life apart
joy jo


Enter path to melody file (default: generated_melody.wav):  "D:\PROJECT\GenAI_prj\FINAL\generated_melody.wav"




In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import numpy as np
from collections import Counter, defaultdict
from gtts import gTTS
from pydub import AudioSegment
from tqdm import tqdm
from difflib import get_close_matches

# Define the themes and the prompt for each theme
themes = {
    "love": "Write a beautiful song about love and relationships.",
    "nature": "Write a song about the beauty of nature and the environment.",
    "friendship": "Write a song about the power and importance of friendship.",
    "hope": "Write a song about hope and positivity for the future.",
    "freedom": "Write a song about freedom and independence."
}

# Lyrics-specific vocabulary with theme-specific words
COMMON_LYRICS_WORDS = [
    '<pad>', '<unk>', 'love', 'heart', 'dream', 'night', 'day', 'life', 'time', 'way',
    'feel', 'know', 'go', 'see', 'world', 'soul', 'sky', 'star', 'moon', 'sun',
    'baby', 'darling', 'forever', 'always', 'never', 'together', 'apart', 'home', 'road', 'fire',
    'dance', 'sing', 'song', 'melody', 'rhyme', 'beat', 'rhythm', 'free', 'run', 'fly',
    'hold', 'touch', 'kiss', 'smile', 'cry', 'tears', 'pain', 'joy', 'hope', 'fear',
    'light', 'dark', 'shadow', 'shine', 'burn', 'break', 'fall', 'rise', 'stay', 'leave',
    'come', 'gone', 'back', 'memory', 'dreams', 'eyes', 'hands', 'voice', 'mind', 'spirit',
    'passion', 'devotion', 'eternal', 'cherish', 'adore', 'sweet', 'romance', 'embrace', 'lover', 'dear',
    'forest', 'river', 'mountain', 'ocean', 'wind', 'tree', 'flower', 'earth', 'valley', 'meadow',
    'stream', 'breeze', 'horizon', 'dawn', 'twilight', 'rain', 'mist', 'lake', 'pine', 'bloom',
    'friend', 'bond', 'trust', 'loyal', 'share', 'laughter', 'support', 'care', 'companion', 'unity',
    'future', 'promise', 'vision', 'aspire', 'uplift', 'believe', 'shine', 'tomorrow', 'wish', 'dreamer',
    'liberty', 'wings', 'open', 'skyward', 'unbound', 'journey', 'release', 'soar', 'freebird', 'escape',
    'write', 'beautiful', 'about', 'and', 'relationships', 'environment', 'power', 'importance', 'positivity', 'independence'
]

# Theme-specific word lists for boosting
THEME_WORDS = {
    "love": [
        'love', 'heart', 'kiss', 'passion', 'devotion', 'eternal', 'cherish', 'adore', 'sweet', 'romance',
        'embrace', 'lover', 'dear', 'baby', 'darling', 'forever', 'always', 'together', 'soul', 'dream'
    ],
    "nature": [
        'forest', 'river', 'mountain', 'ocean', 'wind', 'tree', 'flower', 'earth', 'valley', 'meadow',
        'stream', 'breeze', 'horizon', 'dawn', 'twilight', 'rain', 'mist', 'lake', 'pine', 'bloom'
    ],
    "friendship": [
        'friend', 'bond', 'trust', 'loyal', 'share', 'laughter', 'support', 'care', 'companion', 'unity'
    ],
    "hope": [
        'hope', 'future', 'promise', 'vision', 'aspire', 'uplift', 'believe', 'shine', 'tomorrow', 'dreamer'
    ],
    "freedom": [
        'free', 'liberty', 'wings', 'open', 'skyward', 'unbound', 'journey', 'release', 'soar', 'freebird'
    ]
}

# Custom Tokenizer
class LyricsTokenizer:
    def __init__(self, vocab_size=15000):
        self.vocab_size = vocab_size
        self.word_to_idx = {word: idx for idx, word in enumerate(COMMON_LYRICS_WORDS)}
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}
        self.pad_token_id = 0
        self.unk_token_id = 1
        self.vocab_size = len(self.word_to_idx)

    def fit(self, texts):
        words = []
        for text in tqdm(texts, desc="Building vocabulary"):
            words.extend(re.findall(r'\b\w+\b', text.lower()))

        word_counts = Counter(words)
        additional_words = [word for word, _ in word_counts.most_common(self.vocab_size - len(COMMON_LYRICS_WORDS))]
        for word in additional_words:
            if word not in self.word_to_idx and len(self.word_to_idx) < self.vocab_size:
                self.word_to_idx[word] = len(self.word_to_idx)
                self.idx_to_word[len(self.idx_to_word)] = word
        self.vocab_size = len(self.word_to_idx)

    def encode(self, text, max_length=256):
        words = re.findall(r'\b\w+\b', text.lower())
        ids = [self.word_to_idx.get(word, self.unk_token_id) for word in words]
        if len(ids) < max_length:
            ids = ids + [self.pad_token_id] * (max_length - len(ids))
        else:
            ids = ids[:max_length]
        print(f"Encoded prompt: {words} -> {ids[:10]}... (length: {len(ids)})")
        return torch.tensor(ids)

    def decode(self, ids):
        words = [self.idx_to_word.get(id.item(), '<unk>') for id in ids if id != self.pad_token_id]
        return ' '.join(words)

# Markov Chain
class MarkovChain:
    def __init__(self, order=3):
        self.order = order
        self.transitions = defaultdict(list)

    def fit(self, texts, tokenizer):
        for text in tqdm(texts, desc="Building Markov chain"):
            ids = tokenizer.encode(text).tolist()
            for i in range(len(ids) - self.order):
                state = tuple(ids[i:i + self.order])
                next_token = ids[i + self.order]
                self.transitions[state].append(next_token)

    def get_next_token(self, state, vocab_size):
        state = tuple(state[-self.order:])
        if state in self.transitions and self.transitions[state]:
            return np.random.choice(self.transitions[state])
        return np.random.randint(2, vocab_size)

# Model components
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)
        context = context.permute(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)
        return self.proj(context)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.dropout(self.attn(self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

# Main model
class LyricsTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, num_heads=4, num_layers=2, d_ff=256, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.zeros(1, 256, d_model))  # Match checkpoint
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_encoding[:, :seq_len]
        x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x, mask)
        return self.fc(x)

# Generation function with theme-specific boosting
def generate_with_model(model, prompt, tokenizer, markov_chain, max_length=600, temperature=0.85, device='cpu', theme='love'):
    model.eval()
    input_ids = tokenizer.encode(prompt, max_length=256).unsqueeze(0).to(device)
    generated = input_ids[0].tolist()
    max_seq_length = 256

    # Get indices of theme-specific words
    theme_word_ids = [tokenizer.word_to_idx[word] for word in THEME_WORDS.get(theme, []) if word in tokenizer.word_to_idx]

    with torch.no_grad():
        for _ in range(max_length):
            if len(generated) > max_seq_length:
                input_ids = torch.tensor(generated[-max_seq_length:]).unsqueeze(0).to(device)
            else:
                input_ids = torch.tensor(generated).unsqueeze(0).to(device)

            outputs = model(input_ids)
            logits = outputs[0, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)

            # Boost theme-specific words and Markov suggestions
            markov_suggestion = markov_chain.get_next_token(generated, tokenizer.vocab_size)
            markov_boost = torch.zeros_like(probs)
            markov_boost[markov_suggestion] += 0.7
            for theme_id in theme_word_ids:
                markov_boost[theme_id] += 0.6
            probs = F.softmax(probs + markov_boost, dim=-1)
            probs[tokenizer.unk_token_id] *= 0.01  # Penalize <unk> tokens
            probs[tokenizer.pad_token_id] *= 0.01  # Avoid early termination

            next_token = torch.multinomial(probs, 1).item()
            generated.append(next_token)
            if len(generated) >= max_length:
                break

    generated_lyrics = tokenizer.decode(torch.tensor(generated[len(tokenizer.encode(prompt)):]))
    formatted_lyrics = format_lyrics(generated_lyrics, theme)
    return formatted_lyrics

def format_lyrics(text, theme):
    words = text.split()
    if len(words) < 40:
        words.extend(THEME_WORDS.get(theme, THEME_WORDS["love"]) * 3)
        print(f"Warning: Generated text too short ({len(words)} words). Padded with {theme}-themed words.")

    lines = []
    line_length = 0
    current_line = []
    for word in words:
        current_line.append(word)
        line_length += 1
        if line_length >= min(5, max(3, len(current_line)//2)) and word[-1] in '.,:;?!':
            lines.append(' '.join(current_line))
            current_line = []
            line_length = 0
        elif line_length >= 7:
            lines.append(' '.join(current_line))
            current_line = []
            line_length = 0
    if current_line:
        lines.append(' '.join(current_line))

    formatted_lyrics = []
    total_lines = len(lines)
    formatted_lyrics.append("Verse 1:")
    verse_length = min(8, total_lines // 4)
    for i in range(min(verse_length, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Chorus:")
    chorus_start = verse_length
    chorus_end = min(chorus_start + 4, total_lines)
    for i in range(chorus_start, min(chorus_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Verse 2:")
    verse2_start = chorus_end
    verse2_end = min(verse2_start + 8, total_lines)
    for i in range(verse2_start, min(verse2_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Bridge:")
    bridge_start = verse2_end
    bridge_end = min(bridge_start + 4, total_lines)
    for i in range(bridge_start, min(bridge_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Verse 3:")
    verse3_start = bridge_end
    verse3_end = min(verse3_start + 8, total_lines)
    for i in range(verse3_start, min(verse3_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Chorus:")
    for i in range(chorus_start, min(chorus_end, total_lines)):
        formatted_lyrics.append(lines[i])
    formatted_lyrics.append("")

    formatted_lyrics.append("Outro:")
    outro_start = verse3_end
    outro_end = min(outro_start + 3, total_lines)
    for i in range(outro_start, min(outro_end, total_lines)):
        formatted_lyrics.append(lines[i])

    return '\n'.join(formatted_lyrics)

def load_custom_lyrics_model(model_path="final_lyrics_model.pth"):
    try:
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file {model_path} not found")

        torch.serialization.add_safe_globals([LyricsTokenizer, MarkovChain])

        checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
        print("Successfully loaded checkpoint")

        tokenizer = checkpoint.get('tokenizer')
        if tokenizer is None:
            raise ValueError("Tokenizer not found in checkpoint")
        print("Loaded tokenizer from checkpoint")

        markov_chain = checkpoint.get('markov_chain')
        if markov_chain is None:
            raise ValueError("MarkovChain not found in checkpoint")
        print("Loaded MarkovChain from checkpoint")

        model = LyricsTransformer(
            vocab_size=tokenizer.vocab_size,
            d_model=64,
            num_heads=4,
            num_layers=2,
            d_ff=256
        )

        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print("Loaded model weights from 'model_state_dict' key")
        else:
            raise ValueError("Model state dictionary not found in checkpoint")

        model.eval()
        print(f"Model loaded successfully from {model_path}")
        return model, tokenizer, markov_chain

    except Exception as e:
        print(f"Error loading model: {e}")
        raise

def generate_lyrics(theme, model, tokenizer, markov_chain):
    try:
        prompt = themes.get(theme.lower(), "Write a beautiful song.")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        model = model.to(device)
        lyrics = generate_with_model(model, prompt, tokenizer, markov_chain, device=device, theme=theme)
        return lyrics
    except Exception as e:
        print("An error occurred during lyrics generation:", e)
        return None

def text_to_speech(lyrics, output_file="lyrics_audio.mp3"):
    try:
        # Ensure output_file is in the current working directory
        output_file = os.path.join(os.getcwd(), output_file)
        tts = gTTS(text=lyrics, lang='en')
        tts.save(output_file)
        print(f"Speech generated and saved as '{output_file}'")
        return output_file
    except Exception as e:
        print(f"Error in text-to-speech conversion: {e}")
        return None

def extend_melody_to_match_lyrics(lyrics_audio_file, melody_file, output_file):
    try:
        if not os.path.exists(melody_file):
            raise FileNotFoundError(f"Melody file '{melody_file}' not found")
        if not os.path.exists(lyrics_audio_file):
            raise FileNotFoundError(f"Lyrics audio file '{lyrics_audio_file}' not found")

        # Ensure output_file is in the current working directory
        output_file = os.path.join(os.getcwd(), output_file)
        melody = AudioSegment.from_wav(melody_file)
        lyrics = AudioSegment.from_mp3(lyrics_audio_file)
        lyrics_duration = len(lyrics)
        melody_duration = len(melody)
        if melody_duration < lyrics_duration:
            loop_count = (lyrics_duration // melody_duration) + 1
            extended_melody = melody * loop_count
            extended_melody = extended_melody[:lyrics_duration]
        else:
            extended_melody = melody[:lyrics_duration]
        melody_volume = -6
        extended_melody = extended_melody + melody_volume
        final_output = extended_melody.overlay(lyrics)
        final_output.export(output_file, format="mp3")
        print(f"Final combined song saved as '{output_file}'")
        return output_file
    except Exception as e:
        print(f"Error combining audio: {e}")
        return None

def create_fallback_model_and_tokenizer(theme="love"):
    print("\nWARNING: Creating fallback model and tokenizer for testing purposes.")
    print("This will not generate coherent lyrics without proper training.\n")
    tokenizer = LyricsTokenizer(vocab_size=15000)
    sample_words = [
        "the", "a", "and", "in", "of", "to", "is", "was", "it", "you", "i",
        "life", "time", "way", "feel", "know", "go", "see", "world", "soul", "sky",
        "write", "beautiful", "about", "relationships", "environment", "power", "importance"
    ] + THEME_WORDS.get(theme, THEME_WORDS["love"])
    tokenizer.word_to_idx = {word: idx+2 for idx, word in enumerate(sample_words)}
    tokenizer.word_to_idx["<pad>"] = 0
    tokenizer.word_to_idx["<unk>"] = 1
    tokenizer.idx_to_word = {word: idx for idx, word in tokenizer.word_to_idx.items()}
    tokenizer.vocab_size = len(tokenizer.word_to_idx)
    model = LyricsTransformer(vocab_size=tokenizer.vocab_size)
    markov_chain = MarkovChain(order=3)
    return model, tokenizer, markov_chain

def correct_theme(theme):
    """Correct potential typos in theme input using fuzzy matching."""
    theme = theme.lower().strip()
    if theme in themes:
        return theme
    possible_matches = get_close_matches(theme, themes.keys(), n=1, cutoff=0.6)
    if possible_matches:
        corrected_theme = possible_matches[0]
        print(f"Corrected theme '{theme}' to '{corrected_theme}'")
        return corrected_theme
    print(f"Theme '{theme}' not recognized. Using 'love' as default.")
    return "love"

if __name__ == "__main__":
    # Print current working directory
    print(f"Current working directory: {os.getcwd()}")

    theme = input("Enter a theme for your song (love, nature, friendship, hope, freedom): ").strip()
    theme = correct_theme(theme)  # Correct typos

    print("\nLoading custom lyrics generator model...\n")
    model = None
    tokenizer = None
    markov_chain = None
    try:
        model, tokenizer, markov_chain = load_custom_lyrics_model("final_lyrics_model.pth")
    except Exception as e:
        print(f"\nFatal error loading model: {e}")
        print("\nWould you like to create a fallback model for testing purposes?")
        response = input("This won't generate good lyrics but will test the code flow (y/n): ").strip().lower()
        if response == 'y' or response == 'yes':
            model, tokenizer, markov_chain = create_fallback_model_and_tokenizer(theme)
        else:
            print("Exiting program.")
            exit()

    print("\nGenerating Lyrics... Please wait.\n")
    lyrics = generate_lyrics(theme, model, tokenizer, markov_chain)

    if lyrics:
        print("Generated Lyrics:\n")
        print(lyrics)

        lyrics_audio_file = "lyrics_audio.mp3"
        lyrics_audio_file = text_to_speech(lyrics, lyrics_audio_file)
        if not lyrics_audio_file:
            print("Failed to generate lyrics audio. Skipping melody combination.")
        else:
            melody_file = input("\nEnter path to melody file (default: generated_melody.wav): ").strip()
            if not melody_file:
                melody_file = "generated_melody.wav"

            output_file = f"{theme}_song.mp3"
            if os.path.exists(melody_file):
                print("\nCombining lyrics with melody...\n")
                result = extend_melody_to_match_lyrics(lyrics_audio_file, melody_file, output_file)
                if result:
                    print(f"\nSong generation complete! Your song is saved as '{result}'")
                else:
                    print("\nFailed to combine lyrics with melody. Check the error above.")
            else:
                print(f"Melody file '{melody_file}' not found. Please provide a valid path.")
                print(f"Lyrics audio is saved as '{lyrics_audio_file}'")
    else:
        print("Failed to generate lyrics.")

Current working directory: D:\PROJECT\GenAI_prj\FINAL


Enter a theme for your song (love, nature, friendship, hope, freedom):  love



Loading custom lyrics generator model...

Error loading model: Model file final_lyrics_model.pth not found

Fatal error loading model: Model file final_lyrics_model.pth not found

Would you like to create a fallback model for testing purposes?
