# Transfer Learning



## Load Libraries



In [1]:
import importlib
import sys
import os
import numpy as np
import time
from pathlib import Path
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader

from torch.utils import data
from torchvision import transforms

sys.path.append(os.path.join(os.getcwd(), ".."))

from distiller import apputils
import ai8x

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

kws20 = importlib.import_module("datasets.kws20-horsecough")

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_recall_fscore_support, roc_auc_score, roc_curve


import librosa
import random
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

# FIX SEED FOR REPRODUCIBILITY
seed = 69
torch.manual_seed(seed)

<torch._C.Generator at 0x18da83dad90>

## Function Declarations

In [2]:
def rescale(audio, min_val=-1,max_val=1):
    sig = audio
    mean = np.average(sig)

    sig = sig-mean # REMOVE DC COMPONENT

    sig_max = np.max(sig)
    sig_min = np.min(sig)

    if sig_max >= np.abs(sig_min):
        sig_scaled = sig/sig_max
    else:
        sig_scaled = sig/np.abs(sig_min)

    return sig_scaled

def rescale2(audio, min_val=-1,max_val=1):
    scaler = MinMaxScaler(feature_range=(min_val,max_val))
    audio = audio.reshape(-1,1)
    scaler.fit(audio)
    scaled = np.array(scaler.transform(audio))
    
    return scaled[:,0]


def plot_confusion(y_true, y_pred, classes):
    cf_matrix = confusion_matrix(y_true = y_true, y_pred = y_pred, labels =list(range(len(classes))))
    print(cf_matrix)

def plot_roc_curve(true_y, y_prob):
    fpr, tpr, thresholds = roc_curve(true_y, y_prob)
    plt.plot(fpr, tpr)
    plt.plot([0,1], [0,1], 'r--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')

def custom_dataloader(data, train_idx, val_idx, test_idx, batch_size = 2048):
    data_file = data
    x = np.asarray(data_file[0][test_idx])
    y = np.squeeze(np.asarray(data_file[1][test_idx]))
    test_set = TensorDataset(torch.Tensor(x),torch.Tensor(y).to(dtype =torch.int64))
    test_loader = DataLoader(test_set,batch_size=batch_size, num_workers=4, pin_memory=True)

    x = np.asarray(data_file[0][val_idx])
    y = np.squeeze(np.asarray(data_file[1][val_idx]))
    val_set = TensorDataset(torch.Tensor(x),torch.Tensor(y).to(dtype =torch.int64))
    val_loader = DataLoader(val_set,batch_size=batch_size, num_workers=4, pin_memory=True)

    x = np.asarray(data_file[0][train_idx])
    y = np.squeeze(np.asarray(data_file[1][train_idx]))
    train_set = TensorDataset(torch.Tensor(x),torch.Tensor(y).to(dtype =torch.int64))
    train_loader = DataLoader(train_set,batch_size=batch_size, num_workers=4, pin_memory=True)

    return train_loader,val_loader,test_loader

def count_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print(params)
    return params

def freeze_layer(layer):
    for p in layer.parameters():
        p.requires_grad = False

## Prepare Checkpoint, Plot, and Dataset Folders

In [3]:
model_name = 'human_others'
classes = ["combined",'cough']
checkpoint_dir = './checkpoints/'+model_name+'/'

indexer = 0
while os.path.exists(checkpoint_dir):
    model_name = 'human_others'
    model_name = model_name + '_' + str(indexer)
    checkpoint_dir = './checkpoints/'+model_name+'/'
    indexer += 1

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

print('Model Name: ', model_name)
print('Classes: ', classes)
print('Checkpoint Dir: ', checkpoint_dir)


Model Name:  human_others_13
Classes:  ['combined', 'cough']
Checkpoint Dir:  ./checkpoints/human_others_13/


## Calculate weights

In [4]:
raw_data_path = Path("../data/KWS_EQUINE/raw/")
class_file_count = {}

class_dirs = [d for d in raw_data_path.iterdir() if d.is_dir() and (d.stem != "__combinedkws") and (d.stem != "__combined_others") and (d.stem != "__human_cough")and (d.stem != "__human_cough_v2")]

