# Transfer Learning

## Load Libraries



In [1]:
import importlib
import sys
import os
import numpy as np
import time
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

import torch
from torch import nn
import torch.optim as optim

from torch.utils import data
from torchvision import transforms

sys.path.append(os.path.join(os.getcwd(), ".."))

from distiller import apputils
import ai8x

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

kws20 = importlib.import_module("datasets.kws20-horsecough")

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_recall_fscore_support, roc_auc_score, roc_curve

def plot_confusion(y_true, y_pred, classes):
    cf_matrix = confusion_matrix(y_true = y_true, y_pred = y_pred, labels =list(range(len(classes))))
    print(cf_matrix)

def plot_roc_curve(true_y, y_prob):
    fpr, tpr, thresholds = roc_curve(true_y, y_prob)
    plt.plot(fpr, tpr)
    plt.plot([0,1], [0,1], 'r--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')

## Prepare Checkpoint, Plot, and Dataset Folders

In [2]:
# MODEL AND DATASET NAME 
model_name = 'human_others'
classes = ["combined",'human_cough']
checkpoint_dir = './checkpoints/'+model_name+'/'

indexer = 0
while os.path.exists(checkpoint_dir):
    model_name = 'human_others'
    model_name = model_name + '_' + str(indexer)
    checkpoint_dir = './checkpoints/'+model_name+'/'
    indexer += 1

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

print('Model Name: ', model_name)
print('Classes: ', classes)
print('Checkpoint Dir: ', checkpoint_dir)


Model Name:  human_others_4
Classes:  ['combined', 'human_cough']
Checkpoint Dir:  ./checkpoints/human_others_4/


## Calculate weights

In [3]:
raw_data_path = Path("../data/KWS_EQUINE/raw/")
class_file_count = {}

class_dirs = [d for d in raw_data_path.iterdir() if d.is_dir() and (d.stem != "__combinedkws") and (d.stem != "__combined_others") and (d.stem != "__human_cough")]

for d in class_dirs:
    print(d)
    class_file_count[d] = len(list(d.iterdir()))

min_file_count = float(min(class_file_count.values()))

# Calculate weights
class_weights = []
for d in class_dirs:
    class_file_count[d] = min_file_count / class_file_count[d]
    print(f"{d.stem}: {round(class_file_count[d], 7)}")
    class_weights.append(round(class_file_count[d], 7))
class_weights = [0.5, 1]
print('Weights: ',class_weights)

..\data\KWS_EQUINE\raw\combined
..\data\KWS_EQUINE\raw\human_cough
combined: 0.7269915
human_cough: 1.0
Weights:  [0.7269915, 1.0]


## Generate Processed Dataset

In [None]:
train_batch_size = 4096
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    cpu = False
else:
     device = torch.device('cpu')
     cpu = True

print('Running on device: {}'.format(torch.cuda.get_device_name()))

processed_data_path = Path("../data/KWS_EQUINE/processed/")
if len(os.listdir(processed_data_path)) == 0:
    train_loader, val_loader, test_loader, _ = apputils.get_data_loaders(kws20.KWS_HORSE_TF_get_datasets, ("../data", True), train_batch_size, 4, validation_split=0.1,cpu=cpu)
    print(f"Dataset sizes:\n\ttraining={len(train_loader.sampler)}\n\tvalidation={len(val_loader.sampler)}\n\ttest={len(test_loader.sampler)}")
else:
    print('Dataset Exists')




## Load Reference Model

In [None]:
def count_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print(params)
    return params

ai8x.set_device(device=85, simulate=False, round_avg=False)

mod = importlib.import_module("models.ai85net-kws20-v3")

model = mod.AI85KWS20Netv3(num_classes=21, num_channels=128, dimensions=(128, 1), bias=False)
print(f'Number of Model Params: {count_params(model)}')

# WEIGHTS OF REFERENCE MODEL
model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
            model, "../logs/kws20_original/qat_best.pth.tar")

 # FREEZE SOME LAYERS
def freeze_layer(layer):
    for p in layer.parameters():
        p.requires_grad = False

# freeze_layer(model.voice_conv1)
# freeze_layer(model.voice_conv2)
freeze_layer(model.voice_conv3)
freeze_layer(model.voice_conv4)
freeze_layer(model.kws_conv1)
freeze_layer(model.kws_conv2)
freeze_layer(model.kws_conv3)
# freeze_layer(model.kws_conv4)
model.fc = ai8x.Linear(256, len(classes), bias=False, wide=True)

model = model.to(device)

## Prepare Training Params

