This cell handles the initial setup and imports all necessary dependencies for the SSL-TTS framework:
- Clones the TTS repository from coqui-ai
- Installs required packages: TTS, transformers, torchaudio
- Imports core deep learning libraries (torch, torchaudio)
- Imports WavLM model for SSL feature extraction
- Imports GlowTTS components for the text-to-SSL model
- Sets up other essential utilities like torch.nn.functional

In [None]:
!git clone https://github.com/coqui-ai/TTS.git
%cd TTS


!pip install TTS transformers torchaudio
!pip install mutagen
# DO NOT RESTART RUNTIME AFTER RUNNING THIS CELL
# YOU MIGHT HAVE A FEW WARNINGS/ERROR BUT DW IT'S FINE

Cloning into 'TTS'...
remote: Enumerating objects: 32844, done.[K
remote: Counting objects: 100% (3521/3521), done.[K
remote: Compressing objects: 100% (126/126), done.[K
remote: Total 32844 (delta 3433), reused 3395 (delta 3395), pack-reused 29323 (from 1)[K
Receiving objects: 100% (32844/32844), 166.12 MiB | 12.86 MiB/s, done.
Resolving deltas: 100% (23858/23858), done.
/content/TTS
Collecting TTS
  Downloading TTS-0.22.0-cp310-cp310-manylinux1_x86_64.whl.metadata (21 kB)
Collecting anyascii>=0.3.0 (from TTS)
  Downloading anyascii-0.3.2-py3-none-any.whl.metadata (1.5 kB)
Collecting pysbd>=0.3.4 (from TTS)
  Downloading pysbd-0.3.4-py3-none-any.whl.metadata (6.1 kB)
Collecting umap-learn>=0.5.1 (from TTS)
  Downloading umap_learn-0.5.7-py3-none-any.whl.metadata (21 kB)
Collecting pandas<2.0,>=1.4 (from TTS)
  Downloading pandas-1.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting trainer>=0.0.32 (from TTS)
  Downloading trainer-0.0.36-py3-n

Collecting mutagen
  Downloading mutagen-1.47.0-py3-none-any.whl.metadata (1.7 kB)
Downloading mutagen-1.47.0-py3-none-any.whl (194 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/194.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.9/194.4 kB[0m [31m3.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.4/194.4 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mutagen
Successfully installed mutagen-1.47.0


In [None]:
# Import required modules
import torch
import torchaudio
import torch.optim as optim
from transformers import WavLMModel
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
from torch import nn
from google.colab import drive
import torch.optim as optim
from datetime import datetime
import json

# Import the GlowTTS config and model
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.models.glow_tts import GlowTTS



import torch.nn.functional as F
from typing import Tuple, Dict

import pandas as pd

import multiprocessing
multiprocessing.set_start_method("spawn", force=True)



# SSL Encoder Implementation (WavLM Integration)

Implements the Self-Supervised Learning encoder component using WavLM-Large:

### Key Features
1. Model Initialization:
   - Loads WavLM-Large model from HuggingFace
   - Automatically selects GPU if available
   - Sets model to evaluation mode

2. Feature Extraction:
   - Uses WavLM's 6th layer for optimal feature representation
   - Handles automatic resampling to 16kHz
   - Manages proper tensor dimensions and device placement
   - Outputs 1024-dimensional feature vectors

3. Audio Processing:
   - Supports variable length inputs
   - Handles mono/stereo conversion
   - Implements automatic batching

### Technical Details
- Input: Audio waveform tensor [B, T] or [1, T]
- Output: SSL features [B, T', 1024]
- Uses @torch.no_grad() for efficient inference
- Includes sample rate verification and conversion




In [None]:
class SSLEncoder:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        print(f"Loading WavLM model to {device}...")
        self.model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
        self.model.eval()
        print("WavLM model loaded successfully!")

    @torch.no_grad()
    def extract_features(self, waveform, sample_rate=16000):
        """Extract WavLM features from the 6th layer"""
        # Resample if sample rate is not 16000 Hz
        if sample_rate != 16000:
            waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)

        # Ensure waveform is properly batched
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)

        # Move waveform to the specified device
        waveform = waveform.to(self.device)
        outputs = self.model(waveform, output_hidden_states=True)

        # Extract features from the 6th layer
        features = outputs.hidden_states[6]
        return features

'''# Example usage
ssl_encoder = SSLEncoder()

# Load a sample audio file (replace 'path_to_audio_file.wav' with the actual file path)
waveform, sample_rate = torchaudio.load('/content/harvard.wav')

# Extract features
features = ssl_encoder.extract_features(waveform, sample_rate)
print("Extracted features shape:", features.shape)'''


'# Example usage\nssl_encoder = SSLEncoder()\n\n# Load a sample audio file (replace \'path_to_audio_file.wav\' with the actual file path)\nwaveform, sample_rate = torchaudio.load(\'/content/harvard.wav\')\n\n# Extract features\nfeatures = ssl_encoder.extract_features(waveform, sample_rate)\nprint("Extracted features shape:", features.shape)'

# Text-to-SSL Model (GlowTTS Adaptation)

Implements the text-to-SSL conversion using a modified GlowTTS architecture:

### Architecture Overview
1. Configuration Setup:
   - num_chars: 148 for English character set
   - out_channels: 1024 to match WavLM features
   - hidden_channels: 192 for encoder/decoder
   - encoder_type: "rel_pos_transformer"

2. Model Components:
   - Transformer-based text encoder
   - Duration predictor
   - Flow-based decoder
   - Speaker-independent design

### Key Features
- Non-autoregressive architecture
- Flow-based feature generation
- Duration prediction for proper alignment
- Batch processing support
- Device-agnostic implementation

In [None]:
### TEST TEST TEST TEST TEST TEST TEST TEST
class TextToSSL(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize GlowTTS config with modified output size
        config = GlowTTSConfig(
            num_chars=128,  # Standard English character set size
            out_channels=1024,  # Match WavLM-Large output dimension
        )

        # Initialize GlowTTS model with modified config
        self.glow_tts = GlowTTS(config)

    def forward(self, text, text_lengths, y=None, gen=False):
        """
        Forward pass through the Text-to-SSL model

        Args:
            text (torch.Tensor): Text input tensor [B, T]
            text_lengths (torch.Tensor): Length of each text sequence [B]
            y (torch.Tensor, optional): Target SSL features [B, T, 1024]. Used only during training
            gen (bool): Whether in generation mode or not

        Returns:
            During training (gen=False):
                Dictionary containing:
                - z: Transformed features
                - z_m: Mean of transformed features
                - z_logs: Log standard deviation of transformed features
                - logdet: Log determinant of transformation
                - z_mask: Mask for valid frames
                - logw: Log durations
                - logw_: Target log durations
                - attn: Attention alignments
            During inference (gen=True):
                Dictionary containing:
                - ssl_features: Generated SSL features
                - attn: Attention alignments
        """
        if gen:
            return self.generate(text, text_lengths)

        # Calculate y_lengths if y is provided
        if y is not None:
            y_lengths = torch.tensor([y.size(1)] * y.size(0), dtype=torch.long, device=y.device)
        else:
            y_lengths = None

        # Process through GlowTTS
        outputs = self.glow_tts(
            text,
            text_lengths,
            y,
            y_lengths,
            aux_input={"d_vectors": None, "speaker_ids": None}
        )

        # Return all components needed for MLE loss
        return {
            'z': outputs.get('z', None),
            'z_m': outputs.get('z_m', None),
            'z_logs': outputs.get('z_logs', None),
            'logdet': outputs.get('logdet', None),
            'z_mask': outputs.get('z_mask', None),
            'logw': outputs.get('logw', None),
            'logw_': outputs.get('logw_', None),
            'attn': outputs.get('alignments', None),
            'ssl_features': outputs.get('y_mean', None)  # For compatibility with previous code
        }

    def generate(self, text, text_lengths):
        """
        Generate SSL features from text input during inference
        """
        outputs = self.glow_tts.inference(
            text,
            aux_input={"x_lengths": text_lengths}
        )

        return {
            'ssl_features': outputs['model_outputs'],
            'attn': outputs['alignments']
        }

In [None]:
class TextToSSL(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize GlowTTS config with modified output size
        config = GlowTTSConfig(
            num_chars=128,  # Standard English character set size
            out_channels=1024,  # Match WavLM-Large output dimension
        ) # Mostly default configs

        # Initialize GlowTTS model with modified config
        self.glow_tts = GlowTTS(config)

    def forward(self, text, text_lengths, y=None):
        """
        Forward pass through the Text-to-SSL model

        Args:
            text (torch.Tensor): Text input tensor [B, T]
            text_lengths (torch.Tensor): Length of each text sequence [B]
            y (torch.Tensor, optional): Target SSL features [B, T, 1024]. Used only during training.

        Returns:
            Dict containing:
                - ssl_features: Predicted SSL features [B, T, 1024]
                - alignments: Alignment matrix between text and features
                - durations_log: Log durations for each input token
        """
        if y is not None:
            y_lengths = torch.tensor([y.size(1)] * y.size(0), dtype=torch.long).to(y.device)  # Calculate lengths of y
        else:
            y_lengths = None

        # Process through GlowTTS
        outputs = self.glow_tts(
            text,
            text_lengths,
            y,  # Pass the target features
            y_lengths,  # Pass the lengths of the target features
            aux_input={"d_vectors": None, "speaker_ids": None}
        )

        return {
            "ssl_features": outputs["y_mean"],  # [B, T, 1024]
            "alignments": outputs["alignments"],
            "durations_log": outputs["durations_log"]
        }



    def generate(self, text, text_lengths):
        """
        Generate SSL features from text input during inference

        Args:
            text (torch.Tensor): Text input tensor [B, T]
            text_lengths (torch.Tensor): Length of each text sequence [B]

        Returns:
            Dict containing:
                - ssl_features: Generated SSL features [B, T, 1024]
                - alignments: Alignment matrix between text and features
        """
        outputs = self.glow_tts.inference(
            text,
            aux_input={"x_lengths": text_lengths}
        )

        return {
            "ssl_features": outputs["model_outputs"],  # [B, T, 1024]
            "alignments": outputs["alignments"]
        }


## GlowwTTS Training

In [None]:
# Download the LJSpeech dataset
!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2

# Extract the dataset
!tar -xjf LJSpeech-1.1.tar.bz2

# Verify the extraction by listing the contents
!ls LJSpeech-1.1


--2024-11-13 22:39:27--  https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
Resolving data.keithito.com (data.keithito.com)... 24.199.73.137
Connecting to data.keithito.com (data.keithito.com)|24.199.73.137|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2748572632 (2.6G) [text/plain]
Saving to: ‘LJSpeech-1.1.tar.bz2’


2024-11-13 22:39:50 (116 MB/s) - ‘LJSpeech-1.1.tar.bz2’ saved [2748572632/2748572632]

metadata.csv  README  wavs


### Data Prep / Pre-process

This code sets up the data processing pipeline for training SSL-TTS. It performs three key operations:

1. **Audio Loading and Resampling**
   - Loads audio files from LJSpeech dataset
   - Resamples them from 22.05kHz to 16kHz (required by WavLM)

2. **Feature Extraction**
   - Uses WavLM to convert raw audio into high-level speech features
   - Instead of using mel-spectrograms, we get 1024-dimensional WavLM features
   - These features contain rich information about speech content and speaker characteristics

3. **Batch Processing**
   - Handles variable-length audio files by padding them to the same length
   - Creates batches of features and their corresponding text transcriptions
   - Makes the data ready for training the GlowTTS model

This pipeline transforms raw audio into the format needed for training our SSL-TTS system, where GlowTTS will learn to predict WavLM features from text.

In [None]:
# Custom Dataset class for LJSpeech with resampling and feature extraction
class LJSpeechDataset(Dataset):
    def __init__(self, root_dir, metadata_file, ssl_encoder, transform=None):
        self.root_dir = root_dir
        self.metadata = pd.read_csv(metadata_file, sep="|", header=None, names=["file", "text", "normalized_text"])
        self.resampler = torchaudio.transforms.Resample(orig_freq=22050, new_freq=16000)
        self.ssl_encoder = ssl_encoder # WavLM
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Get the audio file path
        wav_file = os.path.join(self.root_dir, "wavs", self.metadata.iloc[idx, 0] + ".wav")

        # Load the audio file
        waveform, sample_rate = torchaudio.load(wav_file)

        # Resample the audio to 16 kHz if needed
        if sample_rate != 16000:
            waveform = self.resampler(waveform)

        # Extract WavLM features
        features = self.ssl_encoder.extract_features(waveform)

        # Get the corresponding text
        text = self.metadata.iloc[idx, 1]

        # Apply any specified transformations
        if self.transform:
            waveform = self.transform(waveform)

        return features, text  # Return WavLM features and the text

# Custom collate function to handle variable-length waveforms
def collate_fn(batch):
    # Separate the components of the batch
    features, texts = zip(*batch)

    # Find the maximum length in the current batch
    max_length = max(feature.size(1) for feature in features)

    # Pad the features to the maximum length and remove the extra channel dimension
    padded_features = [torch.nn.functional.pad(feature.squeeze(0), (0, 0, 0, max_length - feature.size(1))) for feature in features]

    # Stack the features into a single tensor
    features_tensor = torch.stack(padded_features)

    return features_tensor, texts

# Initialize the SSLEncoder
ssl_encoder = SSLEncoder()

# Directory and file paths
root_dir = "LJSpeech-1.1"
metadata_file = os.path.join(root_dir, "metadata.csv")

# Initialize the dataset and data loader
dataset = LJSpeechDataset(root_dir, metadata_file, ssl_encoder)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, collate_fn=collate_fn)

# Example: Iterate through the data loader and print a sample
for features, text in data_loader:
    print("Features shape:", features.shape)
    print("Text:", text)
    break

Loading WavLM model to cuda...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/2.22k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

WavLM model loaded successfully!


model.safetensors:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Features shape: torch.Size([4, 414, 1024])
Text: ('The Bureau had no earlier information suggesting that Oswald had left the United States.', 'according to the discretion of the court before whom the prisoners might be tried.', 'Agent Fain retired from the FBI in October 1962, and the closed Oswald case was not reassigned.', 'could not be carried out till then. It is to be feared that long after the opening of White Cross Street prison,')


In [None]:
print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3)} GB")

### Training Cell 1
Saves only on runtime, no checkpoints.

In [None]:
# Initialize the TextToSSL model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TextToSSL().to(device=device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop with a fixed number of steps
num_steps = 650000  # Total number of steps
current_step = 0

for epoch in range(1000000):  # Large number to ensure we cover all steps
    model.train()  # Set the model to training mode
    running_loss = 0.0  # Reset running loss at the start of each epoch

    for features, text in data_loader:
        if current_step >= num_steps:
            break

        # Move features to device
        features = features.to(device=device)

        # Convert text to tensor (pad sequences to the same length)

        max_length = max(len(t) for t in text)
        text_tensor = torch.tensor(
            [[ord(c) for c in t.ljust(max_length)] for t in text],  # Pad with spaces (or any padding character)
            dtype=torch.long
        ).to(device=device)

        max_index = 128 - 1
        text_tensor = torch.clamp(text_tensor, max=max_index)

        # Calculate text lengths
        text_lengths = torch.tensor([len(t) for t in text], dtype=torch.long).to(device=device)



        # Forward pass
        optimizer.zero_grad()
        outputs = model(text_tensor, text_lengths, features)  # Pass features as the target for training

        # Adjust for sequence length mismatch
        predicted_features = outputs["ssl_features"]
        target_length = features.size(1)
        predicted_length = predicted_features.size(1)

        if predicted_length > target_length:
            # Truncate the predicted features
            predicted_features = predicted_features[:, :target_length, :]
        elif predicted_length < target_length:
            # Truncate the target features
            features = features[:, :predicted_length, :]

        # Compute the loss between the predicted and actual WavLM features
        loss = criterion(predicted_features, features)
        loss.backward()  # Backpropagation
        optimizer.step()  # Update the model parameters

        running_loss += loss.item()  # Accumulate the loss
        current_step += 1  # Increment the step count

        if current_step % 100 == 0:
            print(f"Step {current_step}/{num_steps}, Loss: {loss.item()}")

    # Calculate the average loss for the epoch
    average_loss = running_loss / len(data_loader)
    print(f"Epoch {epoch+1}, Average Loss: {average_loss}")

    if current_step >= num_steps:
        break

print("Training completed!")
torch.save(model.state_dict(), 'GlowTTS_state.pth')

Step 100/650000, Loss: 10.282491683959961
Step 200/650000, Loss: 9.229987144470215
Step 300/650000, Loss: 10.324702262878418
Step 400/650000, Loss: 10.808058738708496
Step 500/650000, Loss: 11.408522605895996
Step 600/650000, Loss: 11.445104598999023
Step 700/650000, Loss: 8.919713973999023
Step 800/650000, Loss: 11.184569358825684
Step 900/650000, Loss: 8.18460750579834
Step 1000/650000, Loss: 10.558945655822754
Step 1100/650000, Loss: 11.150132179260254
Step 1200/650000, Loss: 9.41183853149414
Step 1300/650000, Loss: 9.975044250488281
Step 1400/650000, Loss: 11.638545989990234
Step 1500/650000, Loss: 12.243314743041992
Step 1600/650000, Loss: 10.34382438659668
Step 1700/650000, Loss: 11.027650833129883
Step 1800/650000, Loss: 10.62975025177002
Step 1900/650000, Loss: 9.639217376708984
Step 2000/650000, Loss: 12.009521484375
Step 2100/650000, Loss: 11.1110200881958
Step 2200/650000, Loss: 8.596921920776367
Step 2300/650000, Loss: 11.377479553222656
Step 2400/650000, Loss: 9.7774791717

### Training Cell 2

Checkpoints and saves on google drive. It also can start from somewhere, it'll grap the file from google drive and start from there (see on driver cell below).

In [None]:
class CheckpointManager:
    def __init__(self, model, optimizer, save_dir, save_interval=10000):
        """
        Initialize checkpoint manager
        Args:
            model: The model to save
            optimizer: The optimizer to save
            save_dir: Directory path in Google Drive to save checkpoints
            save_interval: Save checkpoint every N steps
        """
        self.model = model
        self.optimizer = optimizer
        self.save_dir = save_dir
        self.save_interval = save_interval

        # Create save directory if it doesn't exist
        os.makedirs(self.save_dir, exist_ok=True)

        # Initialize training history
        self.history = {
            'steps': [],
            'losses': [],
            'timestamp': []
        }

    def save_checkpoint(self, step, loss):
        """Save checkpoint and update history"""
        if step % self.save_interval == 0:
            # Create checkpoint filename with timestamp
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            filename = f'GlowTTS_checkpoint_step_{step}_{timestamp}.pth'
            filepath = os.path.join(self.save_dir, filename)

            # Save checkpoint with all necessary information
            checkpoint = {
                'step': step,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': loss
            }

            try:
                # Save checkpoint
                torch.save(checkpoint, filepath)
                print(f"\nCheckpoint saved successfully at step {step}: {filepath}")

                # Update history
                self.history['steps'].append(step)
                self.history['losses'].append(loss)
                self.history['timestamp'].append(timestamp)

                # Save history to JSON
                history_file = os.path.join(self.save_dir, 'training_history.json')
                with open(history_file, 'w') as f:
                    json.dump(self.history, f, indent=4)

            except Exception as e:
                print(f"\nError saving checkpoint at step {step}: {str(e)}")

def train_model_with_checkpoints(model, data_loader, save_dir, num_steps=650000, save_interval=10000,
                               device=None, resume_from=None):

   ''' # Mount Google Drive
    drive.mount('/content/drive')'''
    """
    Training function with optimizer settings matching the Glow-TTS paper
    """
    # Setup device
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device=device)

    # Initialize optimizer with RAdam as used in paper
    from torch.optim import RAdam
    optimizer = RAdam(
        model.parameters(),
        lr=1e-3,
        betas=(0.9, 0.998),
        eps=1e-9,
        weight_decay=1e-6
    )

    # Noam learning rate scheduler
    def noam_learning_rate_decay(init_lr, global_step, warmup_steps=4000):
        step = global_step + 1.
        lr = init_lr * min(step ** -0.5, step * warmup_steps ** -1.5)
        return lr

    # Initialize tracking variables
    current_step = 0
    best_loss = float('inf')
    running_loss = 0.0
    loss_window = []

    # Resume from checkpoint if specified
    if resume_from is not None:
        print(f"Loading checkpoint from {resume_from}")
        checkpoint = torch.load(resume_from)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        current_step = checkpoint['step']
        if 'best_loss' in checkpoint:
            best_loss = checkpoint['best_loss']
        print(f"Resumed from step {current_step}")

    # Initialize checkpoint manager
    checkpoint_manager = CheckpointManager(
        model=model,
        optimizer=optimizer,
        save_dir=save_dir,
        save_interval=save_interval
    )

    # Training loop
    model.train()
    for epoch in range(1000000):
        for features, text in data_loader:
            if current_step >= num_steps:
                break

            # Adjust learning rate using Noam decay
            lr = noam_learning_rate_decay(1e-3, current_step)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # Process inputs
            features = features.to(device=device)

            # Convert text to tensor
            max_length = max(len(t) for t in text)
            text_tensor = torch.tensor(
                [[ord(c) for c in t.ljust(max_length)] for t in text],
                dtype=torch.long
            ).to(device=device)

            max_index = 128 - 1
            text_tensor = torch.clamp(text_tensor, max=max_index)

            text_lengths = torch.tensor([len(t) for t in text], dtype=torch.long).to(device=device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(text_tensor, text_lengths, features)

            # Handle sequence length mismatch
            predicted_features = outputs["ssl_features"]
            target_length = features.size(1)
            predicted_length = predicted_features.size(1)

            if predicted_length > target_length:
                predicted_features = predicted_features[:, :target_length, :]
            elif predicted_length < target_length:
                features = features[:, :predicted_length, :]

            # Compute loss - try to use negative log-likelihood if available
            if "log_likelihood" in outputs:
                loss = -outputs["log_likelihood"].mean()
            else:
                # Scale MSE loss to be in similar range as NLL would be
                mse_loss = torch.nn.functional.mse_loss(predicted_features, features)
                loss = mse_loss * 1000  # Scale factor might need tuning

            if torch.isnan(loss):
                print("Warning: NaN loss detected")
                continue

            # Backward pass
            loss.backward()

            # Gradient clipping
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

            # Skip step if gradient norm is too high
            if grad_norm > 1000:
                print(f"Warning: High gradient norm {grad_norm}, skipping step")
                continue

            optimizer.step()

            # Update tracking
            running_loss += loss.item()
            loss_window.append(loss.item())
            if len(loss_window) > 100:
                loss_window.pop(0)

            current_step += 1

            # Logging
            if current_step % 250 == 0:
                moving_avg_loss = sum(loss_window) / len(loss_window)
                print(f"Step {current_step}/{num_steps}, "
                      f"Loss: {loss.item():.4f}, "
                      f"Moving Avg Loss: {moving_avg_loss:.4f}, "
                      f"Gradient Norm: {grad_norm:.2f}, "
                      f"LR: {lr:.6f}")

                checkpoint_manager.save_checkpoint(current_step, loss.item())

                # Save best model
                if moving_avg_loss < best_loss:
                    best_loss = moving_avg_loss
                    best_model_path = os.path.join(save_dir, 'best_model.pth')
                    torch.save({
                        'step': current_step,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': best_loss
                    }, best_model_path)

        # Epoch metrics
        average_loss = running_loss / len(data_loader)
        print(f"Epoch {epoch+1}, Average Loss: {average_loss:.6f}")
        running_loss = 0.0

        if current_step >= num_steps:
            break

    print("Training completed!")
    final_path = os.path.join(save_dir, 'GlowTTS_final_model.pth')
    torch.save({
        'step': current_step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': average_loss
    }, final_path)


In [None]:
class CheckpointManager:
    def __init__(self, model, optimizer, save_dir, save_interval=10000):
        """
        Initialize checkpoint manager
        Args:
            model: The model to save
            optimizer: The optimizer to save
            save_dir: Directory path in Google Drive to save checkpoints
            save_interval: Save checkpoint every N steps
        """
        self.model = model
        self.optimizer = optimizer
        self.save_dir = save_dir
        self.save_interval = save_interval

        # Create save directory if it doesn't exist
        os.makedirs(self.save_dir, exist_ok=True)

        # Initialize training history
        self.history = {
            'steps': [],
            'losses': [],
            'timestamp': []
        }

    def save_checkpoint(self, step, loss):
        """Save checkpoint and update history"""
        if step % self.save_interval == 0:
            # Create checkpoint filename with timestamp
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            filename = f'GlowTTS_checkpoint_step_{step}_{timestamp}.pth'
            filepath = os.path.join(self.save_dir, filename)

            # Save checkpoint with all necessary information
            checkpoint = {
                'step': step,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': loss
            }

            try:
                # Save checkpoint
                torch.save(checkpoint, filepath)
                print(f"\nCheckpoint saved successfully at step {step}: {filepath}")

                # Update history
                self.history['steps'].append(step)
                self.history['losses'].append(loss)
                self.history['timestamp'].append(timestamp)

                # Save history to JSON
                history_file = os.path.join(self.save_dir, 'training_history.json')
                with open(history_file, 'w') as f:
                    json.dump(self.history, f, indent=4)

            except Exception as e:
                print(f"\nError saving checkpoint at step {step}: {str(e)}")

def train_model_with_checkpoints(model, data_loader, save_dir, num_steps=650000, save_interval=10000,
                               device=None, resume_from=None):
    """Training function with safe attention handling and gradient scaling"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device=device)

    # Initialize optimizer with RAdam
    from torch.optim import RAdam
    optimizer = RAdam(
        model.parameters(),
        lr=0.001,
        betas=(0.9, 0.998),
        eps=1e-9,
        weight_decay=1e-6
    )

    # Learning rate scheduler
    from torch.optim.lr_scheduler import StepLR
    scheduler = StepLR(optimizer, step_size=50000, gamma=0.5)

    def compute_loss(outputs, features, text_lengths):
        """Compute loss with safe handling of all components"""
        predicted_features = outputs["ssl_features"]
        z = outputs.get('z')
        z_m = outputs.get('z_m')
        z_logs = outputs.get('z_logs')
        logdet = outputs.get('logdet')
        z_mask = outputs.get('z_mask')

        # Basic reconstruction loss
        target_length = features.size(1)
        predicted_length = predicted_features.size(1)

        if predicted_length > target_length:
            predicted_features = predicted_features[:, :target_length, :]
        elif predicted_length < target_length:
            features = features[:, :predicted_length, :]

        # Compute masked reconstruction loss
        recon_loss = F.mse_loss(predicted_features, features)

        # Initialize total loss with reconstruction loss
        total_loss = recon_loss * 10.0  # Scale factor for stability

        # Add flow loss components if available
        if all(x is not None for x in [z, z_m, z_logs]):
            # Normalize z_logs for stability
            z_logs = torch.clamp(z_logs, min=-30, max=30)

            # Compute flow loss
            if z_mask is not None:
                # Masked flow loss
                flow_loss = torch.sum(z_logs + 0.5 * torch.exp(-2 * z_logs) * ((z - z_m) ** 2))
                flow_loss = flow_loss / torch.sum(z_mask)  # Normalize by mask sum
            else:
                # Unmasked flow loss
                flow_loss = torch.mean(z_logs + 0.5 * torch.exp(-2 * z_logs) * ((z - z_m) ** 2))

            # Add logdet if available
            if logdet is not None:
                flow_loss = flow_loss - torch.mean(logdet)

            total_loss = total_loss + flow_loss

        return total_loss

    # Initialize tracking variables
    current_step = 0
    best_loss = float('inf')
    running_loss = 0.0
    loss_window = []

    # Resume from checkpoint if specified
    if resume_from is not None:
        print(f"Loading checkpoint from {resume_from}")
        checkpoint = torch.load(resume_from)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        current_step = checkpoint['step']
        if 'best_loss' in checkpoint:
            best_loss = checkpoint['best_loss']
        print(f"Resumed from step {current_step}")

    # Initialize checkpoint manager
    checkpoint_manager = CheckpointManager(
        model=model,
        optimizer=optimizer,
        save_dir=save_dir,
        save_interval=save_interval
    )

    # Training loop
    model.train()
    for epoch in range(1000000):
        for features, text in data_loader:
            if current_step >= num_steps:
                break

            try:
                # Process inputs
                features = features.to(device=device)

                # Convert text to tensor
                max_length = max(len(t) for t in text)
                text_tensor = torch.tensor(
                    [[ord(c) for c in t.ljust(max_length)] for t in text],
                    dtype=torch.long
                ).to(device=device)

                max_index = 128 - 1
                text_tensor = torch.clamp(text_tensor, max=max_index)
                text_lengths = torch.tensor([len(t) for t in text], dtype=torch.long).to(device=device)

                # Forward pass
                optimizer.zero_grad()

                # Print debug info at first step
                if current_step == 0:
                    print("Features shape:", features.shape)
                    print("Text tensor shape:", text_tensor.shape)
                    print("Text lengths:", text_lengths)

                outputs = model(text_tensor, text_lengths, features, gen=False)

                # Print output keys for debugging (only first step)
                if current_step == 0:
                    print("Available output keys:", outputs.keys())
                    for k, v in outputs.items():
                        if v is not None:
                            print(f"{k} shape: {v.shape if hasattr(v, 'shape') else None}")

                # Compute loss
                loss = compute_loss(outputs, features, text_lengths)

                if torch.isnan(loss):
                    print("Warning: NaN loss detected, skipping batch")
                    continue

                # Scale loss for better gradient flow
                loss = loss / 100.0  # Scale down for stability

                # Backward pass with gradient scaling
                loss.backward()

                # Gradient clipping
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Skip step if gradient norm is too high
                if grad_norm > 100:
                    print(f"Warning: High gradient norm {grad_norm}, skipping step")
                    continue

                optimizer.step()
                scheduler.step()

                # Update tracking
                running_loss += loss.item()
                loss_window.append(loss.item())
                if len(loss_window) > 100:
                    loss_window.pop(0)

                current_step += 1

                # Logging
                if current_step % 250 == 0:
                    moving_avg_loss = sum(loss_window) / len(loss_window)
                    current_lr = optimizer.param_groups[0]['lr']

                    print(f"Step {current_step}/{num_steps}, "
                          f"Loss: {loss.item():.4f}, "
                          f"Moving Avg Loss: {moving_avg_loss:.4f}, "
                          f"LR: {current_lr:.6f}, "
                          f"Grad Norm: {grad_norm:.2f}")

                    checkpoint_manager.save_checkpoint(current_step, loss.item())

                    # Save best model
                    if moving_avg_loss < best_loss:
                        best_loss = moving_avg_loss
                        best_model_path = os.path.join(save_dir, 'best_model.pth')
                        torch.save({
                            'step': current_step,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'scheduler_state_dict': scheduler.state_dict(),
                            'loss': best_loss
                        }, best_model_path)

            except Exception as e:
                print(f"Error in training step: {str(e)}")
                continue  # Skip this batch and continue with the next one

        # Epoch metrics
        if running_loss > 0:  # Only if we had successful steps
            average_loss = running_loss / len(data_loader)
            print(f"Epoch {epoch+1}, Average Loss: {average_loss:.6f}")
        running_loss = 0.0

        if current_step >= num_steps:
            break

    print("Training completed!")
    final_path = os.path.join(save_dir, 'GlowTTS_final_model.pth')
    torch.save({
        'step': current_step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': average_loss
    }, final_path)

In [None]:
# Initialize model and data loader
model = TextToSSL()
save_dir = '/content/drive/MyDrive/SSL_TTS_Checkpoints'
# Mount Google Drive
drive.mount('/content/MyDrive')

# Train from scratch
train_model_with_checkpoints(model, data_loader, save_dir)

# Or resume from a checkpoint
# train_model_with_checkpoints(
#     model,
#     data_loader,
#     save_dir,
#     resume_from='/content/drive/MyDrive/SSL_TTS_Checkpoints/checkpoint_name.pth'
# )

Features shape: torch.Size([4, 351, 1024])
Text tensor shape: torch.Size([4, 111])
Text lengths: tensor([ 84, 100, 111, 100], device='cuda:0')
Available output keys: dict_keys(['z', 'z_m', 'z_logs', 'logdet', 'z_mask', 'logw', 'logw_', 'attn', 'ssl_features'])
z shape: torch.Size([4, 350, 1024])
logdet shape: torch.Size([4])
attn shape: torch.Size([4, 350, 111])
ssl_features shape: torch.Size([4, 350, 1024])
Step 250/650000, Loss: 1.3957, Moving Avg Loss: 1.2058, LR: 0.001000, Grad Norm: 0.07
Step 500/650000, Loss: 1.0130, Moving Avg Loss: 1.0456, LR: 0.001000, Grad Norm: 0.03
Step 750/650000, Loss: 1.1419, Moving Avg Loss: 1.0292, LR: 0.001000, Grad Norm: 0.04
Step 1000/650000, Loss: 1.0044, Moving Avg Loss: 1.0425, LR: 0.001000, Grad Norm: 0.02
Step 1250/650000, Loss: 1.0063, Moving Avg Loss: 1.0483, LR: 0.001000, Grad Norm: 0.04
Step 1500/650000, Loss: 1.0249, Moving Avg Loss: 1.0325, LR: 0.001000, Grad Norm: 0.01
Step 1750/650000, Loss: 0.8973, Moving Avg Loss: 1.0491, LR: 0.001000

KeyboardInterrupt: 

# k-NN Retrieval System

Implements the k-Nearest Neighbors retrieval mechanism for voice conversion:

### Technical Features
1. Efficient Batch Processing:
   - Handles multiple sequences simultaneously
   - Optimized matrix operations
   - Memory-efficient implementation

2. Distance Calculation (more below):
   - Cosine similarity metric
   - Numerical stability handling
   - Batch matrix multiplication

3. Feature Averaging:
   - Uniform weighting of k-nearest neighbors
   - Proper dimension handling
   - Gradient-free operations

### Parameters
- k: Number of neighbors (default: 4)
- device: Computation device
- input dimensions: [B, T, D] for both source and target
- output dimensions: [B, T, D] for selected features

### Cosine Similarity
For two feature vectors $\mathbf{a}$ and $\mathbf{b}$ in a high-dimensional space (in our case, $\mathbb{R}^{1024}$), the cosine similarity is defined as:

$
\cos(\mathbf{a}, \mathbf{b}) = \frac{\mathbf{a} \cdot \mathbf{b}}{\|\mathbf{a}\| \|\mathbf{b}\|}
$

Where:
- $\mathbf{a} \cdot \mathbf{b}$ is the dot product
- $\|\mathbf{a}\|$ and $\|\mathbf{b}\|$ are the L2 norms (Euclidean norms)

For batched computation with source features $\mathbf{S} \in \mathbb{R}^{B \times T_s \times D}$ and target features $\mathbf{T} \in \mathbb{R}^{B \times T_t \times D}$, we compute:

$
\text{Similarity}_{batch} = \frac{\mathbf{S}\mathbf{T}^T}{\|\mathbf{S}\|_2 \|\mathbf{T}\|_2^T}
$

### Cosine Distance
The cosine distance is derived from the cosine similarity:

$
d_{cos}(\mathbf{a}, \mathbf{b}) = 1 - \cos(\mathbf{a}, \mathbf{b})
$

In our implementation, we compute this in steps:

1. **Dot Product**:
   $\text{dot}_{batch} = \mathbf{S}\mathbf{T}^T \in \mathbb{R}^{B \times T_s \times T_t}$

2. **L2 Norms**:
   $\|\mathbf{S}\|_2 \in \mathbb{R}^{B \times T_s \times 1}$ and $\|\mathbf{T}\|_2 \in \mathbb{R}^{B \times T_t \times 1}$

3. **Norm Product**:
   $\text{norm\_prod} = \|\mathbf{S}\|_2\|\mathbf{T}\|_2^T \in \mathbb{R}^{B \times T_s \times T_t}$

4. **Final Distance**:
   $d_{cos} = 1 - \frac{\text{dot}_{batch}}{\text{norm\_prod} + \epsilon}$

where $\epsilon = 1e-8$ for numerical stability.

This distance metric has several advantageous properties for our SSL-TTS framework:

1. **Bounded Range**: $d_{cos} \in [0, 2]$ where:
   - 0 indicates identical direction
   - 1 indicates orthogonal vectors
   - 2 indicates opposite directions

2. **Scale Invariance**: The distance is invariant to the magnitude of the vectors, making it suitable for comparing SSL features that may have different magnitudes but similar patterns.

3. **Batch Efficiency**: The formulation allows efficient computation across batches using matrix operations, crucial for processing multiple time steps simultaneously.


In [None]:
class KNNRetrieval:
    def __init__(self, k: int = 4, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        """
        Initialize KNN retrieval with specified k neighbors.
        Args:
            k (int): Number of nearest neighbors to retrieve (default: 4 as per paper)
            device (str): Device to perform computations on
        """
        self.k = k
        self.device = device

    def _compute_cosine_distance(self, source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute cosine distance between source and target features.

        Input shapes:
        source: [B, T_s, D] where D is feature dimension (e.g., 1024 for WavLM)
        target: [B, T_t, D]

        Returns:
            torch.Tensor: Cosine distance matrix [B, T_s, T_t]
        """
        # 1. Compute dot product: (A·B)
        dot_product = torch.bmm(source, target.transpose(1, 2))  # [B, T_s, T_t]

        # 2. Compute L2 norms: ||A|| and ||B||
        source_norm = torch.norm(source, p=2, dim=2, keepdim=True)  # [B, T_s, 1]
        target_norm = torch.norm(target, p=2, dim=2, keepdim=True)  # [B, T_t, 1]

        # 3. Compute product of norms: ||A||·||B||
        norm_product = torch.bmm(source_norm, target_norm.transpose(1, 2))  # [B, T_s, T_t]

        # 4. Compute cosine similarity: (A·B)/(||A||·||B||)
        cosine_similarity = dot_product / (norm_product + 1e-8)  # [B, T_s, T_t]

        # 5. Convert to distance: 1 - similarity
        cosine_distance = 1 - cosine_similarity  # [B, T_s, T_t]

        return cosine_distance

    def retrieve(self, source_features: torch.Tensor, target_database: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieve k nearest neighbor features from target database.
        Args:
            source_features (torch.Tensor): Source speaker features [B, T_s, D]
            target_database (torch.Tensor): Target speaker features database [B, T_t, D]
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Selected features [B, T_s, D] and distances [B, T_s, K]
        """
        # Move tensors to device if needed
        source_features = source_features.to(self.device)
        target_database = target_database.to(self.device)

        # Compute distances between source and all target frames
        distances = self._compute_cosine_distance(source_features, target_database)  # [B, T_s, T_t]

        # Find k nearest neighbors
        topk_distances, topk_indices = torch.topk(distances, k=self.k, dim=-1, largest=False)  # [B, T_s, K]

        # Get dimensions
        B, T_s, _ = source_features.shape
        D = target_database.shape[-1]

        # Expand indices for batch gathering
        batch_indices = torch.arange(B, device=self.device).view(B, 1, 1, 1)
        batch_indices = batch_indices.expand(B, T_s, self.k, 1)

        # Reshape indices for gathering
        topk_indices = topk_indices.unsqueeze(-1)  # [B, T_s, K, 1]

        # Combine batch and topk indices
        gather_indices = torch.cat([batch_indices, topk_indices], dim=-1)  # [B, T_s, K, 2]

        # Gather target features
        selected_features = target_database[gather_indices[..., 0], gather_indices[..., 1]]  # [B, T_s, K, D]

        # Average the k nearest features (uniform weighting as per paper)
        selected_features = selected_features.mean(dim=2)  # [B, T_s, D]

        return selected_features, topk_distances

    def __call__(self, source_features: torch.Tensor, target_database: torch.Tensor) -> torch.Tensor:
        """
        Convenience method to directly get selected features.
        """
        selected_features, _ = self.retrieve(source_features, target_database)
        return selected_features

# $\lambda$ function

### Core Implementation
1. Interpolation Formula:
```python
converted_features = lambda_value * selected_features +
                    (1 - lambda_value) * source_features
```

2. Parameter Management:
   - Lambda value bounds checking
   - Device handling
   - Tensor dimension verification

### Features
1. Input Validation:
   - Lambda range enforcement
   - Tensor dimension checking
   - Device consistency

2. Computation Efficiency:
   - In-place operations where possible
   - Memory-efficient implementation
   - Batch processing support

3. Interface Options:
   - Direct method call
   - Callable interface
   - Flexible parameter passing

In [None]:
class LinearInterpolation:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """
        Initialize linear interpolation module.
        Args:
            device (str): Device to perform computations on
        """
        self.device = device

    def interpolate(self, selected_features: torch.Tensor, source_features: torch.Tensor, lambda_value: float = 1.0) -> torch.Tensor:
        """
        Linearly interpolate between selected target features and source features.
        Formula: y_converted = λ * y_selected + (1 - λ) * y_source

        Args:
            selected_features (torch.Tensor): Selected target features from kNN [B, T, D]
            source_features (torch.Tensor): Original source features [B, T, D]
            lambda_value (float): Interpolation parameter (0 to 1)
                                0 = pure source, 1 = pure target (default: 1.0)

        Returns:
            torch.Tensor: Interpolated features [B, T, D]
        """
        # Ensure tensors are on correct device
        selected_features = selected_features.to(self.device)
        source_features = source_features.to(self.device)

        # Ensure lambda is in valid range
        lambda_value = max(0.0, min(1.0, lambda_value))

        # Perform linear interpolation
        converted_features = lambda_value * selected_features + (1 - lambda_value) * source_features

        return converted_features

    def __call__(self, selected_features: torch.Tensor, source_features: torch.Tensor, lambda_value: float = 1.0) -> torch.Tensor:
        """
        Convenience method to directly perform interpolation.
        """
        return self.interpolate(selected_features, source_features, lambda_value)

# TEST no vocoder


In [None]:
def test_complete_framework(
    ssl_encoder,
    text_to_ssl,
    knn_retrieval,
    linear_interp,
    sample_rate: int = 16000,
    lambda_values: list = [0.0, 0.25, 0.5, 0.75, 1.0]
) -> Dict:
    """
    Test the entire SSL-TTS framework including linear interpolation.

    Args:
        ssl_encoder: SSL encoder model (WavLM)
        text_to_ssl: Text-to-SSL model
        knn_retrieval: KNN retrieval module
        linear_interp: Linear interpolation module
        sample_rate: Audio sample rate (default: 16000)
        lambda_values: List of lambda values to test (default: [0.0, 0.25, 0.5, 0.75, 1.0])

    Returns:
        Dict: Dictionary containing test results and metrics
    """
    device = next(text_to_ssl.parameters()).device
    print(f"\nRunning complete framework test on device: {device}")

    try:
        # 1. Create synthetic target audio (2 seconds of audio at 16kHz)
        print("\nStep 1: Creating synthetic target audio...")
        target_waveform = torch.randn(1, 2 * sample_rate, device=device)

        # 2. Extract SSL features from target audio
        print("Step 2: Extracting SSL features from target audio...")
        with torch.no_grad():
            target_features = ssl_encoder.extract_features(target_waveform, sample_rate)
        print(f"Target features shape: {target_features.shape}")

        # 3. Generate source SSL features from text
        print("\nStep 3: Generating source SSL features from text...")
        text_tokens = torch.randint(0, 148, (1, 100), device=device)  # Simulate tokenized text
        text_lengths = torch.tensor([100], device=device)

        with torch.no_grad():
            source_outputs = text_to_ssl.generate(text_tokens, text_lengths)
            source_features = source_outputs["ssl_features"]
        print(f"Source features shape: {source_features.shape}")

        # 4. Perform k-NN retrieval
        print("\nStep 4: Performing k-NN retrieval...")
        selected_features, knn_distances = knn_retrieval.retrieve(source_features, target_features)
        print(f"Selected features shape: {selected_features.shape}")
        print(f"k-NN distances shape: {knn_distances.shape}")

        # 5. Test different lambda values for interpolation
        print("\nStep 5: Testing linear interpolation with different λ values...")
        interpolation_results = {}

        for lambda_val in lambda_values:
            print(f"\nTesting λ = {lambda_val}")
            converted_features = linear_interp(
                selected_features=selected_features,
                source_features=source_features,
                lambda_value=lambda_val
            )

            # Compute metrics for this lambda value
            metrics = {
                "converted_shape": converted_features.shape,
                "converted_mean": converted_features.mean().item(),
                "converted_std": converted_features.std().item(),
                "source_similarity": torch.nn.functional.cosine_similarity(
                    converted_features.view(-1, converted_features.size(-1)),
                    source_features.view(-1, source_features.size(-1)),
                    dim=-1
                ).mean().item(),
                # For target similarity, we'll compute it frame by frame
                "target_similarity": torch.nn.functional.cosine_similarity(
                    converted_features.mean(dim=1),  # Average over time dimension first
                    target_features.mean(dim=1),     # Average over time dimension first
                    dim=-1
                ).mean().item(),
                "min_knn_distance": knn_distances.min().item(),
                "max_knn_distance": knn_distances.max().item(),
                "mean_knn_distance": knn_distances.mean().item()
            }

            interpolation_results[lambda_val] = metrics

            print(f"Shape: {metrics['converted_shape']}")
            print(f"Mean: {metrics['converted_mean']:.3f}, Std: {metrics['converted_std']:.3f}")
            print(f"Similarity to source: {metrics['source_similarity']:.3f}")
            print(f"Similarity to target: {metrics['target_similarity']:.3f}")

        # 6. Compile final results
        final_results = {
            "feature_shapes": {
                "source": source_features.shape,
                "target": target_features.shape,
                "selected": selected_features.shape
            },
            "knn_metrics": {
                "min_distance": knn_distances.min().item(),
                "max_distance": knn_distances.max().item(),
                "mean_distance": knn_distances.mean().item()
            },
            "interpolation_results": interpolation_results
        }

        print("\nTest completed successfully!")
        return final_results

    except Exception as e:
        print(f"\nTest failed with error: {str(e)}")
        raise

def run_complete_framework_test():
    """
    Run the complete framework test with all components
    """
    # Initialize models
    device = "cuda" if torch.cuda.is_available() else "cpu"

    try:
        print("Initializing components...")
        ssl_encoder = SSLEncoder(device=device)
        text_to_ssl = TextToSSL().to(device)
        knn_retrieval = KNNRetrieval(k=4, device=device)
        linear_interp = LinearInterpolation(device=device)

        # Run complete test
        results = test_complete_framework(
            ssl_encoder=ssl_encoder,
            text_to_ssl=text_to_ssl,
            knn_retrieval=knn_retrieval,
            linear_interp=linear_interp
        )

        # Print summary
        print("\nFinal Summary:")
        print("-" * 50)
        print("\nFeature Shapes:")
        for name, shape in results["feature_shapes"].items():
            print(f"{name}: {shape}")

        print("\nk-NN Metrics:")
        for name, value in results["knn_metrics"].items():
            print(f"{name}: {value:.3f}")

        print("\nInterpolation Results Summary:")
        for lambda_val, metrics in results["interpolation_results"].items():
            print(f"\nλ = {lambda_val}:")
            print(f"Source similarity: {metrics['source_similarity']:.3f}")
            print(f"Target similarity: {metrics['target_similarity']:.3f}")

    except Exception as e:
        print(f"Framework test failed: {str(e)}")
        raise

if __name__ == "__main__":
    run_complete_framework_test()

# Vocoder

Implements the HiFi-GAN vocoder for waveform generation:

### Technical Details
1. Model Components:
   - Residual blocks
   - Upsampling layers
   - Convolutional layers

2. Audio Generation:
   - Feature conditioning
   - Multi-scale processing
   - Waveform synthesis

3. Current Status:
   - Checkpoint loading issue
   - Needs path configuration
   - Testing infrastructure ready

In [None]:
'''
from TTS.vocoder.models.hifigan_generator import HifiganGenerator

class Vocoder:
    def __init__(self, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        """
        Initialize HiFi-GAN vocoder.
        Args:
            device (str): Device to perform computations on
        """
        self.device = device
        print(f"Initializing HiFi-GAN vocoder on {device}...")

        # Initialize HiFi-GAN with configuration from the paper
        self.model = HifiganGenerator(
            in_channels=1024,  # WavLM feature dimension
            out_channels=1,    # Single channel audio output
            resblock_type="1", # NOT SURE WE NEED ALL THESE, I THINK WE KEEP DEFAULT SO WE CAN REMOVE ALL THESE:
            resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
            resblock_kernel_sizes=[3, 7, 11],
            upsample_kernel_sizes=[16, 16, 4, 4],
            upsample_initial_channel=512,
            upsample_factors=[8, 8, 2, 2],
        ).to(device)

        # Load pre-trained weights
        self._load_pretrained_weights()

        # Set to evaluation mode
        self.model.eval()
        print("Vocoder initialized successfully!")

    def _load_pretrained_weights(self):
        """Load pre-trained HiFi-GAN weights"""
        checkpoint_path = "path/to/hifigan_checkpoint.pth"  # Replace with actual path
        self.model.load_checkpoint(config=None, checkpoint_path=checkpoint_path, eval=True)

    @torch.no_grad()
    def generate(self, features: torch.Tensor) -> torch.Tensor:
        """
        Generate waveform from SSL features.

        Args:
            features (torch.Tensor): SSL features [B, T, D]

        Returns:
            torch.Tensor: Generated waveform [B, 1, T*hop_length]
        """
        # Ensure features are on correct device
        features = features.to(self.device)

        # Move channel dimension to match HiFi-GAN input requirements
        features = features.transpose(1, 2)  # [B, T, D] -> [B, D, T]

        # Generate audio
        waveform = self.model.inference(features)

        return waveform

    def __call__(self, features: torch.Tensor) -> torch.Tensor:
        """Convenience method to directly generate waveform."""
        return self.generate(features)

def test_vocoder():
    """Test the vocoder with synthetic features."""
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Initialize vocoder
    vocoder = Vocoder(device=device)

    # Create synthetic feature input
    B, T, D = 1, 100, 1024  # Batch, Time, Features
    features = torch.randn(B, T, D, device=device)

    try:
        # Generate waveform
        waveform = vocoder(features)

        # Print results
        print("\nVocoder Test Results:")
        print("-" * 50)
        print(f"Input features shape: {features.shape}")
        print(f"Output waveform shape: {waveform.shape}")
        print(f"Waveform statistics:")
        print(f"  Mean: {waveform.mean().item():.3f}")
        print(f"  Std: {waveform.std().item():.3f}")
        print(f"  Min: {waveform.min().item():.3f}")
        print(f"  Max: {waveform.max().item():.3f}")

        # Optional: Save generated audio
        # torchaudio.save('test_output.wav', waveform.cpu(), sample_rate=16000)

        print("\nTest completed successfully!")

    except Exception as e:
        print(f"\nTest failed with error: {str(e)}")
        raise
'''



In [None]:
from TTS.vocoder.models.hifigan_generator import HifiganGenerator

import torch
import os
import urllib.request
from tqdm import tqdm

class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

class Vocoder:
    def __init__(self, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        """
        Initialize HiFi-GAN vocoder.
        Args:
            device (str): Device to perform computations on
        """
        self.device = device
        print(f"Initializing HiFi-GAN vocoder on {device}...")

        # Initialize HiFi-GAN with configuration from the paper
        self.model = HifiganGenerator(
            in_channels=1024,  # WavLM feature dimension
            out_channels=1,    # Single channel audio output
            resblock_type="1",
            resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
            resblock_kernel_sizes=[3, 7, 11],
            upsample_kernel_sizes=[16, 16, 4, 4],
            upsample_initial_channel=512,
            upsample_factors=[8, 8, 2, 2],
        ).to(device)

        # Load pre-trained weights
        self._load_pretrained_weights()

        # Set to evaluation mode
        self.model.eval()
        print("Vocoder initialized successfully!")

    def _download_checkpoint(self, url, filename):
        """Download checkpoint if it doesn't exist"""
        if not os.path.exists(filename):
            print(f"Downloading {filename}...")
            with DownloadProgressBar(unit='B', unit_scale=True,
                                   miniters=1, desc=filename) as t:
                urllib.request.urlretrieve(url, filename=filename,
                                         reporthook=t.update_to)

    def _load_pretrained_weights(self):
        """Load pre-trained HiFi-GAN weights"""
        # URL for the prematch generator weights from knn-vc release
        url = "https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt"
        checkpoint_path = "prematch_g_02500000.pt"

        # Download the checkpoint if it doesn't exist
        self._download_checkpoint(url, checkpoint_path)

        # Load the state dict
        print("Loading checkpoint...")
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        # If the checkpoint contains a 'model' or 'generator' key, extract the state dict
        if 'model' in checkpoint:
            state_dict = checkpoint['model']
        elif 'generator' in checkpoint:
            state_dict = checkpoint['generator']
        else:
            state_dict = checkpoint  # Assume it's directly the state dict

        # Load the weights into the model
        self.model.load_state_dict(state_dict)

        # Verify loading was successful
        param_sum = sum(p.sum() for p in self.model.parameters())
        print(f"Sum of parameters after loading: {param_sum}")
        assert param_sum != 0, "Loading failed - all parameters are zero"

In [None]:
vocoder = Vocoder()

Initializing HiFi-GAN vocoder on cpu...
Downloading prematch_g_02500000.pt...


prematch_g_02500000.pt: 66.2MB [00:01, 59.2MB/s]                            
  checkpoint = torch.load(checkpoint_path, map_location=self.device)


Loading checkpoint...


RuntimeError: Error(s) in loading state_dict for HifiganGenerator:
	Unexpected key(s) in state_dict: "lin_pre.weight", "lin_pre.bias". 
	size mismatch for conv_pre.parametrizations.weight.original1: copying a param with shape torch.Size([512, 512, 7]) from checkpoint, the shape in current model is torch.Size([512, 1024, 7]).
	size mismatch for ups.0.parametrizations.weight.original1: copying a param with shape torch.Size([512, 256, 20]) from checkpoint, the shape in current model is torch.Size([512, 256, 16]).

In [None]:
knn_vc = torch.hub.load(
    'bshall/knn-vc',
    'knn_vc',
    pretrained=True,
    prematched=True,
    trust_repo=True
)


Using cache found in /root/.cache/torch/hub/bshall_knn-vc_master


Removing weight norm...
[HiFiGAN] Generator loaded with 16,523,393 parameters.
WavLM-Large loaded with 315,453,120 parameters.


In [None]:
vocoder = knn_vc.hifigan
print(vocoder)

Generator(
  (lin_pre): Linear(in_features=1024, out_features=512, bias=True)
  (conv_pre): Conv1d(512, 512, kernel_size=(7,), stride=(1,), padding=(3,))
  (ups): ModuleList(
    (0): ConvTranspose1d(512, 256, kernel_size=(20,), stride=(10,), padding=(5,))
    (1): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(4,))
    (2): ConvTranspose1d(128, 64, kernel_size=(4,), stride=(2,), padding=(1,))
    (3): ConvTranspose1d(64, 32, kernel_size=(4,), stride=(2,), padding=(1,))
  )
  (resblocks): ModuleList(
    (0): ResBlock1(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      )
    )
    (1): ResBlock1(
      (convs1): Modu

# Testing and Evaluation

To evaluate the zero-shot model they used the LibriSpeech test-clean dataset for target speaker reference utterances (ground truth). The database has speech from 20 males and 20 females, 8 minutes of speech per each. We downloaded the data and we specifically need the following file: test-clean, which contains subfolders (one for each speaker) then subfolders within those (one for each chapter of a book that the speakers read from), then the individual audio files (in .flac form, each file is a sentence from the chapter).

\\

To create the output for the model, they passed in 100 English sentences for each speaker, from the FLoRes+ dataset. We downloaded the data and figured out where to find the sentences. We really only need one file, “devtest.eng_Latn”, which contains a multitude of random English sentences. Below you will find example sentences.

\\

MOS = mean opinion score is a measure of the human-judged overall quality of an event or experience. For us, a MOS is a ranking of the quality of speech utterances. Most often judged on a scale of 1 (bad) to 5 (excellent), MOS’s are the average of a number of other human-scored individual parameters. Although originally MOS’s were derived from surveys of expert observers, today a MOS is often produced by an Objective Measurement Method, approximating a human ranking. 4.3-4.5 is considered an excellent target to shoot for due to human tendency to rarely give out perfect 5’s. Below 3.5 is generally unacceptable.

\\

All tests are conducted with $λ=1$. The evaluation focuses on a few key metrics of language:

- **Naturalness: UTMOS**
 - UTMOS = UTokyo-SaruLab Mean Opinion Score, an autonomous method of calculating MOS.

- **Intelligibility: WER, PER**
 - WER = Word Error Rate, i.e. the ratio of word errors in a transcript to the total words spoken. A lower WER in speech-to-text means better accuracy in recognizing speech. In our case, this would be calculated with the formula $\frac{S+D+I}{N}$, where S is the number of substitutions (instances where a word in the synthesized sentence vector would need to be subsituted to match the truth vector), D is the number of deletions (instances where a word in the synthesized sentence vector would need to be deleted to match the truth vector), I is the number of insertions (instances where a new word would need to be inserted to match the truth vector), and N is the total number of phenomes. The numerator is also known as the edit distance because it represents "how far away" two sentences are.
 - PER = Phenome Error Rate, i.e. the ratio of phenome errors in a transcript to the total phenomes spoken. As above, a lower PER means better accuracy. The formula the same as above, except in the context of comparing phenomes instead of words.
 - Both of these are calculated using the Whisper-Large v3 model.

- **Speaker Similarity: SECS**
 - SECS = Speaker-Encoder Cosine Similarity, i.e. the cosine similarity between the embeddings of two audio samples, which in our case are a ground truth sample from one speaker and the synthesized sample for that same speaker. The original paper uses ECAPA2 to find these embeddings and their similtarity. The goal of speaker similarity is to determine if two audio samples come from the same spaker, so if the output of the model is above a certain threshold, they are considered to be from the same speaker, otherwise, they are from different speakers

- **Subjective Evaluation: N-MOS, S-MOS**
 - N-MOS = Natural MOS, i.e. how natural the utterance (output) sounds compared to the ground-truth recording.
 - S-MOS = Similarity MOS, i.e. how similar the utterance sounds compared to the ground-truth recording.
The original paper had 10 raters go through 3 synthesized sentences per speaker, thus they went through 60 in total. They then gave a score for each synthesis from 1 to 5 in 0.5 increments. They hired native English speakers in the United States through Amazon Mechanical Turk to rate, so in our case it would just be us 4 rating.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

dev_test = '/content/drive/MyDrive/devtest.eng_Latn'

i = 1
for sentence in open(dev_test, 'r'):
  print(f'Example {i}: {sentence.strip()} \n')
  i+=1
  if i==5:
    break

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Example 1: "We now have 4-month-old mice that are non-diabetic that used to be diabetic," he added. 

Example 2: Dr. Ehud Ur, professor of medicine at Dalhousie University in Halifax, Nova Scotia and chair of the clinical and scientific division of the Canadian Diabetes Association cautioned that the research is still in its early days. 

Example 3: Like some other experts, he is skeptical about whether diabetes can be cured, noting that these findings have no relevance to people who already have Type 1 diabetes. 

Example 4: On Monday, Sara Danius, permanent secretary of the Nobel Committee for Literature at the Swedish Academy, publicly announced during a radio program on Sveriges Radio in Sweden the committee, unable to reach Bob Dylan directly about winning the 2016 Nobel Prize in Literature, had abandoned its efforts to reach him. 



In [None]:
test_clean = '/content/drive/MyDrive/LibriSpeech/test-clean'