for d in class_dirs:
    print(d)
    class_file_count[d] = len(list(d.iterdir()))

min_file_count = float(min(class_file_count.values()))

# Calculate weights
class_weights = []
for d in class_dirs:
    class_file_count[d] = min_file_count / class_file_count[d]
    print(f"{d.stem}: {round(class_file_count[d], 7)}")
    class_weights.append(round(class_file_count[d], 7))
print('Weights: ',class_weights)

..\data\KWS_EQUINE\raw\combined_others
..\data\KWS_EQUINE\raw\human_cough_v2
combined_others: 1.0
human_cough_v2: 0.98327
Weights:  [1.0, 0.98327]


## Generate Processed Datafile

In [6]:
train_batch_size = 4096

processed_data_path = Path("../data/KWS_EQUINE/processed/")
fname = model_name+'.pt'
_, _, _, _ = apputils.get_data_loaders(kws20.KWS_HORSE_TF_get_datasets,("../data", True), 
                                       train_batch_size, 4, validation_split=0.1)

No key `noise_var` in input augmentation dictionary!  Using defaults: [Min: 0., Max: 1.]
No key `shift` in input augmentation dictionary! Using defaults: [Min:-0.1, Max: 0.1]
No key `strech` in input augmentation dictionary! Using defaults: [Min: 0.8, Max: 1.3]
Generating dataset from raw data samples for the first time. 
This process will take significant time (~60 minutes)...
data_len: 16384
------------- Label Size ---------------
combined_others:  	5172
human_cough_v2:  	5260
------------------------------------------
Processing the label: combined_others. 1 of 2
	1 of 5172
	1001 of 5172
	2001 of 5172
	3001 of 5172
	4001 of 5172
	5001 of 5172
Finished in 106.081 seconds.
(15516, 128, 128)
Data concatenation finished in 0.026 seconds.
Processing the label: human_cough_v2. 2 of 2
	1 of 5260
	1001 of 5260
	2001 of 5260
	3001 of 5260
	4001 of 5260
	5001 of 5260
Finished in 105.289 seconds.
(15780, 128, 128)
Data concatenation finished in 0.064 seconds.
Dataset created.
Training+Validat

## Read Processed Dataset

In [7]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    cpu = False
else:
     device = torch.device('cpu')
     cpu = True

print('Running on device: {}'.format(torch.cuda.get_device_name()))

processed_dir = "../data/KWS_EQUINE/processed/"

test_ratio = 0.1
val_ratio = 0.1

if os.path.exists(processed_dir+fname):
     data_file = torch.load(processed_dir+fname) # (data, class, type)
     len_dataset = len(data_file[0])

     train_indices, test_indices = train_test_split(np.arange(len_dataset),test_size=test_ratio, train_size=1-test_ratio)
     train_indices, val_indices = train_test_split(train_indices,test_size=val_ratio/(1-test_ratio), train_size=1-(val_ratio/(1-test_ratio)))

     train_loader,val_loader,test_loader = custom_dataloader(data=data_file, train_idx = train_indices,
     val_idx = val_indices, test_idx = test_indices, batch_size = train_batch_size)

     # x = np.asarray(data_file[0][test_indices])
     # y = np.squeeze(np.asarray(data_file[1][test_indices]))

     # test_set = TensorDataset(torch.Tensor(x),torch.Tensor(y).to(dtype =torch.int64))
     # test_loader = DataLoader(test_set,batch_size=train_batch_size, num_workers=4, pin_memory=True)

     # x = np.asarray(data_file[0][val_indices])
     # y = np.squeeze(np.asarray(data_file[1][val_indices]))

     # val_set = TensorDataset(torch.Tensor(x),torch.Tensor(y).to(dtype =torch.int64))
     # val_loader = DataLoader(val_set,batch_size=train_batch_size, num_workers=4, pin_memory=True)
     
     # x = np.asarray(data_file[0][train_indices])
     # y = np.squeeze(np.asarray(data_file[1][train_indices]))

     # train_set = TensorDataset(torch.Tensor(x),torch.Tensor(y).to(dtype =torch.int64))
     # train_loader = DataLoader(train_set,batch_size=train_batch_size, num_workers=4, pin_memory=True)
     
     print('Dataloaders Created')
     print('Train Loader Size: ',len(train_loader.dataset))
     print('Validation Loader Size: ',len(val_loader.dataset))
     print('Test Loader Size: ',len(test_loader.dataset))
