In [1]:
import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

In [1]:
import csv
import json
import math
import os
import random
import time
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader, Dataset
from transformers import SpeechT5HifiGan

vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

In [31]:
meltransformer = torchaudio.transforms.MelSpectrogram(
    sample_rate=16000,
    n_fft=1024,
    win_length=1024,
    hop_length=256,
    n_mels=80,
    f_min=80.0,
    f_max=7600.0,
    power=1.0,  # magnitude mel, then take log
    mel_scale="slaney",
    norm="slaney",
)


def to_logmelspec(waveform, sr):
    if sr != 16000:
        waveform = torchaudio.functional.resample(waveform, sr, 16000)
    mel = meltransformer(waveform)
    logmel = torch.log10(torch.clamp(mel, min=1e-5)).transpose(-1, -2)
    return logmel


from glob import glob

from IPython.display import Audio

audio_files = glob("VOC_JP/**/*.ogg", recursive=True)
# waveform, sr = torchaudio.load(audio_files[0])
# mel = to_logmelspec(waveform, sr)
# reconstructed_waveform = vocoder(mel)
# Audio(reconstructed_waveform.detach().numpy(), rate=16000)

In [5]:
from torch.utils.data import Dataset


class LogMelDataset(Dataset):
    def __init__(self, logmel, label):
        self.logmel = logmel
        self.label = label

    def __len__(self):
        return len(self.logmel)

    def __getitem__(self, idx):
        return self.logmel[idx], self.label[idx]

In [None]:

audio_labels = [os.path.basename(os.path.dirname(f)) for f in audio_files]
label_vocab = sorted(set(audio_labels))
label_to_idx = {label: idx for idx, label in enumerate(label_vocab)}
idx_to_label = {idx: label for label, idx in label_to_idx.items()}

logmel_sequences = []
from tqdm.auto import tqdm

for path, label in zip(tqdm(audio_files), audio_labels):
    waveform, sr = torchaudio.load(path)
    waveform = waveform.mean(dim=0)
    logmel = to_logmelspec(waveform, sr)
    logmel = logmel.float()  # (frames, mel)
    logmel_sequences.append(logmel)

# lengths = torch.tensor([seq.size(0) for seq in logmel_sequences], dtype=torch.long)
# padded_logmels = pad_sequence(logmel_sequences, batch_first=True)
# labels_tensor = torch.tensor(label_indices, dtype=torch.long)

# speech_dataset = TensorDataset(padded_logmels, labels_tensor)

# print(f"Dataset samples: {len(speech_dataset)}")
# print(f"Feature tensor shape: {padded_logmels.shape} (batch, frames, mel)")
# print(f"Unique labels: {len(label_vocab)}")

# example_mel, example_label = speech_dataset[0]
# print(
#     "Example sample -> frames: {frames}, mel bins: {mels}, label: {label}".format(
#         frames=example_mel.size(0),
#         mels=example_mel.size(1),
#         label=idx_to_label[int(example_label.item())],
#     )
# )
# torch.save(speech_dataset, "dataset_slim.pt")





dataset = LogMelDataset(logmel_sequences, audio_labels)
torch.save(dataset, "dataset_slim.pt")

  0%|          | 0/19452 [00:00<?, ?it/s]

In [86]:
import torch, torchaudio
from transformers import SpeechT5HifiGan, SpeechT5FeatureExtractor

# 1) load audio (mono, 16 kHz)
wav, sr = torchaudio.load(audio_files[0])  # (C, T)
if wav.size(0) > 1:
    wav = wav[:1]
if sr != 16000:
    wav = torchaudio.functional.resample(wav, sr, 16000)
wav = wav.squeeze(0).float()  # (T,)

# 2) exact SpeechT5 log-mel with HF extractor
fe = (
    SpeechT5FeatureExtractor()
)  # defaults: 16k, 80 mel, fmin=80, fmax=7600, hop=16ms, win=64ms
mel = fe(
    audio_target=wav.numpy(), sampling_rate=16000, return_tensors="pt"
).input_values
print(mel.shape)
# mel: (time, 80)  -> vocoder accepts (time,80) or (1,time,80)
# if you prefer batched:
# mel = mel.unsqueeze(0)  # (1, time, 80)

