In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
data = "/content/drive/MyDrive/third_dihard_challenge_dev_LDC2022S12/third_dihard_challenge_dev/data"

In [None]:
!pip install speechbrain

Collecting speechbrain
  Downloading speechbrain-1.0.3-py3-none-any.whl.metadata (24 kB)
Collecting hyperpyyaml (from speechbrain)
  Downloading HyperPyYAML-1.2.2-py3-none-any.whl.metadata (7.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.9->speechbrain)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.9->speechbrain)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.9->speechbrain)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.9->speechbrain)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.9->speechbrain)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3

In [None]:
import torch
import torchaudio
import numpy as np
from speechbrain.pretrained import EncoderClassifier

class FeatureExtractor:
    def __init__(self):
        # Load pretrained speaker embedding model (x-vector)
        self.embedding_model = EncoderClassifier.from_hparams(
            source="speechbrain/spkrec-xvect-voxceleb",
            savedir="pretrained_models/spkrec-xvect-voxceleb"
        )

    def extract_embeddings(self, audio_file, segment_length=1.5, hop_length=0.75):
        """Extract speaker embeddings from audio segments"""
        waveform, sample_rate = torchaudio.load(audio_file)

        # Convert to mono if needed
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Resample if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)
            sample_rate = 16000

        # Segment the audio
        segment_samples = int(segment_length * sample_rate)
        hop_samples = int(hop_length * sample_rate)
        total_samples = waveform.shape[1]

        embeddings = []
        timestamps = []

        for start in range(0, total_samples - segment_samples + 1, hop_samples):
            end = start + segment_samples
            segment = waveform[:, start:end]

            # Extract embedding
            with torch.no_grad():
                embedding = self.embedding_model.encode_batch(segment)
                embeddings.append(embedding.squeeze().cpu().numpy())

            # Store timestamp
            timestamps.append((start/sample_rate, end/sample_rate))

        return np.array(embeddings), timestamps

DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for _speechbrain_save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for _speechbrain_load
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for load
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for _save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for _recover
  from speechbrain.pretrained import EncoderClassifier


In [None]:
import gym
from gym import spaces
import numpy as np

