In [1]:
import csv
import os
import pandas as pd

from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    GenerationMixin,
)
from transformers import GPT2Tokenizer, GPT2Model, GPT2PreTrainedModel, GPT2Config
from transformers import GenerationConfig
from transformers import EncodecModel, AutoProcessor

import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
import librosa

from IPython.display import Audio, display

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
audio_folder = "./LJSpeech-1.1/wavs/"
metadata_path = "./LJSpeech-1.1/metadata.csv"

In [3]:
metadata = pd.read_csv(metadata_path, sep='|', header=None, names=['ID', 'Transcription', 'Normalized_Transcription'])
print("Metadata loaded. First few rows:")
print(metadata.head())

Metadata loaded. First few rows:
           ID                                      Transcription  \
0  LJ001-0001  Printing, in the only sense with which we are ...   
1  LJ001-0002                     in being comparatively modern.   
2  LJ001-0003  For although the Chinese took impressions from...   
3  LJ001-0004  produced the block books, which were the immed...   
4  LJ001-0005  the invention of movable metal letters in the ...   

                            Normalized_Transcription  
0  Printing, in the only sense with which we are ...  
1                     in being comparatively modern.  
2  For although the Chinese took impressions from...  
3  produced the block books, which were the immed...  
4  the invention of movable metal letters in the ...  


In [4]:
def load_audio_file(audio_id, sr=16000):
    path = os.path.join(audio_folder, f"{audio_id}.wav")
    audio, _ = librosa.load(path, sr=sr)
    return audio


def play_audio(audio_transcription_pairs, index, sr=16000):
    print(audio_transcription_pairs[index][1])
    audio = audio_transcription_pairs[index][0]
    display(Audio(audio, rate=sr))

In [5]:
audio_transcription_pairs = []
for _, row in tqdm(metadata.iterrows()):
    audio = load_audio_file(row['ID'])
    transcription = row['Transcription']
    audio_transcription_pairs.append((audio, transcription))

13100it [01:04, 202.07it/s]


In [6]:
play_audio(audio_transcription_pairs, 0)

Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition


# Modeling