# 3) vocode
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").eval()
with torch.no_grad():
    out = vocoder(mel)  # returns 1-D tensor if unbatched, (B,T) if batched
wave = out if mel.dim() == 2 else out.squeeze(0)
# torchaudio.save("recon.wav", wave.unsqueeze(0), 16000)

torchaudio.save(
    "orig_16k.mp3", torchaudio.functional.resample(wave.unsqueeze(0), 16000, sr), sr
)
Audio(wave.detach().numpy(), rate=16000)

torch.Size([1, 203, 80])


In [79]:
wav, sr = torchaudio.load(audio_files[0])
wav16k = torchaudio.functional.resample(wav, sr, 16000)
from IPython.display import display

display(Audio(wav16k, rate=16000))
display(Audio(wav, rate=sr))
# save two versions for comparison
print(wav16k.shape, wav.shape)
torchaudio.save("orig_16k.mp3", torchaudio.functional.resample(wav16k, 16000, sr), sr)
torchaudio.save("orig_44.1k.mp3", wav, sr)

torch.Size([1, 51798]) torch.Size([1, 142767])


In [6]:
dataset = torch.load("dataset_slim.pt", weights_only=False)

In [11]:
from sklearn.model_selection import train_test_split

train_ds, val_ds = train_test_split(dataset, test_size=0.2, random_state=42)

In [38]:
from torch.nn.utils.rnn import pad_sequence

all_labels = sorted(set(dataset.label))
label_to_idx = {l: i for i, l in enumerate(all_labels)}
idx_to_label = {i: l for l, i in label_to_idx.items()}


def collate_fn(batch):
    feats, labels = zip(*batch)
    lengths = torch.tensor([f.size(0) for f in feats])
    feats_padded = pad_sequence([*feats], batch_first=True)  # (B,T,F)
    labels = torch.tensor([label_to_idx[l] for l in labels], dtype=torch.long)
    return feats_padded.transpose(1, 2), labels, lengths  # (B,F,T)


train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [None]:
class XVector(nn.Module):
    def __init__(self, feat_dim=80, emb_dim=256, num_classes=100):
        super().__init__()
        self.tdnn = nn.Sequential(
            nn.Conv1d(feat_dim, 512, 5, dilation=1),
            nn.ReLU(),
            nn.Conv1d(512, 512, 3, dilation=2),
            nn.ReLU(),
            nn.Conv1d(512, 512, 3, dilation=3),
            nn.ReLU(),
        )
        self.segment = nn.Sequential(
            nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, emb_dim)
        )
        self.classifier = nn.Linear(emb_dim, num_classes)

    def stats_pooling(self, x, lengths):
        """
        x: (B, D, T)  float32
        lengths: (B,) int64  number of valid frames per item
        returns: (B, 2D) = [mean || std]
        """
        B, D, T = x.shape
        EPS = 1e-8
        # (B, 1, T) mask of 1.0 for valid frames, 0.0 for padded
        mask = (
            (torch.arange(T, device=x.device)[None, :] < lengths[:, None])
            .float()
            .unsqueeze(1)
        )

        denom = mask.sum(dim=2, keepdim=True).clamp_min(1.0)  # (B,1,1)

        # First and second moments (masked)
        m1 = (x * mask).sum(dim=2) / denom.squeeze(2)  # (B, D)
        m2 = ((x**2) * mask).sum(dim=2) / denom.squeeze(2)  # (B, D)
        # var = E[x^2] - E[x]^2 ; clamp to avoid tiny negative due to FP error
        var = (m2 - m1 * m1).clamp_min(0.0)
        std = torch.sqrt(var + EPS)  # (B, D)

        # Optional last-resort cleanup (should rarely trigger):
        # m1 = torch.nan_to_num(m1, nan=0.0, posinf=0.0, neginf=0.0)
        # std = torch.nan_to_num(std, nan=0.0, posinf=0.0, neginf=0.0)
        # print(m1.shape, std.shape)
        return torch.cat([m1, std], dim=1)  # (B, 2D)

    def forward(self, x, lengths):
        h = self.tdnn(x)  # (B,D,T)
        pooled = self.stats_pooling(h, lengths)
        emb = self.segment(pooled)  # (B,emb_dim)
        logits = self.classifier(emb)  # (B,num_classes)
        return logits, emb


