In [6]:
import json
import random
import numpy as np
import torch
import torchaudio
from pathlib import Path
from datasets import load_dataset, Dataset
from tqdm.notebook import tqdm
import requests
import io
import zipfile
from collections import defaultdict

def get_hidden_states_and_audio(text: str, speaker_id: str) -> tuple[np.ndarray, torch.Tensor, dict]:
    """
    Get hidden states and audio for a piece of text using specified speaker
    
    Returns:
        hidden_states: numpy array of shape [n_frames, hidden_dim]
        audio: torch tensor of shape [1, samples] 
        metadata: dict with frame info
    """
    response = requests.post(
        "http://melchior:5000/v1/audio/speech/hidden",
        json={
            "text": text,
            "speaker_id": speaker_id,
            "return_audio": True
        }
    )
    
    if response.status_code != 200:
        raise RuntimeError(f"Request failed with status {response.status_code}: {response.text}")
    
    # Read zip contents in memory
    with zipfile.ZipFile(io.BytesIO(response.content)) as zf:
        # Get hidden states
        with zf.open('hidden_states.npy') as f:
            hidden_states = np.load(io.BytesIO(f.read()))
        
        # Get audio
        with zf.open('audio.wav') as f:
            # Use torchaudio to load WAV directly from bytes
            audio, sr = torchaudio.load(io.BytesIO(f.read()))
            assert sr == 44100
        
        # Get metadata
        with zf.open('metadata.json') as f:
            metadata = json.loads(f.read())
            
    return hidden_states, audio, metadata


In [8]:
def rms_silence(audio: torch.Tensor, frame_rate=21.535, threshold=0.005):
    """Dead simple RMS silence detection"""
    sr = 44100
    samples_per_frame = int(sr / frame_rate)
    
    # Convert to mono if stereo
    if audio.dim() > 1:
        audio = audio[0]
    
    # Compute RMS energy per frame
    windows = audio.unfold(0, samples_per_frame, samples_per_frame)
    energy = torch.sqrt(torch.mean(windows ** 2, dim=1))
    
    return energy < threshold


In [7]:
def get_language_buffer(lang_code, buffer_size=1000):
    """Get a buffer of samples for a language, reuse until empty"""
    path = f"{lang_code.upper()}/*.tar"
    dataset = load_dataset(
        "amphion/Emilia-Dataset",
        data_files={lang_code.lower(): path},
        split=lang_code.lower(),
        streaming=True
    )
    
    buffer = []
    for sample in dataset:
        buffer.append(sample)
        if len(buffer) >= buffer_size:
            break
            
    return buffer

def create_silence_dataset(threshold=0.005, n_train=500, n_test=250, n_val=100, seed=42, buffer_size=1000):
    """Create silence detection dataset balanced across languages"""
    random.seed(seed)
    np.random.seed(seed)
    
    # Load and group speakers by language
    with open('./voices/index.json') as f:
        speakers = json.load(f)['speakers']
    
    lang_speakers = defaultdict(list)
    for speaker_id in speakers:
        if speaker_id != 'default':
            lang = speaker_id.split('_')[0]
            lang_speakers[lang].append(speaker_id)
    
    samples_per_lang = {
        'EN': n_train + n_test + n_val,
        'JA': n_train + n_test + n_val,
        'ZH': n_train + n_test + n_val
    }
    
    all_data = []
    
    for lang, n_samples in tqdm(samples_per_lang.items(), desc="Languages"):
        # Get a buffer of samples for this language
        buffer = get_language_buffer(lang.lower(), buffer_size)
        
        samples_collected = 0
        with tqdm(total=n_samples, desc=f"{lang} samples") as pbar:
            while samples_collected < n_samples:
                if not buffer:  # If buffer empty, refill it
                    buffer = get_language_buffer(lang.lower(), buffer_size)
                
                # Get a random sample from buffer
                sample = random.choice(buffer)
                buffer.remove(sample)  # Remove used sample
                
                # Get a random speaker for this language
                speaker_id = random.choice(lang_speakers[lang])
                
                try:
                    hidden, audio, meta = get_hidden_states_and_audio(
                        sample['json']['text'],
                        speaker_id
                    )
                    silence_mask = rms_silence(audio, threshold=threshold)
                    
                    all_data.append({
                        'hidden_states': hidden,
                        'is_silence': silence_mask.numpy(),
                        'language': lang,
                        'n_frames': len(silence_mask),
                        'speaker_id': speaker_id,
                        'text': sample['json']['text']
                    })
                    
                    samples_collected += 1
                    pbar.update(1)
                    
                except Exception as e:
                    print(f"Error processing sample: {e}")
                    continue
    
    # Split into train/test/val
    random.shuffle(all_data)
    splits = {
        'train': all_data[:n_train*3],  # *3 because we have 3 languages
        'test': all_data[n_train*3:n_train*3 + n_test*3],
        'val': all_data[n_train*3 + n_test*3:]
    }
    
    # Convert to HF datasets
    output_path = Path('./silence_dataset_prod')
    output_path.mkdir(exist_ok=True)
    
    for split_name, split_data in splits.items():
        # Convert to format HF datasets likes
        ds_dict = {
            'hidden_states': [d['hidden_states'] for d in split_data],
            'is_silence': [d['is_silence'] for d in split_data],
            'language': [d['language'] for d in split_data],
            'n_frames': [d['n_frames'] for d in split_data],
            'speaker_id': [d['speaker_id'] for d in split_data],
            'text': [d['text'] for d in split_data]
        }
        
        # Create and save dataset
        ds = Dataset.from_dict(ds_dict)
        ds.save_to_disk(output_path / split_name)
        
    # Save metadata
    with open(output_path / 'metadata.json', 'w') as f:
        json.dump({
            'threshold': threshold,
            'n_train': n_train,
            'n_test': n_test,
            'n_val': n_val,
            'seed': seed
        }, f)
    
    return output_path


