In [2]:
from datasets import load_dataset

# Load from local files
dataset = load_dataset('json', 
                      data_files={
                          'data': 'config/LibriSQA-PartI-flac.json'
                      })


dataset = dataset['data']


In [3]:
# Print basic information about the dataset
print(dataset)
print(f"Number of examples: {len(dataset)}")

Dataset({
    features: ['text', 'duration', 'question', 'answer', 'speech_path'],
    num_rows: 2620
})
Number of examples: 2620


In [4]:
import torch
import torchaudio
from transformers import AutoFeatureExtractor, MimiModel
import os
from pathlib import Path
import logging
from typing import List, Dict, Optional, Union
from dataclasses import dataclass
from tqdm import tqdm
import math
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pydub import AudioSegment
import json
import torch.nn.functional as F
import IPython.display as ipd

In [5]:
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables for model and feature extractor
model = None
feature_extractor = None

In [6]:
def initialize_mimi_model():
    """Initialize MIMI model and feature extractor"""
    global model, feature_extractor
    if model is None or feature_extractor is None:
        model = MimiModel.from_pretrained("kyutai/mimi")
        feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
    return model, feature_extractor

@dataclass
class AudioChunkInfo:
    """Class to hold information about an audio chunk"""
    chunk_number: int
    start_time: float  # in seconds
    end_time: float    # in seconds
    duration: float    # in seconds
    file_path: str

@dataclass
class SimilarityResult:
    """Class to hold similarity computation results"""
    chunk_id: str
    similarity_score: float
    chunk_start_time: float
    chunk_end_time: float
    embedding_path: str

def extract_mimi_embeddings(audio_path):
    """Extract MIMI embeddings from an audio file."""
    global model, feature_extractor
    
    # Initialize model if not already done
    if model is None or feature_extractor is None:
        model, feature_extractor = initialize_mimi_model()
    
    try:
        # Load the audio file
        waveform, sample_rate = torchaudio.load(audio_path)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Resample if necessary
        if sample_rate != feature_extractor.sampling_rate:
            resampler = torchaudio.transforms.Resample(sample_rate, feature_extractor.sampling_rate)
            waveform = resampler(waveform)
        
        # Prepare inputs
        inputs = feature_extractor(
            raw_audio=waveform.squeeze().numpy(),
            sampling_rate=feature_extractor.sampling_rate,
            return_tensors="pt"
        )
        
        # Extract features
        with torch.no_grad():
            # Get encoder outputs
            encoder_outputs = model.encode(inputs["input_values"])
            embeddings = encoder_outputs.audio_codes.float()
            
            # Convert to fixed-size embedding by taking mean across time dimension
            if len(embeddings.shape) == 3:
                embeddings = torch.mean(embeddings, dim=1)
            elif len(embeddings.shape) == 2:
                embeddings = torch.mean(embeddings, dim=0, keepdim=True)
                
            # Normalize the embeddings
            embeddings = F.normalize(embeddings, p=2, dim=-1)
            
        return embeddings
        
    except Exception as e:
        logger.error(f"Error in extract_mimi_embeddings: {str(e)}")
        raise