In [7]:
class GPT2TTS(GPT2PreTrainedModel, GenerationMixin):
    def __init__(self, n_special_tokens=2, **kwargs):
        super().__init__(config=GPT2Config())
        self.h = None

        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # 124M
        self.model = GPT2Model.from_pretrained('gpt2')

        self.codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
        self.processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")

        self.loss_fn = nn.CrossEntropyLoss()

        self.n_special_tokens = n_special_tokens

        # we will use the same embedding table for all codebooks (2)
        self.codebook_size = self.codec.config.codebook_size + self.n_special_tokens
        self.audio_emb = nn.Embedding(self.codebook_size, self.model.config.hidden_size)

        self.lm_head = nn.Linear(self.model.config.hidden_size, self.codebook_size, bias=False)

        self.bos_token_id = self.codebook_size - 2
        self.eos_token_id = self.codebook_size - 1

    def can_generate(self):
        return True

    def preprocess_text_and_audio(self, text: List, audio: List) -> tuple[List, List]:

        B = len(text)
        device = audio[0].device if len(audio) > 0 and isinstance(audio[0], torch.Tensor) else 'cpu'

        text_tokens = []
        audio_tokens = []

        # tokenize text with gpt2 tokenizer, tokenize audio with encodec and flatten
        # audio tokens sequences from [2, T] to [2 * T] such that tokens from two codebooks neighboring:
        # (0, 0), (1 ,0), (0, 1), (1, 1), (0, 2), (1, 2), (0, 3), (1, 3) etc

        # ↓↓↓ YOUR CODE HERE ↓↓↓
        for i in range(B):
            encoded_text = self.tokenizer(text[i], return_tensors='pt')
            text_tokens.append(encoded_text['input_ids'][0].to(device))

            if isinstance(audio[i], torch.Tensor) and audio[i].numel() > 0:
                audio_sample = audio[i].cpu().numpy()
                if hasattr(audio[i], 'sample_rate') and audio[i].sample_rate != self.processor.sampling_rate:
                    audio_sample = librosa.resample(
                        audio_sample,
                        orig_sr=audio[i].sample_rate,
                        target_sr=self.processor.sampling_rate
                    )

                inputs = self.processor(
                    raw_audio=audio_sample,
                    sampling_rate=self.processor.sampling_rate,
                    return_tensors="pt"
                )
                encoder_outputs = self.codec.encode(
                    inputs["input_values"].to(device),
                    inputs["padding_mask"].to(device) if "padding_mask" in inputs else None
                )

                codes = encoder_outputs.audio_codes[:, :2, :].to(device)  # [B, 2, T]

                # Flatten the codes to get [2*T] as described in the task
                # We want (0,0), (1,0), (0,1), (1,1), etc.
                flattened_codes = torch.zeros(2 * codes.shape[2], dtype=torch.long, device=device)
                flattened_codes[0::2] = codes[0, 0, :]
                flattened_codes[1::2] = codes[0, 1, :]

                audio_tokens.append(flattened_codes)
            else:
                # Empty placeholder for demonstration
                audio_tokens.append(torch.tensor([], device=device, dtype=torch.long))
        # ↑↑↑ YOUR CODE HERE ↑↑↑

        return text_tokens, audio_tokens

    def forward(
        self,
        text: List = None,
        audio: List = None,
        return_loss: bool = False,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        if not return_loss:
            # inference mode
            if kwargs.get("input_ids") is not None:
                kwargs["inputs_embeds"] = self.audio_emb(kwargs["input_ids"])
                kwargs.pop("input_ids")
            outputs = self.model(**kwargs)

            hidden_states = outputs.last_hidden_state
            outputs.logits = self.lm_head(hidden_states).float()

            return outputs

        text_tokens, audio_tokens  = \
            self.preprocess_text_and_audio(text, audio)

        B = len(text_tokens)
        H = self.model.config.hidden_size
        device = text_tokens[0].device

        L = max(len(text_tokens[i]) + len(audio_tokens[i]) + 2 for i in range(B))

        bos_token = self.audio_emb.weight[self.bos_token_id, :].unsqueeze(0)
        eos_token = self.audio_emb.weight[self.eos_token_id, :].unsqueeze(0)

        x = torch.zeros(B, L, H, device=device)

        for i in range(B):
            end_text = len(text_tokens[i])
            end_audio = end_text + 1 + len(audio_tokens[i])

            x[i, 0:end_text, :] = self.model.wte(text_tokens[i])
            x[i, end_text : end_text + 1, :] = bos_token
            x[i, end_text + 1 : end_audio, :] = self.audio_emb(audio_tokens[i])
            x[i, end_audio : end_audio + 1, :] = eos_token

        outputs = self.model(
            inputs_embeds=x,
            return_dict=True,
        )
        hidden_states = outputs.last_hidden_state
        logits = self.lm_head(hidden_states).float()

        loss = None
        labels = torch.full((B, L), -100, device=device, dtype=torch.long)
        for i in range(B):
            prompt_end = len(text_tokens[i])
            audio_end = prompt_end + len(audio_tokens[i])
            labels[i, prompt_end : audio_end] = audio_tokens[i]
            labels[i, audio_end] = self.eos_token_id

        loss = self.loss_fn(
            logits.view(-1, self.codebook_size), labels.view(-1)
        )

        return {"loss": loss, "logits": logits}

    def prepare_prompt(self, text: List[str], **kwargs):
        assert len(text) == 1, "Inference supports batch size 1 only"

        # implement prompt creation from text
        # you need to tokenize text, create text embeddings and add bos token
        # from audio embeddings

        # ↓↓↓ YOUR CODE HERE ↓↓↓
        device = next(self.parameters()).device

        encoded_text = self.tokenizer(text[0], return_tensors='pt')
        text_tokens = encoded_text['input_ids'][0].to(device)

        text_embeddings = self.model.wte(text_tokens)

        bos_token_embedding = self.audio_emb.weight[self.bos_token_id, :].unsqueeze(0)

        inputs_embeds = torch.cat([text_embeddings, bos_token_embedding], dim=0)

        inputs_embeds = inputs_embeds.unsqueeze(0)
        # ↑↑↑ YOUR CODE HERE ↑↑↑

        return inputs_embeds

    def pad_audio_tokens(self, audio_tokens: torch.Tensor, n_q: int, pad_id: int = 0):
        seq_len = audio_tokens.size(-1)
        remainder = seq_len % n_q
        if remainder != 0:
            pad_len = n_q - remainder
            pad = torch.full((1, pad_len), pad_id, dtype=audio_tokens.dtype, device=audio_tokens.device)
            audio_tokens = torch.cat([audio_tokens, pad], dim=-1)
        return audio_tokens


    def decode(self, audio_tokens: torch.Tensor):
        # unflatten audio_tokens back into 2 codebooks
        # if length is odd drop last token
        # decode with codec.decode into the waveform

        # ↓↓↓ YOUR CODE HERE ↓↓↓
        if len(audio_tokens.shape) == 1:
            audio_tokens = audio_tokens.unsqueeze(0)  # add batch dimension if missing

        # Remove the last token if the length is odd
        if audio_tokens.size(-1) % 2 != 0:
            audio_tokens = audio_tokens[..., :-1]

        # Reshape from [B, 2*T] to [B, 2, T]
        codes = audio_tokens.view(audio_tokens.size(0), 2, -1)

        # Add the required dimensions for Encodec (batch, codebooks, time)
        codes = codes.unsqueeze(0)  # [1, B, 2, T]

        # Decode with Encodec
        with torch.no_grad():
            waveforms = self.codec.decode(codes, torch.ones_like(codes))

        # Return the first (and only) batch of waveforms
        waveform = waveforms[0]
        # ↑↑↑ YOUR CODE HERE ↑↑↑

        return waveform

In [8]:
tts = GPT2TTS()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [9]:
# train call
tts.forward(["asdf", "asdf"], [torch.tensor([]), torch.tensor([])], return_loss=True)

{'loss': tensor(6.8703, grad_fn=<NllLossBackward0>),
 'logits': tensor([[[ 0.3005, -0.4037, -1.3916,  ..., -2.9117,  2.9984,  0.5271],
          [ 2.7070, -1.5329, -4.6891,  ..., -8.8858,  7.9476,  1.5672],
          [-0.1389, -0.0707, -0.6594,  ..., -0.9649,  1.3987,  0.2044],
          [-0.1170, -0.1072, -0.5582,  ..., -1.0209,  1.3744,  0.1260]],
 
         [[ 0.3005, -0.4037, -1.3916,  ..., -2.9117,  2.9984,  0.5271],
          [ 2.7070, -1.5329, -4.6891,  ..., -8.8858,  7.9476,  1.5672],
          [-0.1389, -0.0707, -0.6594,  ..., -0.9649,  1.3987,  0.2044],
          [-0.1170, -0.1072, -0.5582,  ..., -1.0209,  1.3744,  0.1260]]],
        grad_fn=<UnsafeViewBackward0>)}

In [10]:
text = "The quick brown fox jumps over the lazy dog."

sampling_params = {
    "temperature": 0.4,
    "repetition_penalty": 1.25,
    "top_p": 0.8,
    "do_sample": True
}

inputs_embeds = tts.prepare_prompt([text])
audio_tokens = tts.generate(
    inputs_embeds=inputs_embeds,
    generation_config=GenerationConfig(
        **sampling_params,
        bos_token_id=tts.bos_token_id,
        eos_token_id=tts.eos_token_id,
        pad_token_id=tts.eos_token_id,
    ),
)
waveform = tts.decode(audio_tokens)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [15]:
audio_tokens

tensor([[911, 988, 792, 197, 137, 842, 776, 615, 395]])

In [None]:
class TTSDataset(Dataset):
    def __init__(self, texts, audio_files, tokenizer, processor, codec, max_length=1024):
        self.texts = texts
        self.audio_files = audio_files
        self.tokenizer = tokenizer
        self.processor = processor
        self.codec = codec
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        text_tokens = self.tokenizer(
            text,
            return_tensors='pt',
            max_length=self.max_length,
            truncation=True,
            padding='max_length'
        )['input_ids'][0]

        audio = self.audio_files[idx]
        if isinstance(audio, str):
            # If audio is a filepath
            import torchaudio
            audio, sr = torchaudio.load(audio)
            if sr != self.processor.sampling_rate:
                audio = torchaudio.functional.resample(audio, sr, self.processor.sampling_rate)
            audio = audio.squeeze(0)

        inputs = self.processor(
            raw_audio=audio,
            sampling_rate=self.processor.sampling_rate,
            return_tensors="pt"
        )

        encoder_outputs = self.codec.encode(
            inputs["input_values"],
            inputs.get("padding_mask", None)
        )

        audio_tokens = encoder_outputs.audio_codes[0, :self.max_length]

        return {
            'text_tokens': text_tokens,
            'audio_tokens': audio_tokens,
            'text': text,
            'audio': audio
        }


In [35]:
class GPT2TTSTrainer(pl.LightningModule):
    def __init__(self, model: GPT2TTS, lr: float = 1e-4):
        super().__init__()
        self.model = model
        self.lr = lr

    def training_step(self, batch, batch_idx):
        text, audio = batch['text'], batch['audio']
        print(f"\n[Batch {batch_idx}]")
        print(f"Text: {text}")
        print(f"Audio: {audio}")
        out = self.model(text=text, audio=audio, return_loss=True)
        self.log("train_loss", out["loss"])
        return out["loss"]

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=self.lr)

