In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = ""

from keras.models import load_model
from keras.layers import Lambda
from keras import losses
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import scatter
from matplotlib.lines import Line2D
import tensorflow as tf
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from mpl_toolkits import mplot3d
from tqdm import tqdm


In [None]:
df = pd.read_csv("/home/josh/Documents/Mosquito_Project/New_Data/Data/MIMIdata/mosquitos.dat", '\t')
df = df[df['RearCnd']!='VF']

In [None]:
## Functions for Z-score and sensitivity for input-output

def generate_sensitivity_Z_score(model, layer_name, age, size=1650):

    layer_output = model.get_layer(layer_name).output

    if layer_name == 'age':
        df_1 = df[df['Age']==age]
        loss = layer_output[:, age-1]
        
    elif layer_name == 'species':
        df_1 = df[df['Species']==age]
        if age == 'AA':
            loss = layer_output[:, 0]
        elif age == 'AC':
            loss = layer_output[:, 1]
        elif age == 'AG':
            loss = layer_output[:, 2]
    
    X = df_1.iloc[:,5:]
    X = np.asarray(X)
    
        
    grads = K.gradients(loss, model.input)[0]
        
    iterate = K.function([model.input], [loss, grads])
    
    gradients = []
    for i in range(len(X)):
        input_img_data = X[i]
        input_img_data = np.expand_dims(input_img_data, axis=0)
        input_img_data = np.expand_dims(input_img_data, axis=2)
        loss_value, grads_value = iterate([input_img_data])
        gradients.append(np.squeeze(np.abs(grads_value)))
        
    sensitivity = 1/len(gradients) * np.sum(gradients, axis=0)
    return sensitivity/np.linalg.norm(sensitivity)


def sensitivites_for_age(age):
    sensitivities = []
    for count in tqdm(range(10)):
        model = load_model(("../opt_tailored_output/CNN_8_8s_3_6s_5p/Orig_0_"+str(count+1)+"_Model.h5"))
    #     model.summary()
        for layer in model.layers:
            layer.trainable = False

        sensitivity = generate_sensitivity_Z_score(model, 'age', age)
        sensitivities.append(sensitivity)

        del model
        
    return sensitivities

def sensitivites_for_species(species):
    sensitivities = []
    for count in tqdm(range(10)):
        model = load_model(("../opt_tailored_output/CNN_8_8s_3_6s_5p/Orig_0_"+str(count+1)+"_Model.h5"))
    #     model.summary()
        for layer in model.layers:
            layer.trainable = False

        sensitivity = generate_sensitivity_Z_score(model, 'species', species)
        sensitivities.append(sensitivity)

        del model
        
    return sensitivities


In [None]:
## Generates outputs of Z-score and sensitivty for input-ouput
## Specifically for Age