In [7]:
class EmbeddingSimilarityCalculator:
    """Handles similarity computations between embeddings"""
    
    def __init__(self, embeddings_dir: str):
        self.embeddings_dir = Path(embeddings_dir)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    def load_embedding(self, embedding_path: Union[str, Path]) -> torch.Tensor:
        """Load embedding from file and process it"""
        try:
            embedding = torch.load(embedding_path, map_location=self.device)
            
            # Convert to float if needed
            if not embedding.is_floating_point():
                embedding = embedding.float()
            
            # Get to 2D tensor shape (1, features)
            if len(embedding.shape) == 3:  # (batch, sequence, features)
                embedding = embedding.mean(dim=1)  # Average over sequence dimension
            if len(embedding.shape) == 2:  # (sequence, features)
                embedding = embedding.mean(dim=0, keepdim=True)  # Average to single vector
            if len(embedding.shape) == 1:  # (features,)
                embedding = embedding.unsqueeze(0)  # Add batch dimension
                
            return embedding
            
        except Exception as e:
            logger.error(f"Error loading embedding from {embedding_path}: {str(e)}")
            raise

    def process_embedding_for_comparison(self, embedding: torch.Tensor, target_dim: int) -> torch.Tensor:
        """Process embedding to match target dimension"""
        if embedding.shape[1] != target_dim:
            # Use linear interpolation to resize to target dimension
            embedding = F.interpolate(
                embedding.unsqueeze(1),  # Add channel dimension
                size=target_dim,
                mode='linear',
                align_corners=False
            ).squeeze(1)  # Remove channel dimension
        return embedding

    def compute_cosine_similarity(self, 
                                query_embedding: torch.Tensor,
                                chunk_embedding: torch.Tensor) -> float:
        """Compute cosine similarity between query and chunk embeddings"""
        try:
            with torch.no_grad():
                # Ensure both are 2D
                if len(query_embedding.shape) == 1:
                    query_embedding = query_embedding.unsqueeze(0)
                if len(chunk_embedding.shape) == 1:
                    chunk_embedding = chunk_embedding.unsqueeze(0)

                # Get the minimum dimension
                min_dim = min(query_embedding.shape[1], chunk_embedding.shape[1])
                
                # Resize both embeddings to the minimum dimension
                query_embedding = self.process_embedding_for_comparison(query_embedding, min_dim)
                chunk_embedding = self.process_embedding_for_comparison(chunk_embedding, min_dim)
                
                # Normalize embeddings
                query_embedding = F.normalize(query_embedding, p=2, dim=1)
                chunk_embedding = F.normalize(chunk_embedding, p=2, dim=1)
                
                # Compute similarity
                similarity = F.cosine_similarity(query_embedding, chunk_embedding, dim=1)
                
                return similarity.item()
                
        except Exception as e:
            logger.error(f"Error in compute_cosine_similarity: {str(e)}")
            raise
    
    def find_most_similar_chunks(self,
                               query_path: str,
                               top_k: int = 1,
                               metadata_path: Optional[str] = None) -> List[SimilarityResult]:
        """Find the most similar chunks to a query"""
        # Load query embedding
        query_embedding = self.load_embedding(query_path)
        
        # Load metadata if available
        chunk_metadata = {}
        if metadata_path and Path(metadata_path).exists():
            with open(metadata_path, 'r') as f:
                chunk_metadata = json.load(f)
        
        # Process all chunk embeddings
        results = []
        chunk_paths = sorted(self.embeddings_dir.glob("chunk_*.pt"))
        
        for chunk_path in tqdm(chunk_paths, desc="Processing chunks"):
            try:
                metadata = chunk_metadata.get(chunk_path.stem, {})
                chunk_embedding = self.load_embedding(chunk_path)
                similarity_score = self.compute_cosine_similarity(
                    query_embedding,
                    chunk_embedding
                )
                
                result = SimilarityResult(
                    chunk_id=chunk_path.stem,
                    similarity_score=similarity_score,
                    chunk_start_time=metadata.get('start_time', 0.0),
                    chunk_end_time=metadata.get('end_time', 0.0),
                    embedding_path=str(chunk_path)
                )
                results.append(result)
            except Exception as e:
                logger.warning(f"Error processing chunk {chunk_path}: {str(e)}")
                continue
        
        # Sort by similarity score and get top-k
        results.sort(key=lambda x: x.similarity_score, reverse=True)
        return results[:top_k]


In [None]:
import os
import torch
import logging
import soundfile as sf
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List
from dataclasses import dataclass

# Configure logging
logger = logging.getLogger(__name__)

@dataclass
class ChunkInfo:
    chunk_number: int
    file_path: str
    start_time: float
    end_time: float
    duration: float