class SpeakerDiarizationEnv(gym.Env):
    def __init__(self, embeddings, ground_truth_num_speakers=None):
        super(SpeakerDiarizationEnv, self).__init__()

        self.embeddings = embeddings
        self.num_segments = len(embeddings)
        self.ground_truth_num_speakers = ground_truth_num_speakers
        self.max_speakers = 20  # Maximum number of speakers to consider

        # Define action space:
        # - Action 0: Add new speaker
        # - Action 1 to max_speakers: Assign to existing speaker
        self.action_space = spaces.Discrete(self.max_speakers + 1)

        # Define observation space:
        # - Current embedding
        # - Speaker centroids
        # - Number of segments per speaker so far
        embedding_dim = embeddings[0].shape[0]
        self.observation_space = spaces.Dict({
            'current_embedding': spaces.Box(low=-np.inf, high=np.inf, shape=(embedding_dim,)),
            'speaker_centroids': spaces.Box(low=-np.inf, high=np.inf, shape=(self.max_speakers, embedding_dim)),
            'speaker_counts': spaces.Box(low=0, high=self.num_segments, shape=(self.max_speakers,))
        })

        # Initialize state
        self.reset()

    def reset(self):
        self.current_segment = 0
        self.speaker_assignments = []
        self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
        self.speaker_centroids = np.zeros((self.max_speakers, self.embeddings[0].shape[0]))
        self.speaker_counts = np.zeros(self.max_speakers)
        self.active_speakers = 0

        return self._get_observation()

    def _get_observation(self):
        return {
            'current_embedding': self.embeddings[self.current_segment],
            'speaker_centroids': self.speaker_centroids,
            'speaker_counts': self.speaker_counts
        }

    def _calculate_reward(self):
        if self.current_segment < self.num_segments - 1:
            # Small step reward/penalty based on current assignment quality
            return 0  # Intermediate steps have no reward
        else:
            # Final reward based on correct number of speakers
            estimated_speakers = sum(self.speaker_counts > 0)

            if self.ground_truth_num_speakers is not None:
                # If we know ground truth, use it for reward
                error = abs(estimated_speakers - self.ground_truth_num_speakers)
                return 10.0 if error == 0 else -error
            else:
                # Without ground truth, use cluster quality metrics
                # Simplified reward based on silhouette score
                from sklearn.metrics import silhouette_score

                # Create data for silhouette calculation
                data = []
                labels = []

                for spk_idx, segments in enumerate(self.speaker_embeddings):
                    if len(segments) > 0:
                        data.extend(segments)
                        labels.extend([spk_idx] * len(segments))

                if len(set(labels)) > 1:  # Need at least 2 clusters
                    score = silhouette_score(np.array(data), np.array(labels))
                    return score * 10  # Scale the score
                else:
                    return -5  # Penalty for only one cluster

    def step(self, action):
        # Process action
        if action == 0 and self.active_speakers < self.max_speakers:
            # Add new speaker
            new_speaker_id = self.active_speakers
            self.speaker_assignments.append(new_speaker_id)
            self.speaker_embeddings[new_speaker_id].append(self.embeddings[self.current_segment])
            self.speaker_centroids[new_speaker_id] = self.embeddings[self.current_segment]
            self.speaker_counts[new_speaker_id] += 1
            self.active_speakers += 1
        elif 1 <= action <= self.max_speakers and self.speaker_counts[action-1] > 0:
            # Assign to existing speaker
            speaker_id = action - 1
            self.speaker_assignments.append(speaker_id)
            self.speaker_embeddings[speaker_id].append(self.embeddings[self.current_segment])

            # Update centroid
            all_embeddings = np.array(self.speaker_embeddings[speaker_id])
            self.speaker_centroids[speaker_id] = np.mean(all_embeddings, axis=0)
            self.speaker_counts[speaker_id] += 1
        else:
            # Invalid action
            self.speaker_assignments.append(-1)  # Mark as unassigned

        # Move to next segment
        self.current_segment += 1
        done = self.current_segment >= self.num_segments

        # Calculate reward
        reward = self._calculate_reward() if done else 0

        # Get new observation
        new_obs = self._get_observation() if not done else None

        # Info dictionary
        info = {
            'num_speakers': sum(self.speaker_counts > 0),
            'speaker_counts': self.speaker_counts.copy()
        }

        return new_obs, reward, done, info

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

class DQNetwork(nn.Module):
    def __init__(self, embedding_dim, max_speakers):
        super(DQNetwork, self).__init__()

        self.embedding_processor = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU()
        )

        self.centroid_processor = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU()
        )

        self.count_processor = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU()
        )

        self.decision_network = nn.Sequential(
            nn.Linear(128 + 128 + 32, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, max_speakers + 1)  # Output Q-values for each action
        )

    def forward(self, embedding, centroid, count):
      # Process current embedding
      # embedding shape should be [batch_size, embedding_dim]
      # print("embeddings",embedding.shape)
      emb_features = self.embedding_processor(embedding)  # Shape: [batch_size, 128]

      # Process centroid
      # centroid shape should be [batch_size, embedding_dim]
      # If it's [batch_size, num_centroids, embedding_dim], we need to reshape
      if centroid.dim() == 3:
          # If you're passing multiple centroids, you might need to select one or combine them
          batch_size, num_centroids, emb_dim = centroid.shape
          # Option 1: Use the first centroid for each batch
          centroid = centroid[:, 0, :]  # Shape: [batch_size, embedding_dim]
          # OR Option 2: Average all centroids
          # centroid = torch.mean(centroid, dim=1)  # Shape: [batch_size, embedding_dim]

      centroid_features = self.centroid_processor(centroid)  # Shape: [batch_size, 128]

      # Process count - make sure it's 2D [batch_size, 1]
      if count.dim() == 1:
          count = count.unsqueeze(1)  # Add feature dimension
      count_features = self.count_processor(count)  # Shape: [batch_size, 32]

      # Now all features should be 2D tensors with batch as first dimension
      # Concatenate features
      # print("emb_features",emb_features.shape)
      # print("centroid_features",centroid_features.shape)
      # print("count_features",count_features.shape)
      if emb_features.dim() == 1:
          emb_features = emb_features.unsqueeze(0)
      if centroid_features.dim() == 1:
          centroid_features = centroid_features.unsqueeze(0)
      if count_features.dim() == 1:
          count_features = count_features.unsqueeze(0)

      combined = torch.cat([emb_features, centroid_features, count_features], dim=1)

      # Get Q-values
      q_values = self.decision_network(combined)

      return q_values

