In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WavLMModel, Wav2Vec2FeatureExtractor
import torchaudio
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.metrics import roc_curve
from scipy.stats import entropy
import matplotlib.pyplot as plt
from google.colab import drive
from tqdm import tqdm

drive.mount('/content/drive')

class AudioDataset(Dataset):
    def __init__(self, audio_files, labels, processor):
        self.audio_files = audio_files
        self.labels = labels
        self.processor = processor

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        waveform, sr = torchaudio.load(self.audio_files[idx])
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
        if waveform.dim() == 2:
            waveform = waveform.squeeze(0)
        if torch.isnan(waveform).any() or torch.isinf(waveform).any():
            print(f"Invalid waveform at index {idx}")
        inputs = self.processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True)
        return inputs["input_values"].squeeze(0), self.labels[idx]

def segment_audio(audio_path, segment_length=2, sample_rate=16000):
    waveform, sr = torchaudio.load(audio_path)
    if sr != sample_rate:
        waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
    if waveform.dim() == 2 and waveform.size(0) > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    elif waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    samples_per_segment = segment_length * sample_rate
    total_samples = waveform.size(1)
    segments = []
    for start in range(0, total_samples, samples_per_segment):
        end = min(start + samples_per_segment, total_samples)
        if end - start == samples_per_segment:
            segment = waveform[:, start:end]
            segments.append(segment)
    return segments

def prepare_dataset(dataset_dir, output_dir="/content/segmented_audio"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    audio_files = []
    labels = []
    filenames = [f for f in os.listdir(dataset_dir) if f.endswith('.wav')]
    for filename in tqdm(filenames, desc="Processing Audio Files"):
        speaker_id = int(filename[:2]) - 1
        if 0 <= speaker_id < 56:
            file_path = os.path.join(dataset_dir, filename)
            waveform, sr = torchaudio.load(file_path)
            duration = waveform.size(1) / sr
            if duration < 2:
                print(f"Skipping {filename}: Duration {duration:.2f}s < 2s")
                continue
            segments = segment_audio(file_path)
            for i, segment in enumerate(segments):
                segment_path = os.path.join(output_dir, f"{filename[:-4]}_seg{i}.wav")
                torchaudio.save(segment_path, segment, 16000)
                audio_files.append(segment_path)
                labels.append(speaker_id)
    return audio_files, labels

class VAMA(nn.Module):
    def __init__(self, hidden_size=768, num_scales=3):
        super(VAMA, self).__init__()
        self.hidden_size = hidden_size
        self.num_scales = num_scales
        self.context_mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, num_scales)
        )
        self.convs = nn.ModuleList([
            nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding="same")
            for _ in range(num_scales)
        ])
        self.attn_mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1)
        )
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))
        self.alpha = nn.Parameter(torch.tensor(0.5))
        self.beta_param = nn.Parameter(torch.tensor(0.5))
        self.register_buffer("speaker_mean", torch.zeros(hidden_size))
        self.register_buffer("speaker_var", torch.ones(hidden_size))
        self.momentum = 0.1

    def estimate_mi(self, features, labels):
        clf = LinearSVC(max_iter=1000)
        feats_flat = features.mean(dim=1).detach().cpu().numpy()
        clf.fit(feats_flat, labels.cpu().numpy())
        mi_score = clf.score(feats_flat, labels.cpu().numpy())
        return torch.tensor(mi_score, device=features.device)

    def forward(self, x, labels=None):
        batch, time, hidden = x.shape
        context = x.mean(dim=1)
        dilation_rates = F.softplus(self.context_mlp(context))
        x = x.transpose(1, 2)
        scale_feats = []
        for i, conv in enumerate(self.convs):
            dilation = int(dilation_rates[:, i].mean().item() + 1)
            conv.dilation = (dilation,)
            scale_feats.append(conv(x).transpose(1, 2))
        weights = []
        for feat in scale_feats:
            attn_score = self.attn_mlp(feat.mean(dim=1)).squeeze(-1)
            if labels is not None:
                mi_score = self.estimate_mi(feat, labels)
                attn_score = attn_score + 0.1 * mi_score
            weights.append(attn_score)
        weights = torch.stack(weights, dim=1)
        weights = F.softmax(weights, dim=1)
        weights = weights.unsqueeze(2).unsqueeze(3)
        scale_feats = torch.stack(scale_feats, dim=1)
        fused = (weights * scale_feats).sum(dim=1)
        clip_mean, clip_var = fused.mean(dim=1), fused.var(dim=1)
        if self.training and labels is not None:
            unique_labels = labels.unique()
            for lbl in unique_labels:
                mask = (labels == lbl)
                if mask.sum() > 0:
                    lbl_mean = clip_mean[mask].mean(dim=0)
                    lbl_var = clip_var[mask].mean(dim=0)
                    self.speaker_mean = (1 - self.momentum) * self.speaker_mean + self.momentum * lbl_mean
                    self.speaker_var = (1 - self.momentum) * self.speaker_var + self.momentum * lbl_var
        mu = self.alpha * clip_mean + (1 - self.alpha) * self.speaker_mean
        sigma = torch.sqrt(self.beta_param * clip_var + (1 - self.beta_param) * self.speaker_var + 1e-6)
        normed = self.gamma * (fused - mu.unsqueeze(1)) / sigma.unsqueeze(1) + self.beta
        return normed