In [None]:
num_epochs = 20000
optimizer = optim.Adam(model.parameters(), lr=0.001)
ms_lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500, 1000], gamma=0.2)

criterion = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights))
criterion.to(device)

qat_policy = {
    'start_epoch': 400,
    'weight_bits': 8
    }

## Train the Model

In [None]:
best_acc = 0
best_qat_acc = 0
best_loss = 0
best_epoch = 0

train_acc = []
train_loss = []
val_acc = []
val_loss =[]
running_class_acc = []

for epoch in range(0, num_epochs):
    train_start = time.time()
    if epoch > 0 and epoch == qat_policy['start_epoch']:
        print('QAT is starting!')
        # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
        ai8x.fuse_bn_layers(model)
        # Switch model from unquantized to quantized for QAT
        ai8x.initiate_qat(model, qat_policy)
        # Model is re-transferred to GPU in case parameters were added
        model.to(device)

    average_acc = 0
    ############ TRAIN SECTION ############
    running_loss = []
    y_pred_train = []
    y_true_train = []
    acc_b = []
    model.train()
    for idx, (inputs, target) in enumerate(train_loader):
        inputs = inputs.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        
        model_out = model(inputs)
        target_out = torch.argmax(model_out, dim=1)
        
        y_pred_train.extend(target_out.cpu().numpy())
        y_true_train.extend(target.cpu().numpy())
        
        tp = torch.sum(target_out == target)
        acc_b.extend([(tp / target_out.numel()).detach().item()])
        
        loss = criterion(model_out, target)     
        loss.backward()
        optimizer.step()
        
        running_loss.append(loss.cpu().detach().numpy())
    
    total_acc = np.mean(acc_b)*100
    mean_loss = np.mean(running_loss)

    # TRAIN ACCURACY / TRAIN LOSS
    train_acc.append(total_acc)
    train_loss.append(mean_loss)
    average_acc += total_acc/2

    ############ VALIDATION SECTION ############
    acc_b = []
    y_pred_val = []
    y_true_val = []
    running_v_loss = []
    
    model.eval()
    with torch.no_grad():
        for inputs, target in test_loader:
            inputs = inputs.to(device)
            target = target.to(device)
            model_out = model(inputs)
            target_out = torch.argmax(model_out, dim=1)
            
            y_pred_val.extend(target_out.cpu().numpy())
            y_true_val.extend(target.cpu().numpy())
            
            tp = torch.sum(target_out == target)
            acc_b.extend([(tp / target_out.numel()).detach().item()])

            v_loss = criterion(model_out, target)
            running_v_loss.append(loss.cpu().detach().numpy())

        total_acc = np.mean(acc_b)*100
        mean_loss = np.mean(running_v_loss)

        if epoch == qat_policy['start_epoch']: best_acc = 0
        
        # VALIDATION ACCURACY / VALIDATION LOSS
        val_acc.append(total_acc)
        val_loss.append(mean_loss)
        average_acc += total_acc/2
    
    train_end = time.time()
    ms_lr_scheduler.step()

    ############ CLASS ACCURACY ############
    class_acc = np.zeros(len(classes))
    for class_num,class_type in enumerate(classes):
        class_count_train = y_true_train.count(class_num)
        class_count_val = y_true_val.count(class_num)

        for t_idx, targ_val in enumerate(y_true_train):
            if targ_val == y_pred_train[t_idx] and targ_val == class_num:
                class_acc[class_num] += 1/class_count_train/2*100

        for t_idx, targ_val in enumerate(y_true_val):
            if targ_val == y_pred_val[t_idx] and targ_val == class_num:
                class_acc[class_num] += 1/class_count_val/2*100
        
    running_class_acc.append(class_acc)
    print('---------------------------------------------')
    print("\n\n Epoch: {}/{} \tLR: {} \tDur: {:.2f} sec".format(epoch+1, num_epochs, ms_lr_scheduler.get_lr() , (train_end-train_start)))

    ############ SAVE CHECKPOINT ############   
    if average_acc > best_acc:
        best_acc = average_acc
        checkpoint_extras = {'best_ave_acc': best_acc,
                                'best_epoch': epoch}
        
        model_prefix = f'{model_name}' if epoch < qat_policy['start_epoch'] else (f'qat_{model_name}')
        apputils.save_checkpoint(epoch, model_name+'_'+str(epoch), model, optimizer=optimizer,
                                    scheduler=None, extras=checkpoint_extras,
                                    is_best=True, name=model_prefix,
                                    dir=checkpoint_dir)

        # PLOT CONFUSION MATRIX AND STAT MEASURES ON TRAIN
        conf_mat_train = confusion_matrix(y_true_train, y_pred_train)
        cm_display_train = ConfusionMatrixDisplay(confusion_matrix = conf_mat_train, display_labels = classes)
        p_train,r_train,f1_train,_= precision_recall_fscore_support(y_true_train, y_pred_train, average=None)
        cm_display_train.plot(cmap= 'Blues',colorbar=False, values_format = 'd')
        plt.title('Preicison: ({:.2f} {:.2f})   Recall: ({:.2f} {:.2f})   F1-Score: ({:.2f} {:.2f})'.format(p_train[0],p_train[1],r_train[0],r_train[1],f1_train[0],f1_train[1]))
        plt.savefig(checkpoint_dir+model_name+'_cm_train.png')
        plt.clf()
        plt.cla()
        plt.close()

        # PLOT ROC ON TRAIN
        plot_roc_curve(y_true_train, y_pred_train)
        plt.title('AUC: {:2f}'.format(roc_auc_score(y_true_train, y_pred_train)))
        plt.savefig(checkpoint_dir+model_name+'_roc_train.png')
        plt.clf()
        plt.cla()
        plt.close()

        # PLOT CONFUSION MATRIX AND STAT MEASURES ON VALIDATION
        conf_mat_val = confusion_matrix(y_true_val, y_pred_val)
        cm_display_val = ConfusionMatrixDisplay(confusion_matrix = conf_mat_val, display_labels = classes)
        p_val,r_val,f1_val,_= precision_recall_fscore_support(y_true_train, y_pred_train, average=None)
        cm_display_val.plot(cmap= 'Blues',colorbar=False, values_format = 'd')
        plt.title('Preicison: ({:.2f} {:.2f})   Recall: ({:.2f} {:.2f})   F1-Score: ({:.2f} {:.2f})'.format(p_val[0],p_val[1],r_val[0],r_val[1],f1_val[0],f1_val[1]))
        plt.savefig(checkpoint_dir+model_name+'_cm_Val.png')
        plt.clf()
        plt.cla()
        plt.close()

        # PLOT ROC ON VAL
        plot_roc_curve(y_true_val, y_pred_val)
        plt.title('AUC: {:.2f}'.format(roc_auc_score(y_true_val, y_pred_val)))
        plt.savefig(checkpoint_dir+model_name+'_roc_val.png')
        plt.clf()
        plt.cla()
        plt.close()

        print(f' --------------------------------------------------------->  Model Checkpoints Saved with Mean Accuracy : {best_acc:.2f}%')

    ############ CONFUSION MATRIX ############   
    print("\n Training - Confusion Matrix: ")
    plot_confusion(y_true_train, y_pred_train, classes)
    print("\n Validation - Confusion Matrix: ")
    plot_confusion(y_true_val, y_pred_val, classes)
    
    ############ ACC and LOSS ############  
    print('\nTrain Acc : ', train_acc[-1])
    print('Train Loss : ', train_loss[-1])
    print('Val Acc : ', val_acc[-1])
    print('Val Loss : ', val_loss[-1])

    ############ PLOTS ############
    if (epoch%2 == 0 or epoch==num_epochs-1) and epoch > 0:
        best_epoch = checkpoint_extras['best_epoch']
        
        plt.figure(figsize=(20,10))
        plt.title(model_name)
        plt.subplot(1,2,1)
        plt.plot(val_acc, color ='green')
        plt.plot(train_acc, color = 'red')
        plt.plot(np.asarray(running_class_acc)[:,0],color='orange')
        plt.plot(np.asarray(running_class_acc)[:,1],color='yellow')
        plt.stem(best_epoch,train_acc[best_epoch])
        plt.legend(['Validation','Train',classes[0],classes[1],'Checkpoint'])
        plt.title('Accuracy: {:.2f}'.format(train_acc[best_epoch]))
        plt.xlabel('Epochs')
        plt.ylabel('Value')

        plt.subplot(1,2,2)
        plt.plot(val_loss,color='green')
        plt.plot(train_loss,color='red')
        plt.stem(best_epoch,train_loss[best_epoch])
        plt.legend(['Validation','Train','Checkpoint'])
        plt.title('Loss: {:.2f}'.format(train_loss[best_epoch]))
        plt.xlabel('Epochs')
        plt.ylabel('Value')

        plt.savefig(checkpoint_dir+model_name+'.png')
        
        plt.clf()
        plt.cla()
        plt.close()

    
    