else:
     print('Dataset does not Exist')
    

Running on device: NVIDIA GeForce RTX 2070 SUPER
Dataloaders Created
Train Loader Size:  25036
Validation Loader Size:  3130
Test Loader Size:  3130


## LOOP AT EVERY CHANGE OF FROZEN LAYER TO FIND BEST PERFORMING MODEL

In [9]:
num_epochs = 1500

criterion = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights))
print(criterion.type)
criterion.to(device)

qat_policy = {
    'start_epoch': 100,
    'weight_bits': 8
    }
print('Epochs: ', num_epochs)
# print('Optimizer: \n',optimizer)
print('Loss Function: \n',criterion)
print('QAT: \n',qat_policy)

ai8x.set_device(device=85, simulate=False, round_avg=False)

mod = importlib.import_module("models.ai85net-kws20-v3")

for fl in range (0,8):
    model = mod.AI85KWS20Netv3(num_classes=21, num_channels=128, dimensions=(128, 1), bias=False)
    print(f'Number of Model Params: {count_params(model)}')

    # WEIGHTS OF REFERENCE MODEL
    model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
                model, "../logs/kws20_original/qat_best.pth.tar")

    # FREEZE SOME LAYERS
    if fl >= 0: freeze_layer(model.voice_conv1)
    if fl >= 1: freeze_layer(model.voice_conv2)
    if fl >= 2: freeze_layer(model.voice_conv3)
    if fl >= 3: freeze_layer(model.voice_conv4)
    if fl >= 4: freeze_layer(model.kws_conv1)
    if fl >= 5: freeze_layer(model.kws_conv2)
    if fl >= 6: freeze_layer(model.kws_conv3)
    if fl >= 7: freeze_layer(model.kws_conv4)
    model.fc = ai8x.Linear(256, len(classes), bias=False, wide=True)

    model = model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    ms_lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500, 1000], gamma=0.2)

    best_acc = 0
    best_qat_acc = 0
    best_loss = 0
    best_epoch = 0

    train_acc = []
    train_loss = []
    val_acc = []
    val_loss =[]
    running_class_acc = []

    

    continue_train = True
    epoch = 0
    while(epoch < num_epochs):
        train_start = time.time()
        if epoch > 0 and epoch == qat_policy['start_epoch']:
            print('QAT is starting!')
            # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
            ai8x.fuse_bn_layers(model)
            # Switch model from unquantized to quantized for QAT
            ai8x.initiate_qat(model, qat_policy)
            # Model is re-transferred to GPU in case parameters were added
            model.to(device)

        average_acc = 0
        ############ TRAIN SECTION ############
        running_loss = []
        y_pred_train = []
        y_true_train = []
        acc_b = []
        model.train()
        for idx, (inputs, target) in enumerate(train_loader):
            inputs = inputs.to(device)
            target = target.type(torch.int64)
            target = target.to(device)
            optimizer.zero_grad()
            
            model_out = model(inputs)
            target_out = torch.argmax(model_out, dim=1)
            
            y_pred_train.extend(target_out.cpu().numpy())
            y_true_train.extend(target.cpu().numpy())
            
            tp = torch.sum(target_out == target)
            acc_b.extend([(tp / target_out.numel()).detach().item()])

            loss = criterion(model_out, target)     
            loss.backward()
            optimizer.step()
            
            running_loss.append(loss.cpu().detach().numpy())
        
        total_acc = np.mean(acc_b)*100
        mean_loss = np.mean(running_loss)

        # TRAIN ACCURACY / TRAIN LOSS
        train_acc.append(total_acc)
        train_loss.append(mean_loss)
        average_acc += total_acc*0.5

        ############ VALIDATION SECTION ############
        acc_b = []
        y_pred_val = []
        y_true_val = []
        running_v_loss = []
        
        model.eval()
        with torch.no_grad():
            for inputs, target in val_loader:
                inputs = inputs.to(device)
                target = target.to(device)
                model_out = model(inputs)
                target_out = torch.argmax(model_out, dim=1)
                
                y_pred_val.extend(target_out.cpu().numpy())
                y_true_val.extend(target.cpu().numpy())
                
                tp = torch.sum(target_out == target)
                acc_b.extend([(tp / target_out.numel()).detach().item()])

                v_loss = criterion(model_out, target)
                running_v_loss.append(loss.cpu().detach().numpy())

            total_acc = np.mean(acc_b)*100
            mean_loss = np.mean(running_v_loss)

            if epoch == qat_policy['start_epoch']: best_acc = 0
            
            # VALIDATION ACCURACY / VALIDATION LOSS
            val_acc.append(total_acc)
            val_loss.append(mean_loss)
            average_acc += total_acc*0.5
        
        ############ CLASS ACCURACY ############
        class_acc = np.zeros(len(classes))
        for class_num,class_type in enumerate(classes):
            class_count_train = y_true_train.count(class_num)
            class_count_val = y_true_val.count(class_num)

            for t_idx, targ_val in enumerate(y_true_train):
                if targ_val == y_pred_train[t_idx] and targ_val == class_num:
                    class_acc[class_num] += 1/class_count_train*10

            for t_idx, targ_val in enumerate(y_true_val):
                if targ_val == y_pred_val[t_idx] and targ_val == class_num:
                    class_acc[class_num] += 1/class_count_val*90
            
        running_class_acc.append(class_acc)
        
        train_end = time.time()
        print('---------------------------------------------')
        print("\n\n Epoch: {}/{} \tLR: {} \tDur: {:.2f} sec".format(epoch+1, num_epochs, ms_lr_scheduler.get_lr() , (train_end-train_start)))

        ############ CONFUSION MATRIX ############   
        print("\n Training - Confusion Matrix: ")
        plot_confusion(y_true_train, y_pred_train, classes)
        print("\n Validation - Confusion Matrix: ")
        plot_confusion(y_true_val, y_pred_val, classes)
        
        ############ ACC and LOSS ############  
        print('\nTrain Acc : ', train_acc[-1])
        print('Train Loss : ', train_loss[-1])
        print('Val Acc : ', val_acc[-1])
        print('Val Loss : ', val_loss[-1])

        ############ PLOTS ############
        if (epoch%10 == 0 or epoch==num_epochs-1) and epoch > 0:
            best_epoch = checkpoint_extras['best_epoch']
            
            plt.figure(figsize=(20,10),dpi=1000)
            plt.title(model_name)
            plt.subplot(1,2,1)
            plt.plot(np.asarray(running_class_acc)[:,0],color='orange')
            plt.plot(np.asarray(running_class_acc)[:,1],color='yellow')
            plt.plot(val_acc, color ='green')
            plt.plot(train_acc, color = 'red')
            plt.stem(best_epoch,train_acc[best_epoch])
            plt.legend([classes[0],classes[1],'Validation','Train','Checkpoint'])
            plt.title('Train Acc: {:.2f}   Val Acc: {:.2f}'.format(train_acc[best_epoch],val_acc[best_epoch]))
            plt.xlabel('Epochs')
            plt.ylabel('Value')

            plt.subplot(1,2,2)
            plt.plot(val_loss,color='green')
            plt.plot(train_loss,color='red')
            plt.stem(best_epoch,train_loss[best_epoch])
            plt.legend(['Validation','Train','Checkpoint'])
            plt.title('Train Loss: {:.2f}   Train Loss: {:.2f}'.format(train_loss[best_epoch],val_loss[best_epoch]))
            plt.xlabel('Epochs')
            plt.ylabel('Value')

            plt.savefig(checkpoint_dir+model_name+'_'+str(fl)+'.png')
            
            plt.clf()
            plt.cla()
            plt.close()
        
        ############ CHECKPOINT CRITERION ############   
        if average_acc > best_acc:
            best_acc = average_acc

            # SAVE CHECKPOINT
            checkpoint_extras = {'best_ave_acc': best_acc,'best_epoch': epoch}
            # model_prefix = f'{model_name}' if epoch < qat_policy['start_epoch'] else (f'qat_{model_name}')
            # apputils.save_checkpoint(epoch, model_name, model, optimizer=optimizer,
            #                             scheduler=None, extras=checkpoint_extras,
            #                             is_best=True, name=model_prefix,
            #                             dir=checkpoint_dir)

            # PLOT CONFUSION MATRIX AND STAT MEASURES ON TRAIN
            conf_mat_train = confusion_matrix(y_true_train, y_pred_train)
            cm_display_train = ConfusionMatrixDisplay(confusion_matrix = conf_mat_train, display_labels = classes)
            p_train,r_train,f1_train,_= precision_recall_fscore_support(y_true_train, y_pred_train, average=None)
            cm_display_train.plot(cmap= 'Blues',colorbar=False, values_format = 'd')
            plt.title('Preicison: ({:.2f} {:.2f})   Recall: ({:.2f} {:.2f})   F1-Score: ({:.2f} {:.2f})'.format(p_train[0],p_train[1],r_train[0],r_train[1],f1_train[0],f1_train[1]))
            plt.savefig(checkpoint_dir+model_name+'_'+str(fl)+'_cm_train.png')
            plt.clf()
            plt.cla()
            plt.close()

            # PLOT CONFUSION MATRIX AND STAT MEASURES ON VALIDATION
            conf_mat_val = confusion_matrix(y_true_val, y_pred_val)
            cm_display_val = ConfusionMatrixDisplay(confusion_matrix = conf_mat_val, display_labels = classes)
            p_val,r_val,f1_val,_= precision_recall_fscore_support(y_true_val, y_pred_val, average=None)
            cm_display_val.plot(cmap= 'Blues',colorbar=False, values_format = 'd')
            plt.title('Preicison: ({:.2f} {:.2f})   Recall: ({:.2f} {:.2f})   F1-Score: ({:.2f} {:.2f})'.format(p_val[0],p_val[1],r_val[0],r_val[1],f1_val[0],f1_val[1]))
            plt.savefig(checkpoint_dir+model_name+'_'+str(fl)+'_cm_Val.png')
            plt.clf()
            plt.cla()
            plt.close()

            # PLOT ROC ON VAL
            plot_roc_curve(y_true_val, y_pred_val)
            plt.title('AUC: {:.2f}'.format(roc_auc_score(y_true_val, y_pred_val)))
            plt.savefig(checkpoint_dir+model_name+'_'+str(fl)+'_roc_val.png')
            plt.clf()
            plt.cla()
            plt.close()

            print(' --------------------------------------------------------->  EPOCH {:.0f} Checkpoint Saved w/ Acc: {:.2f}% '.format(best_epoch,best_acc))
        
        ############ STOP TRAINING ############ 
        if epoch > num_epochs*0.75 and best_acc > 80 and average_acc > best_acc:
            print('Ending Training, Best Checkpoint Found')
            continue_train = False
            break
        
        ms_lr_scheduler.step()
        epoch += 1

