# Transfer Learning

## Load Libraries



In [1]:
import importlib
import sys
import os
import numpy as np
import time

import torch
from torch import nn
import torch.optim as optim

from torch.utils import data
from torchvision import transforms

sys.path.append(os.path.join(os.getcwd(), ".."))

from distiller import apputils
import ai8x

kws20 = importlib.import_module("datasets.kws20-horsecough")

## Calculate weights

In [2]:
from pathlib import Path

raw_data_path = Path("../data/KWS/raw/")
class_file_count = {}

class_dirs = [d for d in raw_data_path.iterdir() if d.is_dir() and d.stem != "_background_noise_"]

for d in class_dirs:
    class_file_count[d] = len(list(d.iterdir()))

min_file_count = float(min(class_file_count.values()))

for d in class_dirs:
    class_file_count[d] = min_file_count / class_file_count[d]
    print(f"{d.stem}: {round(class_file_count[d], 7)}")

backward: 0.0120192
bed: 0.0099305
bird: 0.0096899
cat: 0.0098474
dog: 0.0093985
down: 0.0051059
eight: 0.0052812
five: 0.0049358
follow: 0.0126662
forward: 0.0128452
four: 0.0053648
go: 0.0051546
happy: 0.0097371
horse_cough: 0.3773585
horse_neigh: 0.7407407
house: 0.0094652
human_cough: 1.0
learn: 0.0126984
left: 0.0052618
marvin: 0.0095238
nine: 0.0050839
no: 0.0050749
off: 0.0053405
on: 0.0052016
one: 0.0051414
right: 0.0052938
seven: 0.0050025
sheila: 0.0098912
six: 0.0051813
stop: 0.0051653
three: 0.0053662
tree: 0.0113701
two: 0.0051546
up: 0.005372
visual: 0.0125628
wow: 0.0094206
yes: 0.0049456
zero: 0.0049358


## Generate processed dataset

In [3]:
train_batch_size = 128
train_loader, val_loader, test_loader, _ = apputils.get_data_loaders(
    kws20.KWS_HORSE_TF_get_datasets, ("../data", False), train_batch_size, 1, validation_split=0)
print(f"Dataset sizes:\n\ttraining={len(train_loader.sampler)}\n\tvalidation={len(val_loader.sampler)}\n\ttest={len(test_loader.sampler)}")

No key `noise_var` in input augmentation dictionary!  Using defaults: [Min: 0., Max: 1.]
No key `shift` in input augmentation dictionary! Using defaults: [Min:-0.1, Max: 0.1]
No key `strech` in input augmentation dictionary! Using defaults: [Min: 0.8, Max: 1.3]
Using downloaded and verified file: ../data\KWS\raw\speech_commands_v0.02.tar.gz
Extracting ../data\KWS\raw\speech_commands_v0.02.tar.gz to ../data\KWS\raw
Generating dataset from raw data samples for the first time. 
This process will take significant time (~60 minutes)...
data_len: 16384
------------- Label Size ---------------
backward:  	1664
bed     :  	2014
bird    :  	2064
cat     :  	2031
dog     :  	2128
down    :  	3917
eight   :  	3787
five    :  	4052
follow  :  	1579
forward :  	1557
four    :  	3728
go      :  	3880
happy   :  	2054
horse_cough:  	53
horse_neigh:  	27
house   :  	2113
human_cough:  	20
learn   :  	1575
left    :  	3801
marvin  :  	2100
nine    :  	3934
no      :  	3941
off     :  	3745
on      :  	

AttributeError: 'bool' object has no attribute 'truncate_testset'

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

In [None]:
classes = ["horse_cough", "horse_neigh", "human_cough"]

In [None]:
def count_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

In [None]:
ai8x.set_device(device=85, simulate=False, round_avg=False)

mod = importlib.import_module("models.ai85net-kws20-kabayo")

model = mod.AI85KWS20Netv3(num_classes=21, num_channels=128, dimensions=(128, 1), bias=False)
print(f'Number of Model Params: {count_params(model)}')