In [12]:
dataset_path = create_silence_dataset(
    threshold=0.005,
    n_train=100,
    n_test=50,
    n_val=10,
    seed=42,
    buffer_size=1000
)
print(f"Dataset saved to {dataset_path}")

Languages:   0%|          | 0/3 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1140 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1140 [00:00<?, ?it/s]

EN samples:   0%|          | 0/160 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/70 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/70 [00:00<?, ?it/s]

JA samples:   0%|          | 0/160 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/920 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/920 [00:00<?, ?it/s]

ZH samples:   0%|          | 0/160 [00:00<?, ?it/s]

Saving the dataset (0/1 shards):   0%|          | 0/300 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/150 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/30 [00:00<?, ? examples/s]

Dataset saved to silence_dataset


In [14]:
def analyze_balance_and_metrics(dataset_path, model):
    train_ds = load_from_disk(str(Path(dataset_path) / 'train'))
    val_ds = load_from_disk(str(Path(dataset_path) / 'val'))
    
    # Get class balance
    def count_labels(ds):
        silence_counts = sum(np.sum(s) for s in ds['is_silence'])
        total_frames = sum(len(s) for s in ds['is_silence'])
        return silence_counts, total_frames
    
    train_silence, train_total = count_labels(train_ds)
    val_silence, val_total = count_labels(val_ds)
    
    print("Class balance:")
    print(f"Train: {train_silence}/{train_total} silent frames ({100*train_silence/train_total:.1f}%)")
    print(f"Val: {val_silence}/{val_total} silent frames ({100*val_silence/val_total:.1f}%)")
    
    # Detailed metrics on validation set
    X_val = np.vstack([h for sample in val_ds['hidden_states'] for h in sample])
    y_val = np.hstack([s for sample in val_ds['is_silence'] for s in sample])
    
    with torch.no_grad():
        val_pred = model(torch.FloatTensor(X_val)).squeeze().numpy()
    
    # Check different thresholds
    print("\nMetrics at different thresholds:")
    for threshold in [0.3, 0.4, 0.5, 0.6, 0.7]:
        pred_bool = val_pred > threshold
        true_pos = (pred_bool & y_val.astype(bool)).sum()
        false_pos = (pred_bool & ~y_val.astype(bool)).sum()
        false_neg = (~pred_bool & y_val.astype(bool)).sum()
        
        precision = true_pos / (true_pos + false_pos)
        recall = true_pos / (true_pos + false_neg)
        f1 = 2 * (precision * recall) / (precision + recall)
        
        print(f"\nThreshold {threshold}:")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1: {f1:.4f}")
        print(f"False positives: {false_pos}")
        print(f"False negatives: {false_neg}")

# Check it
analyze_balance_and_metrics('./silence_dataset', model)

Class balance:
Train: 10062/57054 silent frames (17.6%)
Val: 1253/6347 silent frames (19.7%)

Metrics at different thresholds:

Threshold 0.3:
Precision: 0.9332
Recall: 0.9473
F1: 0.9402
False positives: 85
False negatives: 66

Threshold 0.4:
Precision: 0.9387
Recall: 0.9409
F1: 0.9398
False positives: 77
False negatives: 74

Threshold 0.5:
Precision: 0.9436
Recall: 0.9346
F1: 0.9391
False positives: 70
False negatives: 82