<bound method Module.type of CrossEntropyLoss()>
Epochs:  3000
Loss Function: 
 CrossEntropyLoss()
QAT: 
 {'start_epoch': 100, 'weight_bits': 8}
Configuring device: MAX78000, simulate=False.
169472
Number of Model Params: 169472
---------------------------------------------


 Epoch: 1/3000 	LR: [0.0001] 	Dur: 3.58 sec

 Training - Confusion Matrix: 
[[6741 5696]
 [6785 5814]]

 Validation - Confusion Matrix: 
[[ 684  806]
 [ 637 1003]]

Train Acc :  50.10147733347756
Train Loss :  0.7042886
Val Acc :  53.897762298583984
Val Loss :  0.70359206
 --------------------------------------------------------->  EPOCH 0 Checkpoint Saved w/ Acc: 52.00% 
---------------------------------------------


 Epoch: 2/3000 	LR: [0.0001] 	Dur: 3.22 sec

 Training - Confusion Matrix: 
[[6031 6406]
 [5456 7143]]

 Validation - Confusion Matrix: 
[[ 650  840]
 [ 534 1106]]

Train Acc :  53.11499152864728
Train Loss :  0.69369984
Val Acc :  56.10223412513733
Val Loss :  0.6836463
 ---------------------------

KeyboardInterrupt: 