# Initialisation and Dependencies

In [1]:
import sys
lib_path = '/home/jovyan/libs'
sys.path.insert(0, lib_path)

In [2]:
%reload_ext autoreload
%autoreload 2

import gc, math, traceback, datetime

import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from datasets import load_from_disk

import whisper
from whisper.tokenizer import get_tokenizer

from utils import audio, gradient, gpu
from utils.attacks import PrepareFront, PrepareAtPosition

  warn(


# GPU RAM Tracking

In [3]:
device = gpu.get_device()

Device: cuda


# Load Model

In [4]:
try:
    del model
    print("Model deleted!")
except:
    pass

In [5]:
SMALL_MODEL = "small.en"
TINY_MODEL = "tiny.en"

small_model = whisper.load_model(SMALL_MODEL).to(device)
small_model.eval()

tiny_model = whisper.load_model(TINY_MODEL).to(device)
tiny_model.eval()

target_id, sot_ids = gradient._get_ids(small_model)

  checkpoint = torch.load(fp, map_location=device)


# Load Data

In [6]:
tedlium_path = "../tedlium"
train_path, validation_path, test_path = f"{tedlium_path}/train_idx.hf", f"{tedlium_path}/validation_idx.hf", f"{tedlium_path}/test.hf"

In [7]:
TRAIN_SELECT = 250
VALID_SELECT = 100
TEST_SELECT = 150

SEED = 1

tedlium_train = load_from_disk(train_path).with_format("torch").shuffle(seed=SEED).select(range(TRAIN_SELECT))
tedlium_validation = load_from_disk(validation_path).with_format("torch").shuffle(seed=SEED).select(range(VALID_SELECT))
tedlium_test = load_from_disk(test_path).with_format("torch").shuffle(seed=SEED).select(range(TEST_SELECT))

Loading dataset from disk:   0%|          | 0/109 [00:00<?, ?it/s]

In [8]:
# def collate(ls):
#     pad_to = max(list(map(lambda x: x["audio"].shape[0], ls)))
#     return torch.cat(list(map(lambda x: F.pad(x["audio"], (0, pad_to - x["audio"].shape[0])).unsqueeze(0).to(torch.bfloat16), ls)), dim=0)

def collate_idx(ls):
    return ls[0]["audio"].unsqueeze(0), ls[0]["idx"].item()

TRAIN_BATCH_SIZE = 1 # highly recommended to be 1
VALID_BATCH_SIZE = 1

train_dataset = DataLoader(tedlium_train, batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_idx)
validation_dataset = DataLoader(tedlium_validation, batch_size=VALID_BATCH_SIZE, collate_fn=collate_idx)
test_dataset = DataLoader(tedlium_test)

# Training Loop

In [9]:
gpu.cleanup()

In [10]:
LR = 1e-3
PATIENCE = 7
MIN_LIMIT = None
ITER_LIMIT = None
CLAMP_EP = 0.005
SNIPPET_SIZE = (1, 10_240)
POSITION = 0
PREPARE_METHOD = PrepareAtPosition(SNIPPET_SIZE, POSITION)
DELTA = 0.01

CLAMP_EPS = (0.00125, 0.0025, 0.005)

writer = None

In [11]:
# tensorboard writer
# timestamp = datetime.datetime.now().strftime(f'%Y%m%d-%H%M%S_size_{SNIPPET_SIZE}_{PREPARE_METHOD.name}')
# writer = SummaryWriter(log_dir=f"../runs/size_tests/{timestamp}", max_queue=5)

In [12]:
best_snippet, snippets, _, _ = gradient.train(small_model, gradient.forward_multi,
                                              train_dataset, validation_dataset,
                                              PREPARE_METHOD,
                                              writer, lr=LR,
                                              train_success=None, valid_success=None,
                                              iter_limit=ITER_LIMIT, mins_limit=MIN_LIMIT, patience=PATIENCE, clamp_epsilon=CLAMP_EP, gap=DELTA)



Prepare method: prepare_at_position
Snippet initialised to [1.0212062306891312e-06, 0.004999465774744749] of size (1, 10240)
Clamp: 0.005
Time Limit (Mins): None
Epochs Limit: None
Tracking training success: False
Tracking valid success: False


Training:   0% 0/1 [02:49<?, ?it/s, Iter 2, Training Batch 1/250]    

Trng Avg Loss: 7.408213442802429 | Valid Avg Loss: 7.685433864593506 | Patience: 7 | LR: [0.001] | Epoch Limit: None


Training:   0% 0/1 [05:13<?, ?it/s, Iter 3, Training Batch 1/250]    

Trng Avg Loss: 7.050932130813599 | Valid Avg Loss: 7.495156288146973 | Patience: 7 | LR: [0.001] | Epoch Limit: None


Training:   0% 0/1 [07:39<?, ?it/s, Iter 4, Training Batch 1/250]    

Trng Avg Loss: 7.549940779209137 | Valid Avg Loss: 7.170105457305908 | Patience: 7 | LR: [0.001] | Epoch Limit: None


Training:   0% 0/1 [09:58<?, ?it/s, Iter 5, Training Batch 1/250]    

Trng Avg Loss: 6.502966806888581 | Valid Avg Loss: 7.135991096496582 | Patience: 7 | LR: [0.001] | Epoch Limit: None


Training:   0% 0/1 [12:17<?, ?it/s, Iter 6, Training Batch 1/250]    

Trng Avg Loss: 6.67595689868927 | Valid Avg Loss: 7.142073154449463 | Patience: 6 | LR: [0.001] | Epoch Limit: None


Training:   0% 0/1 [14:23<?, ?it/s, Iter 7, Training Batch 1/250]    

Trng Avg Loss: 5.086055135250091 | Valid Avg Loss: 5.7306694984436035 | Patience: 7 | LR: [0.0005] | Epoch Limit: None


Training:   0% 0/1 [16:30<?, ?it/s, Iter 8, Training Batch 1/250]    

Trng Avg Loss: 5.157282932758331 | Valid Avg Loss: 5.015968322753906 | Patience: 7 | LR: [0.0005] | Epoch Limit: None


Training:   0% 0/1 [18:42<?, ?it/s, Iter 9, Training Batch 1/250]    

Trng Avg Loss: 5.485568301677704 | Valid Avg Loss: 6.522902011871338 | Patience: 6 | LR: [0.0005] | Epoch Limit: None


Training:   0% 0/1 [20:48<?, ?it/s, Iter 10, Training Batch 1/250]   

Trng Avg Loss: 5.055273509025573 | Valid Avg Loss: 5.493022918701172 | Patience: 5 | LR: [0.0005] | Epoch Limit: None


Training:   0% 0/1 [22:51<?, ?it/s, Iter 11, Training Batch 1/250]    

Trng Avg Loss: 4.834996129512787 | Valid Avg Loss: 4.82178258895874 | Patience: 7 | LR: [0.0005] | Epoch Limit: None


Training:   0% 0/1 [24:36<?, ?it/s, Iter 12, Training Batch 1/250]    

Trng Avg Loss: 3.4732904291152953 | Valid Avg Loss: 3.905441999435425 | Patience: 7 | LR: [0.00025] | Epoch Limit: None


Training:   0% 0/1 [26:23<?, ?it/s, Iter 13, Training Batch 2/250]    

Trng Avg Loss: 3.7926881642341614 | Valid Avg Loss: 3.9466350078582764 | Patience: 6 | LR: [0.00025] | Epoch Limit: None


Training:   0% 0/1 [28:08<?, ?it/s, Iter 14, Training Batch 1/250]    

Trng Avg Loss: 3.777814440727234 | Valid Avg Loss: 4.352321624755859 | Patience: 5 | LR: [0.00025] | Epoch Limit: None


Training:   0% 0/1 [29:53<?, ?it/s, Iter 15, Training Batch 1/250]    

Trng Avg Loss: 3.5098293986320495 | Valid Avg Loss: 3.863316059112549 | Patience: 7 | LR: [0.00025] | Epoch Limit: None


Training:   0% 0/1 [31:28<?, ?it/s, Iter 16, Training Batch 1/250]    

Trng Avg Loss: 3.18808256149292 | Valid Avg Loss: 2.8326191902160645 | Patience: 7 | LR: [0.00025] | Epoch Limit: None


Training:   0% 0/1 [32:51<?, ?it/s, Iter 17, Training Batch 1/250]    

Trng Avg Loss: 2.3053841490745546 | Valid Avg Loss: 3.0545878410339355 | Patience: 6 | LR: [0.000125] | Epoch Limit: None


Training:   0% 0/1 [34:12<?, ?it/s, Iter 18, Training Batch 1/250]    

Trng Avg Loss: 2.129672840595245 | Valid Avg Loss: 2.760840892791748 | Patience: 7 | LR: [0.000125] | Epoch Limit: None


Training:   0% 0/1 [35:34<?, ?it/s, Iter 19, Training Batch 1/250]    

Trng Avg Loss: 2.1550323598384855 | Valid Avg Loss: 2.462954521179199 | Patience: 7 | LR: [0.000125] | Epoch Limit: None


Training:   0% 0/1 [36:53<?, ?it/s, Iter 20, Training Batch 1/250]    

Trng Avg Loss: 1.9487013676166534 | Valid Avg Loss: 2.2845845222473145 | Patience: 7 | LR: [0.000125] | Epoch Limit: None


Training:   0% 0/1 [38:08<?, ?it/s, Iter 21, Training Batch 1/250]    

Trng Avg Loss: 1.7600070093870164 | Valid Avg Loss: 2.7867279052734375 | Patience: 6 | LR: [0.000125] | Epoch Limit: None


Training:   0% 0/1 [39:22<?, ?it/s, Iter 22, Training Batch 1/250]    

Trng Avg Loss: 1.5322322051525117 | Valid Avg Loss: 1.8358352184295654 | Patience: 7 | LR: [6.25e-05] | Epoch Limit: None


Training:   0% 0/1 [40:32<?, ?it/s, Iter 23, Training Batch 1/250]    

Trng Avg Loss: 1.3157225685715674 | Valid Avg Loss: 1.2536004781723022 | Patience: 7 | LR: [6.25e-05] | Epoch Limit: None


Training:   0% 0/1 [41:37<?, ?it/s, Iter 24, Training Batch 2/250]    

Trng Avg Loss: 1.0839357984662057 | Valid Avg Loss: 1.1307522058486938 | Patience: 7 | LR: [6.25e-05] | Epoch Limit: None


Training:   0% 0/1 [42:41<?, ?it/s, Iter 25, Training Batch 1/250]    

Trng Avg Loss: 1.00020832580328 | Valid Avg Loss: 1.5069950819015503 | Patience: 6 | LR: [6.25e-05] | Epoch Limit: None


Training:   0% 0/1 [43:47<?, ?it/s, Iter 26, Training Batch 2/250]    

Trng Avg Loss: 0.891886518239975 | Valid Avg Loss: 1.6240203380584717 | Patience: 5 | LR: [6.25e-05] | Epoch Limit: None


Training:   0% 0/1 [44:42<?, ?it/s, Iter 27, Training Batch 2/250]    

Trng Avg Loss: 0.640330873399973 | Valid Avg Loss: 0.9296208620071411 | Patience: 7 | LR: [3.125e-05] | Epoch Limit: None


Training:   0% 0/1 [45:33<?, ?it/s, Iter 28, Training Batch 2/250]    

Trng Avg Loss: 0.4922091069817543 | Valid Avg Loss: 0.9157963395118713 | Patience: 7 | LR: [3.125e-05] | Epoch Limit: None


Training:   0% 0/1 [46:25<?, ?it/s, Iter 29, Training Batch 2/250]    

Trng Avg Loss: 0.5772376187443733 | Valid Avg Loss: 0.7145307064056396 | Patience: 7 | LR: [3.125e-05] | Epoch Limit: None


Training:   0% 0/1 [47:25<?, ?it/s, Iter 30, Training Batch 2/250]    

Trng Avg Loss: 0.37000648303329947 | Valid Avg Loss: 0.8241088390350342 | Patience: 6 | LR: [3.125e-05] | Epoch Limit: None


Training:   0% 0/1 [48:20<?, ?it/s, Iter 31, Training Batch 2/250]    

Trng Avg Loss: 0.5655862446576357 | Valid Avg Loss: 0.7333850860595703 | Patience: 5 | LR: [3.125e-05] | Epoch Limit: None


Training:   0% 0/1 [49:10<?, ?it/s, Iter 32, Training Batch 2/250]    

Trng Avg Loss: 0.3468946526348591 | Valid Avg Loss: 0.7024179100990295 | Patience: 7 | LR: [1.5625e-05] | Epoch Limit: None


Training:   0% 0/1 [50:03<?, ?it/s, Iter 33, Training Batch 2/250]    

Trng Avg Loss: 0.30378655920922754 | Valid Avg Loss: 0.6970689296722412 | Patience: 6 | LR: [1.5625e-05] | Epoch Limit: None


Training:   0% 0/1 [50:57<?, ?it/s, Iter 34, Training Batch 2/250]    

Trng Avg Loss: 0.28705418469011784 | Valid Avg Loss: 0.6941412091255188 | Patience: 5 | LR: [1.5625e-05] | Epoch Limit: None


Training:   0% 0/1 [51:47<?, ?it/s, Iter 35, Training Batch 2/250]    

Trng Avg Loss: 0.28584949153661726 | Valid Avg Loss: 0.6941326260566711 | Patience: 4 | LR: [1.5625e-05] | Epoch Limit: None


Training:   0% 0/1 [52:15<?, ?it/s, Iter 35, Training Batch 178/250]


Cleared buffer
Cleared loss


In [13]:
# audio.view_mel(best_snippet.detach().to("cpu").squeeze())

# Evaluation

In [14]:
gradient.evaluate(small_model, best_snippet, PREPARE_METHOD, test_dataset, CLAMP_EP, POSITION) # commented to prevent the runtime from autorunning and crashing the thing

Clamp: 0.005
Prepare Method: prepare_at_position
Snippet Size: (1, 10240)
Position: 0


Inference: 100%|██████████| 150/150 [01:03<00:00,  2.37it/s, Valid Examples: 120 | Empty Sequences: 27 | Total SL: 8283 | Non-empty ASL: 89.06451612903226 | Total Bleu Score: 41.109153747558594]



Total valid examples: 120
Success rate (Empty): 0.225
Success rate (ASL): 69.025 (attacked) out of 122.75833333333334 (original)
Average Bleu Score: 0.3425762951374054
Average WER: 0.6654056953407352





In [15]:
# random_snippet = (torch.rand(SNIPPET_SIZE) - 0.5) * (CLAMP_EP / 0.5)
# print(torch.max(random_snippet), torch.min(random_snippet))
# gradient.evaluate(model, random_snippet, PREPARE_METHOD, test_dataset, CLAMP_EP, POSITION) # commented to prevent the runtime from autorunning and crashing the thing

# Batch Evaluation

In [16]:
# for C in (0.005,):
#     print(C)
#     best_snippet, _, _, _ = gradient.train(model, gradient.forward_auto,
#                                               train_dataset, validation_dataset,
#                                               PREPARE_METHOD,
#                                               writer, lr=LR, 
#                                               train_success=None, valid_success=None,
#                                               iter_limit=ITER_LIMIT, mins_limit=MIN_LIMIT, patience=PATIENCE, clamp_epsilon=C, gap=DELTA)
#     gradient.evaluate(model, best_snippet, PREPARE_METHOD, test_dataset, C, POSITION) # commented to prevent the runtime from autorunning and crashing the thing
#     print("\n")

In [17]:
# LENGTHS = (1600, 3200, 4800, 6400, 6800, 7200, 8000, 10240, 16000)

# for L in LENGTHS:
#     print(L)
#     PREPARE_METHOD = PrepareAtPosition((1, L), 0)
#     best_snippet, _, _, _ = gradient.train(model, gradient.forward_auto,
#                                               train_dataset, validation_dataset,
#                                               PREPARE_METHOD,
#                                               writer, lr=LR, 
#                                               train_success=None, valid_success=None,
#                                               iter_limit=ITER_LIMIT, mins_limit=MIN_LIMIT, patience=PATIENCE, clamp_epsilon=CLAMP_EP, gap=DELTA)
#     gradient.evaluate(model, best_snippet, PREPARE_METHOD, test_dataset, CLAMP_EP, POSITION) # commented to prevent the runtime from autorunning and crashing the thing
#     print("\n")

# Save Tensors

In [18]:
snippets = torch.stack(list(map(lambda x: x.cpu(), snippets)) + [best_snippet.cpu()])
snippets.shape

torch.Size([36, 1, 10240])

In [19]:
# torch.save(snippets.squeeze(), "snippets.pt")

In [20]:
# torch.save(torch.stack(list(map(torch.tensor, train_success.values()))), "train_success.pt")
# torch.save(torch.tensor(list(train_success.keys())), "train_ids.pt")

In [21]:
# torch.save(torch.stack(list(map(torch.tensor, valid_success.values()))), "valid_success.pt")
# torch.save(torch.tensor(list(valid_success.keys())), "valid_ids.pt")

# Save and Hear Snippet

In [22]:
# def normalise(random_snippet, ep):
#     # we assume torch.rand inits to [0, 1)
#     res = random_snippet * ep * 2 - ep
#     print(f"Normalised, Min {torch.min(res)}, Max {torch.max(res)}")
#     return res

In [23]:
# Save snippet to wav file
# save_audio(snippet, f"./snippets/clamp_{CLAMP_EP}_{PREPARE_METHOD.name}_snippet_only.wav")

In [24]:
# save_audio(PREPARE_METHOD(snippet.to("cpu"), tedlium_test[2]["audio"].unsqueeze(0)), f"./snippets/clamp_{CLAMP_EP}_{PREPARE_METHOD.name}_combined.wav")