class WavLM_VISTA(nn.Module):
    def __init__(self, num_classes=56):
        super(WavLM_VISTA, self).__init__()
        self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")
        self.vama = VAMA(hidden_size=768)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(768, num_classes)
        self.adv_heads = nn.ModuleList([
            nn.Sequential(nn.Linear(768, 128), nn.ReLU(), nn.Linear(128, 3)),
            nn.Sequential(nn.Linear(768, 128), nn.ReLU(), nn.Linear(128, 3))
        ])

    def forward(self, input_values, attention_mask=None, labels=None, return_adv=False):
        outputs = self.wavlm(input_values, attention_mask=attention_mask)
        features = outputs.last_hidden_state
        embedding = self.pool(features.transpose(1, 2)).squeeze(-1)
        if return_adv:
            adv_outputs = [head(grad_reverse(embedding, 0.1)) for head in self.adv_heads]
            return embedding, adv_outputs
        return embedding

class GradientReversal(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.alpha, None

def grad_reverse(x, alpha=1.0):
    return GradientReversal.apply(x, alpha)

class ArcFaceLoss(nn.Module):
    def __init__(self, s=30.0, m=0.5):
        super(ArcFaceLoss, self).__init__()
        self.s = s
        self.m = m

    def forward(self, logits, labels):
        theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
        target_logits = torch.cos(theta + self.m)
        one_hot = F.one_hot(labels, num_classes=logits.size(1)).float()
        output = self.s * (one_hot * target_logits + (1 - one_hot) * logits)
        return F.cross_entropy(output, labels)

def pretrain_vacl(model, loader, epochs=10, device="cuda"):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-6, weight_decay=1e-5)
    reconstruction_head = nn.Linear(768, 768).to(device)
    for epoch in range(epochs):
        total_loss = 0
        for inputs, _ in tqdm(loader, desc=f"VACL Pre-training Epoch {epoch}"):
            inputs = inputs.to(device)
            optimizer.zero_grad()
            emb = model(inputs)
            if torch.isnan(emb).any() or torch.isinf(emb).any():
                print("NaN or Inf detected in embeddings")
            recon = reconstruction_head(emb)
            if torch.isnan(recon).any() or torch.isinf(recon).any():
                print("NaN or Inf detected in reconstruction")
            loss = F.mse_loss(recon, emb.detach())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
        print(f"VACL Pre-training Epoch {epoch}, Loss: {total_loss / len(loader):.4f}")
    torch.save(model.state_dict(), "/content/drive/MyDrive/Speaker Recognition/wavlm_vista_vacl.pth")

def finetune_doai(model, loader, epochs=20, device="cuda"):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    arcface_loss = ArcFaceLoss(s=30.0, m=0.5)
    ce_loss = nn.CrossEntropyLoss()
    alpha = 0.01
    for epoch in range(epochs):
        total_loss = 0
        for inputs, labels in tqdm(loader, desc=f"DOAI Fine-tuning Epoch {epoch}"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            emb, adv_outputs = model(inputs, labels=labels, return_adv=True)
            logits = model.fc(emb)
            arc_loss = arcface_loss(logits, labels)
            pseudo_emotion = torch.randint(0, 3, (inputs.size(0),), device=device)
            pseudo_rate = torch.randint(0, 3, (inputs.size(0),), device=device)
            adv_loss = ce_loss(adv_outputs[0], pseudo_emotion) + ce_loss(adv_outputs[1], pseudo_rate)
            loss = arc_loss - alpha * adv_loss
            if loss.item() < -1.0:
                print(f"Early stopping at epoch {epoch} due to negative loss: {loss.item()}")
                break
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"DOAI Epoch {epoch}, Loss: {total_loss / len(loader):.4f}")
    torch.save(model.state_dict(), "/content/drive/MyDrive/Speaker Recognition/wavlm_vista_doai.pth")

