# Transfer Learning

## Load Libraries



In [1]:
import importlib
import sys
import os
import numpy as np
import time

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

import torch
from torch import nn
import torch.optim as optim

from torch.utils import data
from torchvision import transforms

sys.path.append(os.path.join(os.getcwd(), ".."))

from distiller import apputils
import ai8x

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

kws20 = importlib.import_module("datasets.kws20-horsecough")

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
def plot_confusion(y_true, y_pred, classes):
    cf_matrix = confusion_matrix(y_true = y_true, y_pred = y_pred, labels =list(range(len(classes))))
    print(cf_matrix)

## Creating combined Dataset. Class Merging

In [2]:
# from pathlib import Path

# #raw_data_path = Path("C:/Users/J_C/Desktop/DATASETS/raw")
# raw_data_path = Path("C:/Users/J_C/Desktop/DATASETS/raw/")
# class_file_count = {}

# class_dirs = [d for d in raw_data_path.iterdir() if d.is_dir() and (d.stem != "_background_noise_")]

# # Create combined Dataset
# import shutil
# combined_path = "C:/Users/J_C/Documents/GitHub/ai8x-training/data/KWS_EQUINE/combined/"
# if os.path.isdir(combined_path) == False:
#     os.makedirs(combined_path)

# for i,d in enumerate(class_dirs):
#     print(class_file_count[d] ," files in folder ",d)
#     fnames =  os.listdir(d)
#     for fi,f in enumerate(d.iterdir()):
#         shutil.copyfile(f, combined_path+'/'+str(i)+'_'+fnames[fi])

## Calculate weights

In [3]:
from pathlib import Path

#raw_data_path = Path("C:/Users/J_C/Desktop/DATASETS/raw")
raw_data_path = Path("../data/KWS_EQUINE/raw/")
class_file_count = {}

class_dirs = [d for d in raw_data_path.iterdir() if d.is_dir() and (d.stem != "_background_noise_")]

for d in class_dirs:
    class_file_count[d] = len(list(d.iterdir()))

min_file_count = float(min(class_file_count.values()))

# Calculate weights
class_weights = []
for d in class_dirs:
    class_file_count[d] = min_file_count / class_file_count[d]
    print(f"{d.stem}: {round(class_file_count[d], 7)}")
    class_weights.append(round(class_file_count[d], 7))

print('Weights: ',class_weights)

combined: 0.0171171
human_cough: 1.0
Weights:  [0.0171171, 1.0]


## Generate Processed Dataset

In [4]:
# NOTE: Smart compressed file creation
# Change class dicts of main Dataloader class
train_batch_size = 256
train_loader, val_loader, test_loader, _ = apputils.get_data_loaders(
    kws20.KWS_HORSE_TF_get_datasets, ("../data", True), train_batch_size, 1, validation_split=0.1)

print(f"Dataset sizes:\n\ttraining={len(train_loader.sampler)}\n\tvalidation={len(val_loader.sampler)}\n\ttest={len(test_loader.sampler)}")

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

classes = ["combined",'human_cough']

No key `noise_var` in input augmentation dictionary!  Using defaults: [Min: 0., Max: 1.]
No key `shift` in input augmentation dictionary! Using defaults: [Min:-0.1, Max: 0.1]
No key `strech` in input augmentation dictionary! Using defaults: [Min: 0.8, Max: 1.3]
Generating dataset from raw data samples for the first time. 
This process will take significant time (~60 minutes)...
data_len: 16384
------------- Label Size ---------------
combined:  	219664
human_cough:  	3760
------------------------------------------
Processing the label: combined. 1 of 2
	1 of 219664
	1001 of 219664
	2001 of 219664
	3001 of 219664
	4001 of 219664
	5001 of 219664
	6001 of 219664
	7001 of 219664
	8001 of 219664
	9001 of 219664
	10001 of 219664
	11001 of 219664
	12001 of 219664
	13001 of 219664
	14001 of 219664
	15001 of 219664
	16001 of 219664
	17001 of 219664
	18001 of 219664
	19001 of 219664
	20001 of 219664
	21001 of 219664
	22001 of 219664
	23001 of 219664
	24001 of 219664
	25001 of 219664
	26001 of 21

## Load Reference Model

In [7]:
def count_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print(params)
    return params

ai8x.set_device(device=85, simulate=False, round_avg=False)

mod = importlib.import_module("models.ai85net-kws20-v3")

model = mod.AI85KWS20Netv3(num_classes=21, num_channels=128, dimensions=(128, 1), bias=False)
print(f'Number of Model Params: {count_params(model)}')

# WEIGHTS OF REFERENCE MODEL
model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
            model, "../logs/kws20_original/qat_best.pth.tar")

 # FREEZE SOME LAYERS