class DQNAgent:
    def __init__(self, embedding_dim, max_speakers, learning_rate=0.001, gamma=0.99,
                 epsilon_start=1.0, epsilon_decay=0.995, epsilon_min=0.01):
        self.embedding_dim = embedding_dim
        self.max_speakers = max_speakers
        self.gamma = gamma
        self.epsilon = epsilon_start
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        # Q-Network
        self.q_network = DQNetwork(embedding_dim, max_speakers)
        self.target_network = DQNetwork(embedding_dim, max_speakers)
        self.target_network.load_state_dict(self.q_network.state_dict())

        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
        self.criterion = nn.MSELoss()

        # Experience replay
        self.memory = deque(maxlen=10000)
        self.batch_size = 64

    def store_experience(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def select_action(self, state, current_speakers):
        embedding = torch.FloatTensor(state['current_embedding'])

        # Find the speaker with highest count
        speaker_counts = state['speaker_counts']
        if np.sum(speaker_counts) > 0:
            best_speaker = np.argmax(speaker_counts)
            centroid = torch.FloatTensor(state['speaker_centroids'][best_speaker])
            count = torch.FloatTensor([speaker_counts[best_speaker]])
        else:
            # No speakers yet
            centroid = torch.zeros(self.embedding_dim)
            count = torch.zeros(1)

        # Epsilon-greedy action selection
        if random.random() < self.epsilon:
            # With probability epsilon, select random action
            if current_speakers == 0:
                return 0  # Must add first speaker
            elif current_speakers == self.max_speakers:
                return random.randint(1, current_speakers)  # Can't add more speakers
            else:
                return random.randint(0, current_speakers)
        else:
            # Otherwise, select best action according to model
            self.q_network.eval()
            with torch.no_grad():
                q_values = self.q_network(embedding, centroid, count)

                # Mask invalid actions
                if current_speakers == 0:
                    # Must add a new speaker
                    valid_q = q_values[0].item()
                    return 0
                elif current_speakers == self.max_speakers:
                    # Can't add more speakers
                    valid_indices = torch.arange(1, current_speakers + 1)
                    valid_q = q_values[valid_indices]
                    return valid_indices[torch.argmax(valid_q)].item()
                else:
                    valid_indices = torch.arange(0, current_speakers + 1)
                    valid_q = q_values[valid_indices]
                    return valid_indices[torch.argmax(valid_q)].item()

    def train(self):
        if len(self.memory) < self.batch_size:
            return

        # Sample batch from memory
        batch = random.sample(self.memory, self.batch_size)

        states = []
        actions = []
        rewards = []
        next_states = []
        dones = []

        for exp in batch:
            states.append(exp[0])
            actions.append(exp[1])
            rewards.append(exp[2])
            next_states.append(exp[3])
            dones.append(exp[4])

        # Process batch
        embeddings = torch.FloatTensor(np.array([s['current_embedding'] for s in states]))

        # Find the speaker with highest count for each state
        centroids = []
        counts = []

        for s in states:
            speaker_counts = s['speaker_counts']
            if np.sum(speaker_counts) > 0:
                best_speaker = np.argmax(speaker_counts)
                centroids.append(s['speaker_centroids'][best_speaker])
                counts.append([speaker_counts[best_speaker]])
            else:
                # No speakers yet
                centroids.append(np.zeros(self.embedding_dim))
                counts.append([0.0])

        centroids = torch.FloatTensor(np.array(centroids))
        counts = torch.FloatTensor(np.array(counts))

        # Same for next_states
        next_embeddings = []
        next_centroids = []
        next_counts = []

        for s in next_states:
            if s is not None:
                next_embeddings.append(s['current_embedding'])

                speaker_counts = s['speaker_counts']
                if np.sum(speaker_counts) > 0:
                    best_speaker = np.argmax(speaker_counts)
                    next_centroids.append(s['speaker_centroids'][best_speaker])
                    next_counts.append([speaker_counts[best_speaker]])
                else:
                    # No speakers yet
                    next_centroids.append(np.zeros(self.embedding_dim))
                    next_counts.append([0.0])
            else:
                # Episode ended
                next_embeddings.append(np.zeros(self.embedding_dim))
                next_centroids.append(np.zeros(self.embedding_dim))
                next_counts.append([0.0])

        next_embeddings = torch.FloatTensor(np.array(next_embeddings))
        next_centroids = torch.FloatTensor(np.array(next_centroids))
        next_counts = torch.FloatTensor(np.array(next_counts))

        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        dones = torch.FloatTensor(dones)

        # Compute Q-values
        self.q_network.train()
        q_values = self.q_network(embeddings, centroids, counts)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Compute target Q-values
        self.target_network.eval()
        with torch.no_grad():
            next_q_values = self.target_network(next_embeddings, next_centroids, next_counts)
            max_next_q = torch.max(next_q_values, dim=1)[0]
            targets = rewards + (1 - dones) * self.gamma * max_next_q

        # Compute loss and optimize
        loss = self.criterion(q_values, targets)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update epsilon
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

        # Update target network occasionally
        if random.random() < 0.01:  # 1% chance each batch
            self.target_network.load_state_dict(self.q_network.state_dict())

In [None]:
def train_agent(agent, data_loader, num_episodes=1000):
    total_rewards = []
    speaker_estimate_errors = []

    for episode in range(num_episodes):
        # Get an audio file and its ground truth
        audio_file, num_speakers = data_loader.get_random_sample()

        # Extract features
        feature_extractor = FeatureExtractor()
        embeddings, timestamps = feature_extractor.extract_embeddings(audio_file)

        # Create environment
        env = SpeakerDiarizationEnv(embeddings, ground_truth_num_speakers=num_speakers)

        # Reset environment
        state = env.reset()
        done = False
        episode_reward = 0

        while not done:
            # Get current number of active speakers
            current_speakers = sum(env.speaker_counts > 0)

            # Select action
            action = agent.select_action(state, current_speakers)

            # Take action
            next_state, reward, done, info = env.step(action)

            # Store experience
            agent.store_experience(state, action, reward, next_state, done)

            # Train agent
            agent.train()

            # Update state
            state = next_state
            episode_reward += reward

        # Episode complete
        total_rewards.append(episode_reward)
        estimated_speakers = info['num_speakers']
        error = abs(estimated_speakers - num_speakers)
        speaker_estimate_errors.append(error)

        # Print progress
        if episode % 10 == 0:
            print(f"Episode {episode}/{num_episodes}")
            print(f"  Reward: {episode_reward:.2f}")
            print(f"  Estimated speakers: {estimated_speakers}, True: {num_speakers}, Error: {error}")
            print(f"  Mean error (last 100): {np.mean(speaker_estimate_errors[-100:]):.2f}")
            print(f"  Epsilon: {agent.epsilon:.4f}")

    return total_rewards, speaker_estimate_errors

In [None]:
import os
import random
from pathlib import Path
import pandas as pd

class DIHARD3Loader:
    def __init__(self, data_path, rttm_path):
        """
        Initialize the DIHARD III dataset loader

        Args:
            data_path: Path to the audio files
            rttm_path: Path to the RTTM files with speaker annotations
        """
        self.data_path = Path(data_path)
        self.rttm_path = Path(rttm_path)

        # Load and parse the dataset
        self.samples = self._load_dataset()

    def _load_dataset(self):
        samples = []

        # Find all RTTM files
        rttm_files = list(self.rttm_path.glob("*.rttm"))

        for rttm_file in rttm_files:
            file_id = rttm_file.stem
            audio_file = self.data_path / f"{file_id}.flac"

            if audio_file.exists():
                # Parse RTTM to get number of speakers
                speakers = set()
                with open(rttm_file, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if parts[0] == "SPEAKER":
                            speakers.add(parts[7])  # Speaker ID is in column 8

                samples.append({
                    'file_id': file_id,
                    'audio_path': str(audio_file),
                    'rttm_path': str(rttm_file),
                    'num_speakers': len(speakers)
                })

        print(f"Loaded {len(samples)} samples from DIHARD III dataset")
        return samples

    def get_random_sample(self):
        """Get a random sample from the dataset"""
        sample = random.choice(self.samples)
        return sample['audio_path'], sample['num_speakers']

    def get_all_samples(self):
        """Get all samples from the dataset"""
        return [(s['audio_path'], s['num_speakers']) for s in self.samples]

In [None]:
def evaluate_agent(agent, data_loader, num_samples=100):
    """Evaluate the agent on the test set"""
    samples = data_loader.get_all_samples()
    if num_samples < len(samples):
        samples = random.sample(samples, num_samples)

    results = []

    for i, (audio_file, num_speakers) in enumerate(samples):
        print(f"Evaluating sample {i+1}/{len(samples)}: {os.path.basename(audio_file)}")

        # Extract features
        feature_extractor = FeatureExtractor()
        embeddings, timestamps = feature_extractor.extract_embeddings(audio_file)

        # Create environment
        env = SpeakerDiarizationEnv(embeddings, ground_truth_num_speakers=num_speakers)

        # Reset environment
        state = env.reset()
        done = False

        while not done:
            # Get current number of active speakers
            current_speakers = sum(env.speaker_counts > 0)

            # Select action (using greedy policy, no exploration)
            agent.epsilon = 0  # No exploration during evaluation
            action = agent.select_action(state, current_speakers)

            # Take action
            next_state, reward, done, info = env.step(action)

            # Update state
            state = next_state

        # Record results
        estimated_speakers = info['num_speakers']
        error = abs(estimated_speakers - num_speakers)

        results.append({
            'file_id': os.path.basename(audio_file),
            'true_speakers': num_speakers,
            'estimated_speakers': estimated_speakers,
            'error': error
        })

        print(f"  True: {num_speakers}, Estimated: {estimated_speakers}, Error: {error}")

    # Calculate metrics
    errors = [r['error'] for r in results]
    mean_error = np.mean(errors)
    exact_match = np.mean([r['error'] == 0 for r in results])

    print(f"Evaluation results:")
    print(f"  Mean absolute error: {mean_error:.2f}")
    print(f"  Exact match accuracy: {exact_match:.2f}")

    return results

In [None]:
def main():
    # Setup paths
    dihard_data_path = f"{data}/flac"
    dihard_rttm_path = f"{data}/rttm"

    # Initialize dataset loader
    data_loader = DIHARD3Loader(dihard_data_path, dihard_rttm_path)

    # Initialize agent
    embedding_dim = 512  # Dimension of x-vectors
    max_speakers = 20
    agent = DQNAgent(embedding_dim, max_speakers)

    # Train the agent
    rewards, errors = train_agent(agent, data_loader, num_episodes=1000)

    # Plot training progress
    import matplotlib.pyplot as plt

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(rewards)
    plt.title('Episode Rewards')
    plt.xlabel('Episode')
    plt.ylabel('Reward')

    plt.subplot(1, 2, 2)
    plt.plot(np.convolve(errors, np.ones(100)/100, mode='valid'))
    plt.title('Mean Absolute Error (100-episode window)')
    plt.xlabel('Episode')
    plt.ylabel('Error')

    plt.tight_layout()
    plt.savefig('training_progress.png')

    # Save the model
    torch.save(agent.q_network.state_dict(), 'speaker_diarization_rl_model.pt')

    # Evaluate the agent
    results = evaluate_agent(agent, data_loader, num_samples=50)

    # Save results
    import json
    with open('evaluation_results.json', 'w') as f:
        json.dump(results, f, indent=2)

if __name__ == "__main__":
    main()