# --- Train quick ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model = XVector(feat_dim=80, emb_dim=256, num_classes=len(all_labels)).to(device)
opt = torch.optim.AdamW(model.parameters(), amsgrad=True)
crit = nn.CrossEntropyLoss()
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

for epoch in range(100):  # quick demo, increase for real use
    model.train()
    for feats, labs, lengths in (progbar := tqdm(train_loader, leave=False)):
        feats, labs, lengths = feats.to(device), labs.to(device), lengths.to(device)
        logits, emb = model(feats, lengths)
        F.cross_entropy(logits, labs)
        loss = crit(logits, labs)
        opt.zero_grad()
        loss.backward()
        opt.step()
        progbar.set_description(f"Epoch {epoch+1}")
        progbar.set_postfix(loss=loss.item())
        writer.add_scalar(
            "Loss/train", loss.item(), epoch * len(train_loader) + progbar.n
        )
        import math

        assert not math.isnan(loss.item())
    model.eval()
    with torch.no_grad():
        total, correct = 0, 0
        for feats, labs, lengths in tqdm(val_loader, leave=False):
            feats, labs, lengths = feats.to(device), labs.to(device), lengths.to(device)
            logits, emb = model(feats, lengths)
            preds = logits.argmax(dim=1)
            total += labs.size(0)
            correct += (preds == labs).sum().item()
        acc = correct / total
        print(f"Epoch {epoch+1} Validation accuracy: {acc*100:.2f}%")

torch.save(model.state_dict(), "xvector_params.pt")

In [None]:
# %% Encoder+Attention x-vector (Transformer + Attentive Stats Pooling + ArcFace)
import math, torch, torch.nn as nn
import torch.nn.functional as F


class PosEnc(nn.Module):
    def __init__(self, d_model, max_len=20000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe)  # (L, D)

    def forward(self, x):
        # x: (B, T, D)
        T = x.size(1)
        return x + self.pe[:T].unsqueeze(0) # type: ignore


class AttentiveStatsPooling(nn.Module):
    """
    Multi-head attention to get per-frame weights; returns [mean || std].
    """

    def __init__(self, d_in, heads=4):
        super().__init__()
        self.heads = heads
        self.att = nn.Linear(d_in, heads)  # (B,T,heads)
        # self.bn = nn.BatchNorm1d(d_in)
        self.bn = nn.BatchNorm1d(2 * heads * d_in)

    def forward(self, x, lengths):
        # x: (B,T,D), lengths: (B,)
        B, T, D = x.shape
        mask = torch.arange(T, device=x.device)[None, :] < lengths[:, None]  # (B,T)
        w = self.att(x)  # (B,T,H)
        w = w.masked_fill(~mask.unsqueeze(-1), -1e4)
        w = F.softmax(w, dim=1)  # (B,T,H)
        # per head mean/std then concat heads
        means, stds = [], []
        for h in range(self.heads):
            wh = w[:, :, h : h + 1]  # (B,T,1)
            denom = wh.sum(dim=1, keepdim=True).clamp_min(1e-6)  # (B,1,1)
            mu = (x * wh).sum(dim=1, keepdim=True) / denom  # (B,1,D)
            var = ((x - mu) ** 2 * wh).sum(dim=1, keepdim=True) / denom
            sd = torch.sqrt(var.clamp_min(1e-5))  # (B,1,D)
            means.append(mu.squeeze(1))  # (B,D)
            stds.append(sd.squeeze(1))  # (B,D)
        m = torch.cat(means, dim=1)  # (B, H*D)
        s = torch.cat(stds, dim=1)  # (B, H*D)
        out = torch.cat([m, s], dim=1)  # (B, 2*H*D)
        return self.bn(out)  # BN helps stability