# model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
#             model, "logs/2023.01.10-013008/qat_best.pth.tar", model_device='cuda')

model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
            model, "logs/2023.01.23-053753/qat_best.pth.tar")

## Replace FC layer and freeze the rest of the layers

In [None]:
def freeze_layer(layer):
    for p in layer.parameters():
        p.requires_grad = False

In [None]:
freeze_layer(model.voice_conv1)
freeze_layer(model.voice_conv2)
freeze_layer(model.voice_conv3)
freeze_layer(model.voice_conv4)
freeze_layer(model.kws_conv1)
# freeze_layer(model.kws_conv2)
# freeze_layer(model.kws_conv3)
# freeze_layer(model.kws_conv4)
model.fc = ai8x.Linear(256, 8, bias=False, wide=True)

model = model.to(device)

## Train the model

In [None]:
num_epochs = 20
epoch = 0
optimizer = optim.Adam(model.parameters(), lr=0.0001)
ms_lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 80], gamma=0.5)
criterion = torch.nn.CrossEntropyLoss(
    torch.Tensor((1, 1, 0.01, 1, 1, 0.01, 0.01, 0.01))
)
criterion.to(device)

qat_policy = {
    'start_epoch': 20,
    'weight_bits': 8
    }

In [None]:
from sklearn.metrics import confusion_matrix

def plot_confusion(y_true, y_pred, classes):
    cf_matrix = confusion_matrix(y_true, y_pred, list(range(len(classes))))
    print(cf_matrix)

In [None]:
best_acc = 0
best_qat_acc = 0
for epoch in range(0, num_epochs):
    if epoch > 0 and epoch == qat_policy['start_epoch']:
        print('QAT is starting!')
        # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
        ai8x.fuse_bn_layers(model)

        # Switch model from unquantized to quantized for QAT
        ai8x.initiate_qat(model, qat_policy)

        # Model is re-transferred to GPU in case parameters were added
        model.to(device)
    running_loss = []
    train_start = time.time()
    model.train()
    for idx, (inputs, target) in enumerate(train_loader):
        inputs = inputs.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        
        model_out = model(inputs)
        
        loss = criterion(model_out, target)
        loss.backward()
        optimizer.step()
        
        running_loss.append(loss.cpu().detach().numpy())

    mean_loss = np.mean(running_loss)
    train_end = time.time()
    print("Epoch: {}/{}\t LR: {}\t Train Loss: {:.4f}\t Dur: {:.2f} sec.".format(
        epoch+1, num_epochs, ms_lr_scheduler.get_lr(), mean_loss, (train_end-train_start)))
    
    model.eval()
    acc = 0.
    acc_weight = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for inputs, target in test_loader:
            inputs = inputs.to(device)
            target = target.to(device)
            model_out = model(inputs)
            target_out = torch.argmax(model_out, dim=1)
            
            y_pred.extend(target_out.cpu().numpy())
            y_true.extend(target.cpu().numpy())
            
            tp = torch.sum(target_out == target)
            acc_batch = (tp / target_out.numel()).detach().item()
            acc += target_out.shape[0] * acc_batch
            acc_weight += target_out.shape[0]
            
        total_acc = 100 * (acc / acc_weight)
        if epoch == qat_policy['start_epoch']: best_acc = 0
        if total_acc > best_acc:
            best_acc = total_acc
            checkpoint_extras = {'current_top1': best_acc,
                                 'best_top1': best_acc,
                                 'best_epoch': epoch}
            model_name = 'ai85net_kws_dash'
            model_prefix = f'{model_name}' if epoch < qat_policy['start_epoch'] else (f'qat_{model_name}')
            apputils.save_checkpoint(epoch, model_name, model, optimizer=optimizer,
                                     scheduler=None, extras=checkpoint_extras,
                                     is_best=True, name=model_prefix,
                                     dir='.')
            print(f'Best model saved with accuracy: {best_acc:.2f}%')
            
        print('\t\t Test Acc: {:.2f}'.format(total_acc))
        print("\t\tConfusion:")
        plot_confusion(y_true, y_pred, classes)
    ms_lr_scheduler.step()