def freeze_layer(layer):
    for p in layer.parameters():
        p.requires_grad = False
freeze_layer(model.voice_conv1)
freeze_layer(model.voice_conv2)
freeze_layer(model.voice_conv3)
freeze_layer(model.voice_conv4)
freeze_layer(model.kws_conv1)
freeze_layer(model.kws_conv2)
# freeze_layer(model.kws_conv3)
# freeze_layer(model.kws_conv4)
model.fc = ai8x.Linear(256, len(classes), bias=False, wide=True)

model = model.to(device)

Configuring device: MAX78000, simulate=False.
169472
Number of Model Params: 169472


## Prepare Training Params

In [8]:
num_epochs = 1000
optimizer = optim.Adam(model.parameters(), lr=0.001)
ms_lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(num_epochs/4), int(num_epochs/2)], gamma=0.2)

criterion = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights))
criterion.to(device)

qat_policy = {
    'start_epoch': int(num_epochs/5),
    'weight_bits': 8
    }

model_name = 'human_kws20_1'
checkpoint_dir = './checkpoints/'+model_name+'/'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

plot_dir ='./checkpoints/plots/'+model_name+'/'
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)


## Train the Model

In [11]:
best_acc = 0
best_qat_acc = 0

train_acc = []
train_loss = []
val_acc = []
val_loss =[]

for epoch in range(0, num_epochs):
    if epoch > 0 and epoch == qat_policy['start_epoch']:
        print('QAT is starting!')
        # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
        ai8x.fuse_bn_layers(model)
        # Switch model from unquantized to quantized for QAT
        ai8x.initiate_qat(model, qat_policy)
        # Model is re-transferred to GPU in case parameters were added
        model.to(device)
    
    ############ TRAIN SECTION ############
    running_loss = []
    acc = 0.
    acc_weight = 0

    train_start = time.time()
    model.train()
    for idx, (inputs, target) in enumerate(train_loader):
        inputs = inputs.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        
        model_out = model(inputs)
        target_out = torch.argmax(model_out, dim=1)
        
        tp = torch.sum(target_out == target)
        acc_batch = (tp / target_out.numel()).detach().item()
        acc += target_out.shape[0] * acc_batch
        acc_weight += target_out.shape[0]
        
        loss = criterion(model_out, target)

        writer.add_scalar("Loss/train", loss, epoch)        

        loss.backward()
        optimizer.step()
        
        running_loss.append(loss.cpu().detach().numpy())
    
    total_acc = 100 * (acc / acc_weight)
    mean_loss = np.mean(running_loss)

    # TRAIN ACCURACY / TRAIN LOSS
    train_acc.append(total_acc)
    train_loss.append(mean_loss)

    train_end = time.time()
    print("Epoch: {}/{}\t LR: {}\t \t Dur: {:.2f} sec.".format(
        epoch+1, num_epochs, ms_lr_scheduler.get_lr(), (train_end-train_start)))
    
    
    ############ VALIDATION SECTION ############

    model.eval()
    acc = 0.
    acc_weight = 0
    y_true = []
    y_pred = []
    running_v_loss = []
    
    with torch.no_grad():
        for inputs, target in test_loader:
            inputs = inputs.to(device)
            target = target.to(device)
            model_out = model(inputs)
            target_out = torch.argmax(model_out, dim=1)
            
            y_pred.extend(target_out.cpu().numpy())
            y_true.extend(target.cpu().numpy())
            
            tp = torch.sum(target_out == target)
            acc_batch = (tp / target_out.numel()).detach().item()
            acc += target_out.shape[0] * acc_batch
            acc_weight += target_out.shape[0]

            v_loss = criterion(model_out, target)
            running_v_loss.append(loss.cpu().detach().numpy())
    
        mean_loss = np.mean(running_v_loss)
        total_acc = 100 * (acc / acc_weight)

        if epoch == qat_policy['start_epoch']: best_acc = 0
        
        # VALIDATION ACCURACY / VALIDATION LOSS
        val_acc.append(total_acc)
        val_loss.append(mean_loss)

        if total_acc > best_acc:
            best_acc = total_acc
            checkpoint_extras = {'current_top1': best_acc,
                                 'best_top1': best_acc,
                                 'best_epoch': epoch}
            
            model_prefix = f'{model_name}' if epoch < qat_policy['start_epoch'] else (f'qat_{model_name}')
            apputils.save_checkpoint(epoch, model_name, model, optimizer=optimizer,
                                     scheduler=None, extras=checkpoint_extras,
                                     is_best=True, name=model_prefix,
                                     dir=checkpoint_dir)
            print(f'Best model saved with accuracy: {best_acc:.2f}%')
            
        #print('\t\t Test Acc: {:.2f}'.format(total_acc))
        print("\tConfusion:")
        plot_confusion(y_true, y_pred, classes)

    ms_lr_scheduler.step()

    print('----------------------')
    print('Train Acc : ', train_acc[-1])
    print('Train Loss : ', train_loss[-1])
    print('Val Acc : ', val_acc[-1])
    print('Val Loss : ', val_loss[-1])
    print('----------------------')

    if epoch%100 == 0:
        plt.figure(epoch,figsize=(20,10))
        plt.subplot(1,2,1)
        plt.plot(train_acc, color = 'red')
        plt.plot(val_acc, color ='blue')
        plt.legend(['Train','Validation'])
        plt.title('Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Value')

        plt.subplot(1,2,2)
        plt.plot(train_loss,color='red')
        plt.plot(val_loss,color='blue')
        plt.legend(['Train','Validation'])
        plt.title('Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Value')

        plt.savefig(plot_dir+str(epoch)+'.png')