sensitivities_save = []
for age in tqdm(range(1,18)):
    sensitivities = sensitivites_for_age(age)
    sensitivities_save.append(sensitivities)

    Z_scores = []
    m_signals = []
    for sens1 in range(10):
        for sens2 in range(10):
            s_signal = (sensitivities[sens1] + sensitivities[sens2]) / np.sqrt(2)

            iterations = int(len(s_signal)/50)

            for index in range(iterations):
                signals = s_signal[(50*index):(50*(index+1))]
                mean_b = np.mean(signals)
                sigma_b = np.std(signals)
                for sig in signals:
                    Z_b = (sig-mean_b)/sigma_b
                    Z_scores.append(Z_b)
                    m_signals.append(sig)

    fig = plt.figure()
    plt.scatter(m_signals, Z_scores)
    poly_index = 3
    plt.plot(np.unique(m_signals), np.poly1d(np.polyfit(m_signals, Z_scores, poly_index))(np.unique(m_signals)), color='k', linewidth=3)

    index_95 = (np.where(np.logical_and(np.poly1d(np.polyfit(m_signals, Z_scores, poly_index))(np.unique(m_signals)) < 1.6458, np.poly1d(np.polyfit(m_signals, Z_scores, poly_index))(np.unique(m_signals)) > 1.644)))
    index_95 = index_95[0][int(len(index_95)/2)]
    print(index_95)
    y_value = (np.poly1d(np.polyfit(m_signals, Z_scores, poly_index))(np.unique(m_signals))[index_95])
    x_value = (np.unique(m_signals)[index_95])

    plt.plot([0, x_value], [y_value, y_value], 'k--')
    plt.plot([x_value, x_value], [-4, y_value], 'k--')
    plt.xlim([0, 0.2])
    plt.ylim([-4, 6])
    plt.xlabel('Signal value')
    plt.ylabel('Z-score')
    plt.title(('Z-score Calculation - Age '+str(age)))
    plt.tight_layout()
    plt.savefig(('Sensitivity/Sensitivity_Maps_Confidence_Level/Species/Z_Score_age_'+str(age)+'.png'))

    ## Start of individual age Sensitivity plots

    fig = plt.figure(figsize=(8,4))
    ax = fig.add_subplot(1,1,1)
    l1 = plt.plot(np.arange(3800, 500, -2), np.squeeze(sensitivities[0]), 'b')

    l3 = plt.plot([3800, 500], [x_value, x_value], 'k--')
    ax.set_xlim(3800, 500)
    ax.set_ylim(0, 0.16)
    ax.set_xlabel('Wavenumber $cm^{-1}$')
    ax.set_ylabel('Sensitivity')
    ax.set_title(('Sensitivity map - Age '+str(age)))

    for mol in [3500, 3100, 2950, 2800, 1820, 1670, 1519, 1377, 1150, 1020, 1154, 1000, 675]:
            l2 = plt.plot([mol, mol], [0, 0.16], 'k', linewidth=1)

    plt.fill_between([3500, 3100], [0], [0.16], color='k', alpha=0.2)
    plt.fill_between([2950, 2800], [0], [0.16], color='k', alpha=0.2)
    plt.fill_between([1820, 1670], [0], [0.16], color='k', alpha=0.2)
    plt.fill_between([1150, 1020], [0], [0.16], color='k', alpha=0.2)
    plt.fill_between([1000, 675], [0], [0.16], color='k', alpha=0.2)

    ax2 = ax.twiny()
    new_tick_loc = [3300, 2925, 1745, 1519, 1377, 1065, 1154, 837.5]
    ax2.set_xlim(ax.get_xlim())
    ax2.set_xticks(new_tick_loc)
    ax2.set_xticklabels(['N-H', 'C-H2', 'C=O', 'O=C-N', 'C-CH3', 'C-N', 'C-O-C', 'C-H2'])
    plt.setp(ax2.get_xticklabels(), rotation=90)
    # ax2.set_xlabel('Present Molecules')

    custom_points = [Line2D([3800], [0], color = 'b', label='Sensitivity'),
                     Line2D([3800], [0], color = 'k', lw=1, label='Previously used wavenumbers'),
                     Line2D([3800], [0], color = 'k', linestyle = '--', label='95% confidence interval')]
    ax.legend(custom_points, ['Sensitivity', 'Molecular Bands', '95% confidence interval'], loc='center left', bbox_to_anchor=(1,0.5))

    plt.tight_layout()
    plt.savefig(('Sensitivity/Sensitivity_Maps_Confidence_Level/Species/Sensitivity_Map_Age_'+str(age)+'.png'))
    

In [None]:
## Generates outputs of Z-score and sensitivty for input-ouput
## Specifically for Species