class ArcMarginProduct(nn.Module):
    """
    Additive Angular Margin (ArcFace) head.
    """

    def __init__(self, in_features, out_features, s=30.0, m=0.2):
        super().__init__()
        self.s, self.m = s, m
        self.W = nn.Parameter(torch.randn(out_features, in_features))
        nn.init.xavier_uniform_(self.W)

    def forward(self, emb, labels=None):
        # emb: (B,D), labels: (B,)
        emb = F.normalize(emb, p=2, dim=1)
        W = F.normalize(self.W, p=2, dim=1)
        logits = F.linear(emb, W)  # cos(theta)
        if labels is None:
            return self.s * logits
        # add margin to target classes
        th = torch.clamp(logits, -1 + 1e-7, 1 - 1e-7)
        theta = torch.acos(th)
        target_logits = torch.cos(theta + self.m)
        one_hot = F.one_hot(labels, num_classes=W.size(0)).float()
        logits = logits * (1 - one_hot) + target_logits * one_hot
        return self.s * logits


class EncAttnXVector(nn.Module):
    def __init__(
        self, feat_dim=80, d_model=256, nhead=4, nlayers=4, emb_dim=256, num_classes=100
    ):
        super().__init__()
        self.proj = nn.Linear(feat_dim, d_model)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=nlayers)
        self.pos = PosEnc(d_model)
        self.pool = AttentiveStatsPooling(d_model, heads=2)  # output = 2*heads*d_model
        self.fc = nn.Sequential(
            nn.Linear(2 * 2 * d_model, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        )
        self.arc = ArcMarginProduct(emb_dim, num_classes, s=30.0, m=0.2)

    def forward(self, x, lengths, labels=None):
        # x: (B,80,T) -> (B,T,80)
        x = x.transpose(1, 2)
        x = self.pos(self.proj(x))  # (B,T,d_model)
        # build key_padding_mask: True=pad
        B, T, _ = x.size()
        mask = torch.arange(T, device=x.device)[None, :] >= lengths[:, None]  # (B,T)
        h = self.encoder(x, src_key_padding_mask=mask)  # (B,T,d_model)
        pooled = self.pool(h, lengths)  # (B, 2*H*D)
        emb = self.fc(pooled)  # (B, emb_dim)
        logits = self.arc(emb, labels)  # (B, num_classes)
        return logits, emb


num_classes = len(label_to_idx)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EncAttnXVector(
    feat_dim=80, d_model=256, nhead=4, nlayers=3, emb_dim=256, num_classes=num_classes
).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
crit = nn.CrossEntropyLoss()


In [None]:
for epoch in range(20):
    model.train()
    total, seen = 0.0, 0
    for feats, labs, lengths in tqdm(train_loader, leave=False):
        feats, labs, lengths = feats.to(device), labs.to(device), lengths.to(device)
        logits, emb = model(
            feats, lengths, labels=labs
        )  # ArcFace returns scaled logits
        loss = crit(logits, labs)
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        opt.step()
        total += loss.item() * feats.size(0)
        seen += feats.size(0)
    sched.step()
    print(f"epoch {epoch+1} | loss {total/seen:.4f}")
    model.eval()
    with torch.no_grad():
        total, correct = 0, 0
        for feats, labs, lengths in tqdm(val_loader, leave=False):
            feats, labs, lengths = feats.to(device), labs.to(device), lengths.to(device)
            logits, emb = model(feats, lengths)
            preds = logits.argmax(dim=1)
            total += labs.size(0)
            correct += (preds == labs).sum().item()
        acc = correct / total
        print(f"Epoch {epoch+1} Validation accuracy: {acc*100:.2f}%")

In [79]:
torch.save(model.state_dict(), "enc_attn_xvector_params.pt")