In [None]:
import sys
import argparse
import torch
import os, sys
import numpy as np
import warnings
from time import time
import datetime
import gc
import torch
import numpy as np
import torchaudio
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import Audio, display
import matplotlib.pyplot as plt

from smooth import divide_batch, Smooth
from parser_certify import get_parser
from our_utils import Producer, CustomModel, predict_with_radius, predict_with_radius_2, predict_with_radius_3
from our_utils import ERA_of_f, ERA_of_g

In [None]:
args = get_parser().parse_args([])
args.dataset = "VoxCeleb2"
args.dataset_train = None
args.dataset_test = ... # path to dataset
# args.outdur = ... # path to save results
args.num_support_val = 5
args.classes_per_it_val = 118  # 1118, 7363 
args.sigma = 0.01
args.N = 20000

args.K = 5


args.normalize1=True
args.normalize_enrlollment_prototypes=True

args.cuda_number = 0
device = 'cuda:{n}'.format(n=args.cuda_number) if torch.cuda.is_available() and args.cuda else 'cpu'
args.device = device

In [None]:
args.model_name = "ecapa-tdnn"
args.emb_size = 192

# args.model_name = "pyannote"
# args.emb_size = 512

# args.model_name = "wavlm"
# args.emb_size = 512

# args.model_name = "campplus"
# args.emb_size = 192

# args.model_name = "eres2net"
# args.emb_size = 192

# args.model_name = "wespeaker"
# args.emb_size = 256

model = CustomModel(args)
model = model.to(device)
model.eval()

In [None]:
pr = Producer(model, args, args.normalize_enrlollment_prototypes)
class_prototypes, speaker_enrollment_audios, speaker_inference_audios, id2class, class2id = pr.produce_subsets()

class_prototypes_list = torch.stack(list(class_prototypes.values()))
centroids = class_prototypes_list.squeeze(1).to(device)

smoothed_model = Smooth(
    base_model=model,
    device=device,
    num_classes=args.classes_per_it_val,
    sigma=args.sigma,
    alpha=args.alpha,
    mode=args.mode,
    normalize1=args.normalize1,
    emb_size=model.emb_size
)

In [None]:
max_delta = 2
n_grid = 10
batch_size = 8
results_f, results2_f, era_f = ERA_of_f(model, pr, device, max_delta, n_grid=n_grid, batch_size=batch_size, attack='pgd')

plt.plot(np.linspace(0, max_delta, n_grid), results2_f.mean(dim=0))

plt.ylabel('Empirical robust accuracy (f)')
plt.xlabel('Attack radius')

plt.tight_layout()
plt.show()

In [None]:
max_delta = 2
n_grid = 10
batch_size = 8

results_g, results2_g, era_g = ERA_of_g(model=model, 
                                        smoothed_model=smoothed_model,
                                        args=args,
                                        producer=pr,
                                        device=device, 
                                        delta_max=max_delta,
                                        n_grid=10)

plt.plot(np.linspace(0, max_delta, n_grid), results2_g.mean(dim=0))

plt.ylabel('Empirical robust accuracy (g)')
plt.xlabel('Attack radius')

plt.tight_layout()
plt.show()

In [None]:
audio_len = 3

y = []

sum_audios = 0
correct = 0

preds = []
preds_cohen = []

radii = []
new_radii = []
radii_cohen = []

correct_or_not = []
correct_or_not_cohen = []
for speaker_id in tqdm(pr.speakers_test_only, total=len(pr.speakers_test_only)):
    pathes = speaker_inference_audios[speaker_id]
    gt_class = id2class[speaker_id]
    wavs_paths = speaker_inference_audios[class2id[gt_class]]
    wav_path = wavs_paths[np.random.choice(len(wavs_paths))]

    sample = torchaudio.load(wav_path)[0][0, :audio_len*16000].to("cuda:3")
    sum_audios += 1
    pred_class, gamma_lcb, radius, time_elapsed, n_samples, radius_as_in_article, pred_centroid, adv_centroid = predict_with_radius(
        args,
        smoothed_model,
        sample=sample,
        centroids=centroids,
        centroid_target=torch.arange(args.classes_per_it_val)
        )
    
    pred_class_cohen, radius_cohen = predict_with_radius_2(args, model, sample, centroids, centroid_target=torch.arange(args.classes_per_it_val))

    radius_as_in_article = torch.tensor(radius_as_in_article)
    print(f"GT class: {gt_class}   |    Predicted class: {pred_class}   |    Predicted class RS: {pred_class_cohen}  |   Radius as in article: {radius_as_in_article}    |   New radius: {radius}|   Radius RS: {radius_cohen}")

    is_correct = int(pred_class == gt_class)
    preds.append(pred_class)
    preds_cohen.append(pred_class_cohen)
    radii.append(radius_as_in_article)
    new_radii.append(radius)
    radii_cohen.append(radius_cohen)
    if is_correct:
        correct += 1
    correct_or_not.append(is_correct)
    correct_or_not_cohen.append(int(pred_class_cohen == gt_class))

print(correct / sum_audios)
preds = np.array(preds)
preds_cohen = np.array(preds_cohen)
radii = np.array(radii)
new_radii = np.array(new_radii)
radii_cohen = np.array(radii_cohen)
correct_or_not = np.array(correct_or_not)
correct_or_not_cohen = np.array(correct_or_not_cohen)

In [None]:
r_attack = np.linspace(0, 0.1, 100)

cra_old_list = []
cra_new_list = []
cra_cohen_list = []
for r in r_attack:
    cra_old = np.mean((radii  >= r) * correct_or_not)
    cra_new = np.mean((new_radii  >= r) * correct_or_not)
    cra_cohen = np.mean((radii_cohen  >= r) * correct_or_not_cohen)
    cra_old_list.append(cra_old)
    cra_new_list.append(cra_new)
    cra_cohen_list.append(cra_cohen)
    
    
plt.plot(r_attack, cra_old_list, label="old_approach")
plt.plot(r_attack, cra_new_list, label="new_approach")
plt.plot(r_attack, cra_cohen_list, label="ordinary_rs")
plt.legend()
plt.show()

In [None]:
y = []

sum_audios = 0
correct = 0

preds = []
new_radii = []
correct_or_not = []
# for speaker_id, pathes in speaker_inference_audios.items():
for speaker_id in tqdm(pr.speakers_test_only, total=len(pr.speakers_test_only)):
    pathes = speaker_inference_audios[speaker_id]
    gt_class = id2class[speaker_id]
    wavs_paths = speaker_inference_audios[class2id[gt_class]]
    wav_path = wavs_paths[np.random.choice(len(wavs_paths))]

    sample = torchaudio.load(wav_path)[0][0, :3*16000].to("cuda:3")
    
    sum_audios += 1
    pred_class, radius = predict_with_radius_2(args, model, sample, centroids, centroid_target=torch.arange(args.classes_per_it_val))

    print(f"GT class: {gt_class}   |    Predicted class: {pred_class}   |   Radius: {radius}")

    is_correct = int(pred_class == gt_class)
    preds.append(pred_class)
    new_radii.append(radius)
    if is_correct:
        correct += 1
    correct_or_not.append(is_correct)

print(correct / sum_audios)
preds = np.array(preds)
new_radii = np.array(new_radii)
correct_or_not = np.array(correct_or_not)

In [None]:
r_attack = np.linspace(0, 0.04, 100)

cra_new_list = []
for r in r_attack:
    cra_new = np.mean((new_radii  >= r) * correct_or_not)
    cra_new_list.append(cra_new)
    

plt.plot(r_attack, cra_new_list, label="other_approach")
plt.legend()
plt.show()