@dataclass
class SimilarityResult:
    chunk_id: str
    similarity_score: float
    chunk_start_time: float
    chunk_end_time: float
    embedding_path: str

class FLACSplitter:
    def __init__(self, chunk_duration: int, output_dir: str, min_chunk_duration: int):
        self.chunk_duration = chunk_duration
        self.output_dir = output_dir
        self.min_chunk_duration = min_chunk_duration

    def split_audio(self, input_file: str) -> List[ChunkInfo]:
        """Split FLAC file into chunks"""
        # Read the FLAC file
        data, samplerate = sf.read(input_file)
        duration = len(data) / samplerate
        chunk_infos = []
        
        # Calculate number of chunks
        num_chunks = int(duration // self.chunk_duration)
        
        # Create output directory if it doesn't exist
        os.makedirs(self.output_dir, exist_ok=True)
        
        for i in range(num_chunks):
            start_time = i * self.chunk_duration
            end_time = start_time + self.chunk_duration
            
            # Convert time to samples
            start_sample = int(start_time * samplerate)
            end_sample = int(end_time * samplerate)
            
            # Get chunk data
            chunk_data = data[start_sample:end_sample]
            
            # Skip if chunk is too short
            if len(chunk_data) / samplerate < self.min_chunk_duration:
                continue
                
            # Save chunk
            chunk_path = os.path.join(self.output_dir, f"chunk_{i:03d}.flac")
            sf.write(chunk_path, chunk_data, samplerate)
            
            # Create chunk info
            chunk_info = ChunkInfo(
                chunk_number=i,
                file_path=chunk_path,
                start_time=start_time,
                end_time=end_time,
                duration=len(chunk_data) / samplerate
            )
            chunk_infos.append(chunk_info)
            
        return chunk_infos

class AudioMetadata:
    @staticmethod
    def save_chunk_metadata(chunk_infos: List[ChunkInfo], output_dir: str):
        """Save metadata about chunks to JSON file"""
        metadata = {
            "chunks": [
                {
                    "chunk_number": info.chunk_number,
                    "file_path": info.file_path,
                    "start_time": info.start_time,
                    "end_time": info.end_time,
                    "duration": info.duration
                }
                for info in chunk_infos
            ]
        }
        
        output_path = os.path.join(output_dir, "chunk_metadata.json")
        with open(output_path, 'w') as f:
            json.dump(metadata, f, indent=2)

class SimilarityVisualizer:
    @staticmethod
    def save_results(results: List[SimilarityResult], output_path: str):
        with open(output_path, 'w') as f:
            f.write("Similarity Results:\n")
            f.write("-" * 50 + "\n")
            
            for i, result in enumerate(results, 1):
                f.write(f"Rank {i}:\n")
                f.write(f"  Chunk ID: {result.chunk_id}\n")
                f.write(f"  Similarity Score: {result.similarity_score:.4f}\n")
                f.write(f"  Time Range: {result.chunk_start_time:.2f}s - "
                       f"{result.chunk_end_time:.2f}s\n")
                f.write(f"  Embedding: {result.embedding_path}\n")
                f.write("-" * 50 + "\n")
    
    @staticmethod
    def plot_similarities(results: List[SimilarityResult], output_path: str):
        chunk_ids = [r.chunk_id for r in results]
        scores = [r.similarity_score for r in results]
        
        plt.figure(figsize=(12, 6))
        sns.barplot(x=range(len(chunk_ids)), y=scores)
        plt.title("Chunk Similarity Scores")
        plt.xlabel("Chunk ID")
        plt.ylabel("Cosine Similarity")
        plt.xticks(range(len(chunk_ids)), chunk_ids, rotation=45)
        plt.tight_layout()
        plt.savefig(output_path)
        plt.close()

def process_audio_files():
    """Process main audio file and extract embeddings"""
    input_file = "/mnt/data/ashwin/SpeechRAG/libriSQA/twomerged/61_672_merged.flac"  # Update with your FLAC file path
    output_dir = "/mnt/data/ashwin/SpeechRAG/libriSQA/processed_audio"
    embeddings_dir = "/mnt/data/ashwin/SpeechRAG/libriSQA/embeddings/chunks_embedding"
    chunk_duration = 15
    min_chunk_duration = 5

    try:
        # Initialize MIMI model
        initialize_mimi_model()
        
        # Initialize splitter
        splitter = FLACSplitter(
            chunk_duration=chunk_duration,
            output_dir=output_dir,
            min_chunk_duration=min_chunk_duration
        )
        
        # Split audio
        chunk_infos = splitter.split_audio(input_file)
        
        # Save metadata and create embeddings directory
        if chunk_infos:
            output_dir = os.path.dirname(chunk_infos[0].file_path)
            AudioMetadata.save_chunk_metadata(chunk_infos, output_dir)
            Path(embeddings_dir).mkdir(parents=True, exist_ok=True)
            
            # Process each chunk
            for chunk_info in tqdm(chunk_infos, desc="Extracting embeddings"):
                try:
                    embeddings = extract_mimi_embeddings(chunk_info.file_path)
                    output_path = os.path.join(embeddings_dir, f"chunk_{chunk_info.chunk_number:03d}.pt")
                    torch.save(embeddings, output_path)
                except Exception as e:
                    logger.error(f"Error processing chunk {chunk_info.chunk_number}: {str(e)}")
                    continue
                    
        logger.info("Processing completed successfully!")
        
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        raise

def process_query_audio(query_audio_path: str, output_path: str):
    """Process query audio and save embeddings"""
    try:
        initialize_mimi_model()
        embeddings = extract_mimi_embeddings(query_audio_path)
        torch.save(embeddings, output_path)
        logger.info(f"Query embeddings saved to {output_path}")
    except Exception as e:
        logger.error(f"Error processing query audio: {str(e)}")
        raise

def find_similar_chunks():
    """Find chunks similar to query"""
    output_dir = "/mnt/data/ashwin/SpeechRAG/libriSQA/results"
    embeddings_dir = "/mnt/data/ashwin/SpeechRAG/libriSQA/embeddings/chunks_embedding"
    query_path = "/mnt/data/ashwin/SpeechRAG/libriSQA/embeddings/question_embedding.pt"
    # metadata_path = "/path/to/chunk_metadata.json"
    top_k = 20

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    calculator = EmbeddingSimilarityCalculator(embeddings_dir)
    results = calculator.find_most_similar_chunks(
        query_path,
        top_k=top_k,
        # metadata_path=metadata_path
    )
    
    SimilarityVisualizer.save_results(
        results,
        output_dir / "similarity_results.txt"
    )
    
    SimilarityVisualizer.plot_similarities(
        results,
        output_dir / "similarity_plot.png"
    )
    
    logger.info("Similarity search completed successfully!")
    
    return results

In [10]:
initialize_mimi_model()

# First process the main audio file
logger.info("Processing main audio file...")
process_audio_files()

# Then process the query audio with MIMI
logger.info("Processing query audio with MIMI embeddings...")
query_audio = "/mnt/data/ashwin/SpeechRAG/libriSQA/question_long_news.mp3"
query_embedding_path = "/mnt/data/ashwin/SpeechRAG/libriSQA/embeddings/question_embedding.pt"

# Extract MIMI embeddings for query
embeddings = extract_mimi_embeddings(query_audio)
torch.save(embeddings, query_embedding_path)
logger.info(f"Query embeddings saved to {query_embedding_path}")

config.json:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/385M [00:00<?, ?B/s]

  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


preprocessor_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

INFO:__main__:Processing main audio file...
Extracting embeddings: 100%|██████████| 65/65 [00:32<00:00,  2.00it/s]
INFO:__main__:Processing completed successfully!
INFO:__main__:Processing query audio with MIMI embeddings...
INFO:__main__:Query embeddings saved to /mnt/data/ashwin/SpeechRAG/libriSQA/embeddings/question_embedding.pt