Threshold 0.6:
Precision: 0.9472
Recall: 0.9314
F1: 0.9392
False positives: 65
False negatives: 86

Threshold 0.7:
Precision: 0.9507
Recall: 0.9234
F1: 0.9368
False positives: 60
False negatives: 96


In [45]:
def train_language_probe(dataset_path, batch_size=256, lr=1e-3, epochs=20, device="cuda"):
    train_ds = load_from_disk(str(Path(dataset_path) / 'train'))
    val_ds = load_from_disk(str(Path(dataset_path) / 'val'))
    
    # Create language mapping
    lang_to_idx = {'EN': 0, 'JA': 1, 'ZH': 2}
    
    def prepare_data(ds):
        hidden_states = []
        languages = []
        for i, sample in enumerate(ds):
            # Keep all frames
            hidden_states.append(sample['hidden_states'])
            # Repeat language label for each frame
            languages.extend([lang_to_idx[sample['language']]] * len(sample['hidden_states']))
        
        hidden_states = np.vstack(hidden_states)  # Stack all frames
        return (torch.FloatTensor(hidden_states).to(device), 
                torch.LongTensor(languages).to(device))
    
    X_train, y_train = prepare_data(train_ds)
    X_val, y_val = prepare_data(val_ds)
    
    print(f"Training on {len(X_train)} frames from {len(train_ds)} samples")
    print(f"Validating on {len(X_val)} frames from {len(val_ds)} samples")
    
    train_loader = DataLoader(
        TensorDataset(X_train, y_train), 
        batch_size=batch_size, 
        shuffle=True
    )
    
    model = LanguageProbe().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_correct = 0
        
        for batch_X, batch_y in train_loader:
            optimizer.zero_grad()
            pred = model(batch_X)
            loss = criterion(pred, batch_y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_correct += (pred.argmax(1) == batch_y).sum().item()
        
        train_loss /= len(train_loader)
        train_acc = train_correct / len(X_train)
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_pred = model(X_val)
            val_loss = criterion(val_pred, y_val)
            val_correct = (val_pred.argmax(1) == y_val).sum().item()
            val_acc = val_correct / len(X_val)
            
            # Frame-level confusion matrix
            conf_matrix = torch.zeros(3, 3)
            preds = val_pred.argmax(1)
            for t, p in zip(y_val, preds):
                conf_matrix[t, p] += 1
            # Normalize by true class counts
            conf_matrix = conf_matrix / conf_matrix.sum(dim=1, keepdim=True)
        
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train loss: {train_loss:.4f}, Train acc: {train_acc:.4f}")
        print(f"Val loss: {val_loss:.4f}, Val acc: {val_acc:.4f}")
        
        if (epoch + 1) % 5 == 0:
            print("\nNormalized Confusion Matrix:")
            print("    EN  JA  ZH")
            for i, lang in enumerate(['EN', 'JA', 'ZH']):
                print(f"{lang}: {conf_matrix[i].tolist()}")
        print()
    
    return model

# Let's see how it does frame-by-frame!
language_model = train_language_probe('./silence_dataset')

Training on 57054 frames from 300 samples
Validating on 6347 frames from 30 samples
Epoch 1/20
Train loss: 0.7218, Train acc: 0.8280
Val loss: 0.5939, Val acc: 0.9570

Epoch 2/20
Train loss: 0.5818, Train acc: 0.9689
Val loss: 0.5883, Val acc: 0.9627

Epoch 3/20
Train loss: 0.5680, Train acc: 0.9832
Val loss: 0.5806, Val acc: 0.9701

Epoch 4/20
Train loss: 0.5654, Train acc: 0.9858
Val loss: 0.5723, Val acc: 0.9786

Epoch 5/20
Train loss: 0.5611, Train acc: 0.9902
Val loss: 0.5788, Val acc: 0.9720

Normalized Confusion Matrix:
    EN  JA  ZH
EN: [0.9725330471992493, 0.016954900696873665, 0.010512038134038448]
JA: [0.03423607721924782, 0.9519672989845276, 0.013796627521514893]
ZH: [0.0, 0.002081887563690543, 0.9979181289672852]

Epoch 6/20
Train loss: 0.5602, Train acc: 0.9909
Val loss: 0.5782, Val acc: 0.9726

Epoch 7/20
Train loss: 0.5574, Train acc: 0.9940
Val loss: 0.5784, Val acc: 0.9726

Epoch 8/20
Train loss: 0.5570, Train acc: 0.9943
Val loss: 0.5659, Val acc: 0.9849

Epoch 9/20