# Transfer Learning

## Load Libraries and Function Declarations



In [1]:
import importlib
import sys
import os
import numpy as np
import time
from pathlib import Path
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader

from torch.utils import data
from torchvision import transforms

sys.path.append(os.path.join(os.getcwd(), ".."))

from distiller import apputils
import ai8x

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

kws20 = importlib.import_module("datasets.kws20-horsecough")

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_recall_fscore_support, roc_auc_score, roc_curve


import librosa
import random
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

# FIX SEED FOR REPRODUCIBILITY
seed = 69
torch.manual_seed(seed)

def rescale(audio, min_val=-1,max_val=1):
    sig = audio
    mean = np.average(sig)

    sig = sig-mean # REMOVE DC COMPONENT

    sig_max = np.max(sig)
    sig_min = np.min(sig)

    if sig_max >= np.abs(sig_min):
        sig_scaled = sig/sig_max
    else:
        sig_scaled = sig/np.abs(sig_min)

    return sig_scaled

def rescale2(audio, min_val=-1,max_val=1):
    scaler = MinMaxScaler(feature_range=(min_val,max_val))
    audio = audio.reshape(-1,1)
    scaler.fit(audio)
    scaled = np.array(scaler.transform(audio))
    
    return scaled[:,0]


def plot_confusion(y_true, y_pred, classes):
    cf_matrix = confusion_matrix(y_true = y_true, y_pred = y_pred, labels =list(range(len(classes))))
    print(cf_matrix)

def plot_roc_curve(true_y, y_prob):
    fpr, tpr, thresholds = roc_curve(true_y, y_prob)
    plt.plot(fpr, tpr)
    plt.plot([0,1], [0,1], 'r--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')

def custom_dataloader(data, train_idx, val_idx, test_idx, batch_size = 2048):
    data_file = data
    x = np.asarray(data_file[0][test_idx])
    y = np.squeeze(np.asarray(data_file[1][test_idx]))
    test_set = TensorDataset(torch.Tensor(x),torch.Tensor(y).to(dtype =torch.int64))
    test_loader = DataLoader(test_set,batch_size=batch_size, num_workers=4, pin_memory=True)

    x = np.asarray(data_file[0][val_idx])
    y = np.squeeze(np.asarray(data_file[1][val_idx]))
    val_set = TensorDataset(torch.Tensor(x),torch.Tensor(y).to(dtype =torch.int64))
    val_loader = DataLoader(val_set,batch_size=batch_size, num_workers=4, pin_memory=True)

    x = np.asarray(data_file[0][train_idx])
    y = np.squeeze(np.asarray(data_file[1][train_idx]))
    train_set = TensorDataset(torch.Tensor(x),torch.Tensor(y).to(dtype =torch.int64))
    train_loader = DataLoader(train_set,batch_size=batch_size, num_workers=4, pin_memory=True)

    return train_loader,val_loader,test_loader
    

## Calculate weights

In [3]:
from pathlib import Path

raw_data_path = Path("../data/KWS_EQUINE/raw/")
class_file_count = {}

class_dirs = [d for d in raw_data_path.iterdir() if d.is_dir() and d.stem != "__combinedkws" and d.stem != "__human_cough"]

for d in class_dirs:
    class_file_count[d] = len(list(d.iterdir()))

min_file_count = float(min(class_file_count.values()))

for d in class_dirs:
    class_file_count[d] = min_file_count / class_file_count[d]
    print(f"{d.stem}: {round(class_file_count[d], 7)}")

combined_others: 1.0
human_cough_v2: 0.98327


## Generate processed dataset

In [7]:
# NOTE: Smart compressed file creation
# Change class dicts of main Dataloader class
train_batch_size = 128
train_loader, val_loader, test_loader, _ = apputils.get_data_loaders(
    kws20.KWS_HORSE_TF_get_datasets, ("../data", False), train_batch_size, 1, validation_split=0.1)

print(f"Dataset sizes:\n\ttraining={len(train_loader.sampler)}\n\tvalidation={len(val_loader.sampler)}\n\ttest={len(test_loader.sampler)}")

TypeError: KWS_HORSE_TF_get_datasets() got an unexpected keyword argument 'download'

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cpu


In [5]:
classes = ['combined','horse_cough']

In [6]:
def count_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

In [7]:
ai8x.set_device(device=85, simulate=False, round_avg=False)

mod = importlib.import_module("models.ai85net-kws20-kabayo")

model = mod.AI85KWS20Netv3(num_classes=21, num_channels=128, dimensions=(128, 1), bias=False)
print(f'Number of Model Params: {count_params(model)}')

# model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
#             model, "logs/2023.01.10-013008/qat_best.pth.tar", model_device='cuda')

model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
            model, "../logs/kws20/qat_best.pth.tar")

Configuring device: MAX78000, simulate=False.
Number of Model Params: 173568


## Replace FC layer and freeze the rest of the layers

In [8]:
def freeze_layer(layer):
    for p in layer.parameters():
        p.requires_grad = False

In [9]:
freeze_layer(model.voice_conv1)
freeze_layer(model.voice_conv2)
freeze_layer(model.voice_conv3)
freeze_layer(model.voice_conv4)
freeze_layer(model.kws_conv1)
# freeze_layer(model.kws_conv2)
# freeze_layer(model.kws_conv3)
# freeze_layer(model.kws_conv4)
model.fc = ai8x.Linear(256, len(classes), bias=False, wide=True)

model = model.to(device)

## Train the model