def evaluate_metrics(model, loader, device="cuda"):
    model.eval()
    embeddings, labels = [], []
    with torch.no_grad():
        for inputs, lbls in tqdm(loader, desc="Evaluating Metrics"):
            inputs, lbls = inputs.to(device), lbls
            emb = model(inputs, labels=lbls)
            embeddings.append(emb.cpu())
            labels.append(lbls.cpu())
    embeddings = torch.cat(embeddings)
    labels = torch.cat(labels)
    clf = LinearSVC(max_iter=1000)
    clf.fit(embeddings.numpy(), labels.numpy())
    mi_score = clf.score(embeddings.numpy(), labels.numpy())
    lin_probe_acc = mi_score
    model.train()
    adv_accs = []
    for inputs, _ in tqdm(loader, desc="Adversarial Evaluation"):
        inputs = inputs.to(device)
        _, adv_outputs = model(inputs, return_adv=True)
        for adv_out in adv_outputs:
            pred = adv_out.argmax(dim=1)
            acc = (pred == torch.randint(0, 3, pred.size(), device=device)).float().mean().item()
            adv_accs.append(acc)
    adv_failure = np.mean(adv_accs)
    intra_dist, inter_dist = [], []
    for i in tqdm(range(len(labels)), desc="Calculating Distances"):
        same = embeddings[labels == labels[i]]
        diff = embeddings[labels != labels[i]]
        intra_dist.append(F.cosine_similarity(embeddings[i:i+1], same).mean().item())
        inter_dist.append(F.cosine_similarity(embeddings[i:i+1], diff).mean().item())
    compactness = np.mean(intra_dist) / np.mean(inter_dist)
    scores = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0)).flatten()
    true_labels = (labels.unsqueeze(1) == labels.unsqueeze(0)).flatten()
    fpr, tpr, _ = roc_curve(true_labels.numpy(), scores.numpy())
    eer = fpr[np.nanargmin(np.abs(fpr - (1 - tpr)))]
    from sklearn.manifold import TSNE
    tsne = TSNE(n_components=2).fit_transform(embeddings.numpy())
    hist_train, _ = np.histogramdd(tsne[:len(loader.dataset)//2], bins=20)
    hist_test, _ = np.histogramdd(tsne[len(loader.dataset)//2:], bins=20)
    kl_div = entropy(hist_train.flatten() + 1e-10, hist_test.flatten() + 1e-10)
    print(f"MI Score: {mi_score:.4f}, Linear Probing: {lin_probe_acc:.4f}")
    print(f"Adversarial Failure: {adv_failure:.4f}, Compactness: {compactness:.4f}")
    print(f"EER: {eer:.4f}, t-SNE Divergence: {kl_div:.4f}")
    return mi_score, lin_probe_acc, adv_failure, compactness, eer, kl_div

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dataset_dir = "/content/drive/MyDrive/Speaker Recognition/Dataset"
    processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus")
    print("Processing audio files and segmenting into 2-second clips...")
    audio_files, labels = prepare_dataset(dataset_dir)
    labels = torch.tensor(labels)
    dataset = AudioDataset(audio_files, labels, processor)
    loader = DataLoader(dataset, batch_size=8, shuffle=True)
    print(f"Total number of 2-second clips: {len(dataset)}")
    print(f"Labels range: {labels.min()} to {labels.max()} (should be 0 to 55)")
    model = WavLM_VISTA(num_classes=56).to(device)
    pretrain_vacl(model, loader, epochs=10, device=device)
    finetune_doai(model, loader, epochs=20, device=device)
    metrics = evaluate_metrics(model, loader, device=device)
    with open("/content/drive/MyDrive/Speaker Recognition/dataset_summary.txt", "w") as f:
        for file, label in zip(audio_files, labels):
            f.write(f"File: {file}, Label: {label}\n")
    print("Dataset summary saved to /content/drive/MyDrive/Speaker Recognition/dataset_summary.txt")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Processing audio files and segmenting into 2-second clips...


Processing Audio Files: 100%|██████████| 56/56 [00:29<00:00,  1.89it/s]


Total number of 2-second clips: 3439
Labels range: 0 to 55 (should be 0 to 55)


VACL Pre-training Epoch 0: 100%|██████████| 430/430 [2:01:31<00:00, 16.96s/it]


VACL Pre-training Epoch 0, Loss: 0.0066


VACL Pre-training Epoch 1: 100%|██████████| 430/430 [2:03:36<00:00, 17.25s/it]


VACL Pre-training Epoch 1, Loss: 0.0050


VACL Pre-training Epoch 2:  43%|████▎     | 183/430 [52:44<1:08:10, 16.56s/it]