In [2]:
from src.util_classes import get_train_test_dataloaders, LogMelspec, train
from src.configs import TaskConfig, StreamingTaskConfig
from src.model import CRNN, StreamingCRNN
from src.wandb_pipeline import train_baseline

import torch

baseline_task_config = TaskConfig(
    hidden_size=32,
    bottleneck_size=32
)

train_loader, val_loader = get_train_test_dataloaders("speech_commands", baseline_task_config)

baseline_melspec_train = LogMelspec(is_train=True, config=baseline_task_config)
baseline_melspec_val = LogMelspec(is_train=False, config=baseline_task_config)

In [3]:
import wandb
from src.wandb_pipeline import evaluate_model

teacher = CRNN(baseline_task_config).cuda()

_ = wandb.restore('baseline.pt', run_path="broccoliman/kws/381ohren")

teacher.load_state_dict(torch.load("baseline.pt"))

baseline_melspec_val.melspec.cpu()

evaluate_model(teacher.cpu(), val_loader, baseline_melspec_val, "cpu")

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

{'area under FA/FR curve': 2.6525748998086058e-05,
 'evaluation time (s)': 1.1379876136779785,
 'memory size (MB)': 0.10112476348876953,
 'number of parameters': 25387,
 'MACs': 54911167.058823526}

In [None]:
teacher_fp16 = torch.quantization.quantize_dynamic(teacher, dtype=torch.float16)
evaluate_model(teacher_fp16.cpu(), val_loader, baseline_melspec_val, "cpu")

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

{'area under FA/FR curve': 4.7638932340094716e-05,
 'evaluation time (s)': 1.2087290287017822,
 'memory size (MB)': 0.10608959197998047,
 'number of parameters': 808,
 'MACs': 20102823.529411763}

In [None]:
from src.configs import DistillTaskConfig
from src.wandb_pipeline import train_distillation

distill_config = DistillTaskConfig(
    hidden_size=24,
    bottleneck_size=16,
    cnn_out_channels=6,
    distill_w=.25,
    attn_distill_w=.05,
    melspec_win_length=400,
    melspec_hop_length=160,
    num_epochs=50,
    use_scheduler=True
)

student_melspec_train = LogMelspec(is_train=True, config=distill_config)
student_melspec_val = LogMelspec(is_train=False, config=distill_config)

student = train_distillation(
    teacher,
    train_loader, 
    val_loader,
    baseline_melspec_train,
    student_melspec_train,
    student_melspec_val,
    distill_config,
    log_wandb=True,
    name_wandb="triple-50-epochs"
) 

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train/attn_loss,████▇▇▆▅▆▄▄▄▃▁▂▁
train/cls_loss,█▇▇▇▇▇▆▇▅▃▄▃▂▃▃▁
train/kl_loss,▇▇███▆▆▇▆▃▅▄▄▃▄▁
train/loss,█▇██▇▇▆▆▆▃▄▃▃▃▃▁
train/lr,████▇▇▇▆▆▅▅▄▃▃▂▁

0,1
train/attn_loss,54.29375
train/cls_loss,16.37372
train/kl_loss,16.4942
train/loss,23.21196
train/lr,0.0003


Number of trainable parameters: 14321
EPOCH: 0


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 1


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 2


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 3


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 4


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 5


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 6


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 7


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 8


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 9


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 10


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 11


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 12


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 13


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 14


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 15


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 16


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 17


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 18


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 19


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 20


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 21


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 22


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 23


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 24


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 25


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 26


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 27


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 28


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 29


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 30


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 31


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 32


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 33


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 34


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 35


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 36


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 37


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 38


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 39


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 40


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 41


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 42


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 43


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 44


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 45


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 46


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 47


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 48


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

EPOCH: 49


  0%|          | 0/405 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

VBox(children=(Label(value='2.955 MB of 2.955 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test: MACs,▁
test: area under FA/FR curve,▁
test: evaluation time (s),▁
test: memory size (MB),▁
test: number of parameters,▁
train/attn_loss,█▆▄▄▄▂▄▂▃▃▂▂▃▂▂▃▂▂▁▃▂▂▄▂▂▂▁▄▃▃▂▁▂▃▂▃▂▂▄▂
train/cls_loss,█▆▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁
train/kl_loss,█▆▄▄▃▁▃▁▄▃▃▂▂▃▁▄▂▄▂▂▂▁▄▂▃▁▁▄▃▄▄▁▂▄▃▃▃▂▅▃
train/loss,█▆▄▄▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▂▂▂▂▂▁▁▁▂▂▁▁▂▁

0,1
epoch,49.0
test: MACs,34605284.70588
test: area under FA/FR curve,5e-05
test: evaluation time (s),1.13919
test: memory size (MB),0.05901
test: number of parameters,14321.0
train/attn_loss,49.65679
train/cls_loss,4.80611
train/kl_loss,11.38213
train/loss,10.13448


In [None]:
student_melspec_val.melspec.cpu()
evaluate_model(student.cpu(), val_loader, student_melspec_val, "cpu")

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

{'area under FA/FR curve': 5.3522934255080735e-05,
 'evaluation time (s)': 1.185582160949707,
 'memory size (MB)': 0.06118488311767578,
 'number of parameters': 14321,
 'MACs': 34605284.705882356}

In [None]:
student_fp16 = torch.quantization.quantize_dynamic(student, dtype=torch.float16)
evaluate_model(student_fp16.cpu(), val_loader, student_melspec_val, "cpu")

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

{'area under FA/FR curve': 5.3516966707499614e-05,
 'evaluation time (s)': 1.3515667915344238,
 'memory size (MB)': 0.06379222869873047,
 'number of parameters': 606,
 'MACs': 15077117.647058824}