KeyboardInterrupt: 

In [None]:
'''plt.figure(4,figsize=(20,10))
plt.subplot(1,2,1)
plt.plot(train_acc, color = 'red')
plt.plot(val_acc, color ='blue')
plt.legend(['Train','Validation'])
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Value')

plt.subplot(1,2,2)
plt.plot(train_loss,color='red')
plt.plot(val_loss,color='blue')
plt.legend(['Train','Validation'])
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Value')

plt.savefig(model_name+'_check.png')
#plt.show()'''

## INFERENCE ON HORSE COUGH AUDIO

In [None]:
'''import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion(y_true, y_pred, classes):
    cf_matrix = confusion_matrix(y_true = y_true, y_pred = y_pred, labels =list(range(len(classes))))
    print(cf_matrix)

best_acc = 0
best_qat_acc = 0

train_acc = []
train_loss = []
val_acc = []
val_loss =[]

for epoch in range(0, num_epochs):
    if epoch > 0 and epoch == qat_policy['start_epoch']:
        print('QAT is starting!')
        # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
        ai8x.fuse_bn_layers(model)
        # Switch model from unquantized to quantized for QAT
        ai8x.initiate_qat(model, qat_policy)
        # Model is re-transferred to GPU in case parameters were added
        model.to(device)
    
    ############ VALIDATION SECTION ############

    model.eval()
    acc = 0.
    acc_weight = 0
    y_true = []
    y_pred = []
    running_v_loss = []
    
    with torch.no_grad():
        for inputs, target in test_loader:
            inputs = inputs.to(device)
            target = target.to(device)
            model_out = model(inputs)
            target_out = torch.argmax(model_out, dim=1)
            
            y_pred.extend(target_out.cpu().numpy())
            y_true.extend(target.cpu().numpy())
            
            tp = torch.sum(target_out == target)
            acc_batch = (tp / target_out.numel()).detach().item()
            acc += target_out.shape[0] * acc_batch
            acc_weight += target_out.shape[0]

            v_loss = criterion(model_out, target)
            running_v_loss.append(loss.cpu().detach().numpy())
    
        mean_loss = np.mean(running_v_loss)
        total_acc = 100 * (acc / acc_weight)

        if epoch == qat_policy['start_epoch']: best_acc = 0
        
        # VALIDATION ACCURACY / VALIDATION LOSS
        val_acc.append(total_acc)
        val_loss.append(mean_loss)

        if total_acc > best_acc:
            best_acc = total_acc
            checkpoint_extras = {'current_top1': best_acc,
                                 'best_top1': best_acc,
                                 'best_epoch': epoch}
            model_name = 'ai85net_kws_equine_binary_scraped'
            model_prefix = f'{model_name}' if epoch < qat_policy['start_epoch'] else (f'qat_{model_name}')
            apputils.save_checkpoint(epoch, model_name, model, optimizer=optimizer,
                                     scheduler=None, extras=checkpoint_extras,
                                     is_best=True, name=model_prefix,
                                     dir='.')
            print(f'Best model saved with accuracy: {best_acc:.2f}%')
            
        #print('\t\t Test Acc: {:.2f}'.format(total_acc))
        print("\tConfusion:")
        plot_confusion(y_true, y_pred, classes)

    ms_lr_scheduler.step()

    print('----------------------')
    print('Train Acc : ', train_acc[-1])
    print('Train Loss : ', train_loss[-1])
    print('Val Acc : ', val_acc[-1])
    print('Val Loss : ', val_loss[-1])
    print('----------------------')

plt.figure(1)
plt.subplot(1,2,1)
plt.plot(train_acc, color = 'red')
plt.plot(val_acc, color ='blue')

plt.figure(1)
plt.subplot(1,2,2)
plt.plot(train_loss,color='red')
plt.plot(val_loss,color='blue')
plt.savefig(model_name+'.png')
plt.show()


'''