In [36]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")

In [37]:
from torch.utils.data import DataLoader

def train_model():
    texts = []
    audio_files = []
    for audio, text in audio_transcription_pairs:
        texts.append(text)
        audio_files.append(audio)

    model = GPT2TTS()
    dataset = TTSDataset(texts, audio_files, tokenizer, processor, codec)
    train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

    trainer = pl.Trainer(
        max_epochs=10,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        log_every_n_steps=10
    )

    trainer.fit(GPT2TTSTrainer(model), train_loader)

In [38]:
train_model()

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | GPT2TTS | 140 M  | train
------------------------------------------
140 M     Trainable params
0         Non-trainable params
140 M     Total params
563.470   Total estimated model params size (MB)
4         Modules in train mode
459       Modules in eval mode
e:\UCU Third Year\Term 2\Audio Processing\Audio_Processing_Labs\Lab5\venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 0/3275 [00:00<?, ?it/s] 

RuntimeError: stack expects each tensor to be equal size, but got [25] at entry 0 and [7] at entry 1

In [None]:
def generate_audio(model, text, max_length=500, strategy="greedy", temperature=1.0, top_k=50, top_p=0.9):
    model.eval()

    inputs_embeds = model.prepare_prompt([text])

    generation_config = {
        "max_length": max_length,
        "eos_token_id": tts.eos_token_id,
        "pad_token_id": 0,
        "do_sample": strategy != "greedy",
        "temperature": temperature,
        "top_k": top_k if strategy == "top_k" else None,
        "top_p": top_p if strategy == "nucleus" else None,
    }

    with torch.no_grad():
        outputs = model.generate(
            inputs_embeds=inputs_embeds,
            **generation_config
        )

    prompt_len = inputs_embeds.shape[1]
    audio_tokens = outputs[0, prompt_len:]

    eos_pos = (audio_tokens == model.eos_token_id).nonzero()
    if len(eos_pos) > 0:
        audio_tokens = audio_tokens[:eos_pos[0]]

    waveform = model.decode(audio_tokens)
    return waveform

texts = [
    "This is a test sentence for TTS evaluation.",
    "The quick brown fox jumps over the lazy dog.",
    "Hello world, this is a text-to-speech system."
]

# strategies = ["greedy", "top_k", "nucleus"]
strategies = ["greedy", "top_k"]
res = []
for text in texts:
    for strategy in strategies:
        audio = generate_audio(tts, text, strategy=strategy)
        audio_np = audio.detach().cpu().numpy().squeeze()
        res.append((audio_np, text, strategy))
        print(audio_np.shape)
        display(Audio(audio_np, rate=processor.sampling_rate))
        print(f"Generated audio for: '{text}' with {strategy} sampling")

In [None]:
display(Audio(res[0][0][2], rate=processor.sampling_rate))

In [None]:
np.all(res[0][0][1] == res[0][0][90])

np.True_

In [None]:
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer

model = GPT2TTS()
pl_model = GPT2TTSTrainer(model)

train_set = TTSDataset(train_texts, train_audios)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)

trainer = Trainer(accelerator='gpu', devices=1, max_epochs=10)
trainer.fit(pl_model, train_loader)