sensitivities_save = []
for species in tqdm(['AA', 'AC', 'AG']):
    sensitivities = sensitivites_for_species(species)
    sensitivities_save.append(sensitivities)

    Z_scores = []
    m_signals = []
    for sens1 in range(10):
        for sens2 in range(10):
            s_signal = (sensitivities[sens1] + sensitivities[sens2]) / np.sqrt(2)

            iterations = int(len(s_signal)/50)

            for index in range(iterations):
                signals = s_signal[(50*index):(50*(index+1))]
                mean_b = np.mean(signals)
                sigma_b = np.std(signals)
                for sig in signals:
                    Z_b = (sig-mean_b)/sigma_b
                    Z_scores.append(Z_b)
                    m_signals.append(sig)

    fig = plt.figure()
    plt.scatter(m_signals, Z_scores)
    poly_index = 3
    plt.plot(np.unique(m_signals), np.poly1d(np.polyfit(m_signals, Z_scores, poly_index))(np.unique(m_signals)), color='k', linewidth=3)

    index_95 = (np.where(np.logical_and(np.poly1d(np.polyfit(m_signals, Z_scores, poly_index))(np.unique(m_signals)) < 1.6458, np.poly1d(np.polyfit(m_signals, Z_scores, poly_index))(np.unique(m_signals)) > 1.644)))
    index_95 = index_95[0][int(len(index_95)/2)]
    print(index_95)
    y_value = (np.poly1d(np.polyfit(m_signals, Z_scores, poly_index))(np.unique(m_signals))[index_95])
    x_value = (np.unique(m_signals)[index_95])

    plt.plot([0, x_value], [y_value, y_value], 'k--')
    plt.plot([x_value, x_value], [-4, y_value], 'k--')
    plt.xlim([0, 0.2])
    plt.ylim([-4, 6])
    plt.xlabel('Signal value')
    plt.ylabel('Z-score')
    plt.title(('Z-score Calculation - Species '+species))
    plt.tight_layout()
    plt.savefig(('Sensitivity/Sensitivity_Maps_Confidence_Level/Species/Z_Score_species_'+species+'.png'))

    ## Start of individual species Sensitivity plots

    fig = plt.figure(figsize=(8,4))
    ax = fig.add_subplot(1,1,1)
    l1 = plt.plot(np.arange(3800, 500, -2), np.squeeze(sensitivities[0]), 'b')

    l3 = plt.plot([3800, 500], [x_value, x_value], 'k--')
    ax.set_xlim(3800, 500)
    ax.set_ylim(0, 0.22)
    ax.set_xlabel('Wavenumber $cm^{-1}$')
    ax.set_ylabel('Sensitivity')
    ax.set_title(('Sensitivity map - Species '+species))

    for mol in [3500, 3100, 2950, 2800, 1820, 1670, 1519, 1377, 1150, 1020, 1154, 1000, 675]:
            l2 = plt.plot([mol, mol], [0, 0.22], 'k', linewidth=1)

    plt.fill_between([3500, 3100], [0], [0.22], color='k', alpha=0.2)
    plt.fill_between([2950, 2800], [0], [0.22], color='k', alpha=0.2)
    plt.fill_between([1820, 1670], [0], [0.22], color='k', alpha=0.2)
    plt.fill_between([1150, 1020], [0], [0.22], color='k', alpha=0.2)
    plt.fill_between([1000, 675], [0], [0.22], color='k', alpha=0.2)

    ax2 = ax.twiny()
    new_tick_loc = [3300, 2925, 1745, 1519, 1377, 1065, 1154, 837.5]
    ax2.set_xlim(ax.get_xlim())
    ax2.set_xticks(new_tick_loc)
    ax2.set_xticklabels(['N-H', 'C-H2', 'C=O', 'O=C-N', 'C-CH3', 'C-N', 'C-O-C', 'C-H2'])
    plt.setp(ax2.get_xticklabels(), rotation=90)
    # ax2.set_xlabel('Present Molecules')

    custom_points = [Line2D([3800], [0], color = 'b', label='Sensitivity'),
                     Line2D([3800], [0], color = 'k', lw=1, label='Previously used wavenumbers'),
                     Line2D([3800], [0], color = 'k', linestyle = '--', label='95% confidence interval')]
    ax.legend(custom_points, ['Sensitivity', 'Molecular Bands', '95% confidence interval'], loc='center left', bbox_to_anchor=(1,0.5))

    plt.tight_layout()
    plt.savefig(('Sensitivity/Sensitivity_Maps_Confidence_Level/Species/Sensitivity_Map_Age_'+species+'.png'))
    