In [10]:
num_epochs = 500
epoch = 0
optimizer = optim.Adam(model.parameters(), lr=0.0001)
ms_lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 80], gamma=0.5)
# criterion = torch.nn.CrossEntropyLoss(
#     torch.Tensor((1, 1, 1))
# )
# criterion.to(device)

criterion = torch.nn.CrossEntropyLoss(
    torch.Tensor((1, 1))
)
criterion.to(device)

# criterion = torch.nn.CrossEntropyLoss(
#     torch.Tensor((1, 1, 0.01, 1, 1, 0.01, 0.01, 0.01))
# )
# criterion.to(device)

qat_policy = {
    'start_epoch': 20,
    'weight_bits': 8
    }

In [11]:
from sklearn.metrics import confusion_matrix

def plot_confusion(y_true, y_pred, classes):
    cf_matrix = confusion_matrix(y_true, y_pred, list(range(len(classes))))
    print(cf_matrix)

In [12]:
# for i, (input, target) in enumerate(train_loader):
#     print(np.shape(input), np.shape(target), sep=" ")

In [13]:
best_acc = 0
best_qat_acc = 0
for epoch in range(0, num_epochs):
    if epoch > 0 and epoch == qat_policy['start_epoch']:
        print('QAT is starting!')
        # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
        ai8x.fuse_bn_layers(model)

        # Switch model from unquantized to quantized for QAT
        ai8x.initiate_qat(model, qat_policy)

        # Model is re-transferred to GPU in case parameters were added
        model.to(device)
    running_loss = []
    train_start = time.time()
    model.train()
    for idx, (inputs, target) in enumerate(train_loader):
        inputs = inputs.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        
        model_out = model(inputs)
        
        loss = criterion(model_out, target)
        loss.backward()
        optimizer.step()
        
        running_loss.append(loss.cpu().detach().numpy())

    mean_loss = np.mean(running_loss)
    train_end = time.time()
    print("Epoch: {}/{}\t LR: {}\t Train Loss: {:.4f}\t Dur: {:.2f} sec.".format(
        epoch+1, num_epochs, ms_lr_scheduler.get_lr(), mean_loss, (train_end-train_start)))
    
    model.eval()
    acc = 0.
    acc_weight = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for inputs, target in test_loader:
            inputs = inputs.to(device)
            target = target.to(device)
            model_out = model(inputs)
            target_out = torch.argmax(model_out, dim=1)
            
            y_pred.extend(target_out.cpu().numpy())
            y_true.extend(target.cpu().numpy())
            
            tp = torch.sum(target_out == target)
            acc_batch = (tp / target_out.numel()).detach().item()
            acc += target_out.shape[0] * acc_batch
            acc_weight += target_out.shape[0]
            
        total_acc = 100 * (acc / acc_weight)
        if epoch == qat_policy['start_epoch']: best_acc = 0
        if total_acc > best_acc:
            best_acc = total_acc
            checkpoint_extras = {'current_top1': best_acc,
                                 'best_top1': best_acc,
                                 'best_epoch': epoch}
            model_name = 'ai85net_kws_equine'
            model_prefix = f'{model_name}' if epoch < qat_policy['start_epoch'] else (f'qat_{model_name}')
            apputils.save_checkpoint(epoch, model_name, model, optimizer=optimizer,
                                     scheduler=None, extras=checkpoint_extras,
                                     is_best=True, name=model_prefix,
                                     dir='.')
            print(f'Best model saved with accuracy: {best_acc:.2f}%')
            
        print('\t\t Test Acc: {:.2f}'.format(total_acc))
        print("\t\tConfusion:")
        plot_confusion(y_true, y_pred, classes)
    ms_lr_scheduler.step()

Epoch: 1/50	 LR: [0.0001]	 Train Loss: 0.6912	 Dur: 11.03 sec.
Best model saved with accuracy: 76.67%
		 Test Acc: 76.67
		Confusion:
[[ 0  6]
 [ 1 23]]
Epoch: 2/50	 LR: [0.0001]	 Train Loss: 0.6823	 Dur: 10.50 sec.
		 Test Acc: 76.67
		Confusion:
[[ 0  6]
 [ 1 23]]
Epoch: 3/50	 LR: [0.0001]	 Train Loss: 0.6639	 Dur: 10.55 sec.
Best model saved with accuracy: 80.00%
		 Test Acc: 80.00
		Confusion:
[[ 1  5]
 [ 1 23]]
Epoch: 4/50	 LR: [0.0001]	 Train Loss: 0.6553	 Dur: 10.40 sec.
Best model saved with accuracy: 83.33%
		 Test Acc: 83.33
		Confusion:
[[ 2  4]
 [ 1 23]]
Epoch: 5/50	 LR: [0.0001]	 Train Loss: 0.6532	 Dur: 10.69 sec.
		 Test Acc: 80.00
		Confusion:
[[ 2  4]
 [ 2 22]]
Epoch: 6/50	 LR: [0.0001]	 Train Loss: 0.6287	 Dur: 11.13 sec.
		 Test Acc: 80.00
		Confusion:
[[ 3  3]
 [ 3 21]]
Epoch: 7/50	 LR: [0.0001]	 Train Loss: 0.6414	 Dur: 10.39 sec.
		 Test Acc: 83.33
		Confusion:
[[ 4  2]
 [ 3 21]]
Epoch: 8/50	 LR: [0.0001]	 Train Loss: 0.6336	 Dur: 10.92 sec.
		 Test Acc: 76.67
		C