In [5]:
from scipy.interpolate import make_interp_spline, BSpline
#from scipy.ndimage.filters import gaussian_filter1d
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

from scipy.interpolate import interp1d



In [27]:
def mse(y_pred, y_true):
    y_pred = tf.cast(y_pred, tf.float64)
    y_true = tf.cast(y_true,tf.float64)

    MSE = tf.reduce_mean(tf.pow(tf.subtract(y_true,y_pred),2))

    return MSE, y_pred, y_true

def calculate_pt(y_pred, y_true):

    y_pred = tf.cast(y_pred, tf.float64)
    y_true = tf.cast(y_true,tf.float64)
    
    pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))

    #ref:https://github.com/unsky/focal-loss/blob/master/focal_loss.py
    pt_1 = tf.clip_by_value(pt_1, 1e-14, 1.0) #avoid log(0) that returns inf    
    #pt_1 = tf.add(pt_1, 1e-14) #avoid log(0) that returns inf    

    #if value is zero in y_true, than take value from y_pred, otherwise, write zeros
    pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        
    return pt_0, pt_1

def calculate_CE(y_pred, y_true):

    pt_0, pt_1 = calculate_pt(y_pred, y_true)

    n1 =  tf.cast(-1.0, tf.float64)

    CE_1 = tf.multiply(n1,tf.log(pt_1))
    CE_0 = tf.multiply(n1,tf.log(tf.subtract(tf.cast(1.0, tf.float64),pt_0)))

    return CE_0, CE_1

def cross_entropy(y_pred, y_true,MAF_all_var_vector):
    
    pt_0, pt_1 = calculate_pt(y_pred, y_true)

    CE_0, CE_1 =  calculate_CE(y_pred, y_true) 
   
    CE_1 = tf.reduce_sum(CE_1)
    CE_0 = tf.reduce_sum(CE_0)

    CE = tf.add(CE_1, CE_0, name='reconstruction_loss')
    
    return CE

def calculate_alpha(MAF_all_var_vector):
    
    one=tf.cast(1.0, tf.float64)
    eps=tf.cast(1.0-1e-4, tf.float64)

    alpha = tf.multiply(tf.cast(MAF_all_var_vector,tf.float64),2.0)
    alpha = tf.clip_by_value(alpha, 0.0, eps)

    alpha_1 = tf.divide(one, alpha)
    alpha_0 = tf.divide(one, tf.subtract(one,alpha))
    
    #alpha_0,alpha_1 = alpha_1,alpha_0
    #alpha_1 = alpha_0
    
    return alpha_0, alpha_1

def weighted_cross_entropy(y_pred, y_true, MAF_all_var_vector):

    one=tf.cast(1.0, tf.float64)
    eps=tf.cast(1.0+1e-14, tf.float64)
    n1 =  tf.cast(-1.0, tf.float64)

    pt_0, pt_1 = calculate_pt(y_pred, y_true)

    CE_0, CE_1 =  calculate_CE(y_pred, y_true) 
        
    alpha_0, alpha_1 = calculate_alpha(MAF_all_var_vector)
     
    WCE_per_var_1 = tf.multiply(CE_1, alpha_1)
    WCE_per_var_0 = tf.multiply(CE_0, alpha_0)

    WCE_1 = tf.reduce_sum(WCE_per_var_1)
    WCE_0 = tf.reduce_sum(WCE_per_var_0)

    WCE = tf.add(WCE_1, WCE_0, name='reconstruction_loss')
    
    return WCE

def focal_loss(y_pred, y_true, MAF_all_var_vector, gamma):

    one=tf.cast(1.0, tf.float64)
    eps=tf.cast(1.0+1e-14, tf.float64)
    
    pt_0, pt_1 = calculate_pt(y_pred, y_true)

    CE_0, CE_1 =  calculate_CE(y_pred, y_true) 

    alpha_0, alpha_1 = calculate_alpha(MAF_all_var_vector)
        
    gamma_0 = tf.pow(pt_0, gamma)
    gamma_1 = tf.pow(tf.subtract(eps, pt_1), gamma)
    
    FL_per_var_1 = tf.multiply(gamma_1, CE_1)
    FL_per_var_0 = tf.multiply(gamma_0, CE_0)

    FL_per_var_1 = tf.multiply(FL_per_var_1, alpha_1)
    FL_per_var_0 = tf.multiply(FL_per_var_0, alpha_0)
    
    FL_1 = tf.reduce_sum(FL_per_var_1, axis=0)
    FL_0 = tf.reduce_sum(FL_per_var_0, axis=0)

    FL = tf.add(FL_1, FL_0, name='reconstruction_loss')
    
    return FL


In [28]:
proportions = [0.005, 0.01, 0.06, 0.11, 0.23,0.34, 0.50]
y_pred = np.zeros((1000,1))

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
probabilities = [0.05,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]
#probabilities = np.arange(0.1,1.01,0.1)

result_list = []
FL1_list = []
FL2_list = []
FL3_list = []
FL4_list = []
FL5_list = []
CE_list = []
WCE_list = []
MSE_list = []
proportions_list = []
probabilities_list = []

for prob in probabilities:
    #if(prob==0.3):
    #    break

    for p in proportions:
        p1 = p
        p0 = (1-p)

        y_true = np.zeros(1000)
        ones = np.ones(int(p*1000))
        y_true[0:int(p*1000)] = ones

        y_true.reshape(1000,1)

        y_pred=np.copy(y_true)
        y_pred=np.multiply(y_pred,prob)

        #print(y_pred)
        #print(y_true)
        #print(p1, y_pred)	
        #np.clip(y_pred,0,1,out=y_pred)
        #print(y_pred)    
        a0, a1 = calculate_alpha(p1)
        
        CE = cross_entropy(y_pred, y_true,p1)

        MSE, a, b = mse(y_pred,y_true)

        WCE = weighted_cross_entropy(y_pred, y_true,p1)

        FL1 = focal_loss(y_pred, y_true,p1,0)
        FL2 = focal_loss(y_pred, y_true,p1,0.5)
        FL3 = focal_loss(y_pred, y_true,p1,1)
        FL4 = focal_loss(y_pred, y_true,p1,2)
        FL5 = focal_loss(y_pred, y_true,p1,5)
        #FL6 = focal_loss(y_pred, y_true,p1,6)
        #FL7 = focal_loss(y_pred, y_true,p1,7)

        with tf.Session(config=config) as sess:

            my_a0, my_a1, my_MSE, my_CE, my_WCE, my_FL1, my_FL2 , my_FL3, my_FL4, my_FL5,a,b = sess.run([a0, a1, MSE, CE, WCE, FL1, FL2, FL3, FL4, FL5,a,b])

            #print(a)
            #print(b)
            #print("p1", p1,"CE,WCE,FL", my_CE, my_WCE, my_FL)
            #print("prob", prob, "p1", p1,"CE,WCE,FL", my_CE, my_WCE, my_FL1, my_FL2, my_FL5, my_FL7)
            tmp_result = [prob, p, my_MSE, my_CE, my_WCE, my_FL1, my_FL2 , my_FL3, my_FL4, my_FL5 ]
            print("##################", tmp_result, my_a0, my_a1)
            result_list.append(tmp_result)
            FL1_list.append(my_FL1)
            FL2_list.append(my_FL2)
            FL3_list.append(my_FL3)
            FL4_list.append(my_FL4)
            FL5_list.append(my_FL5)
            CE_list.append(my_CE)
            WCE_list.append(my_WCE)
            MSE_list.append(my_MSE)
            proportions_list.append(p)
            probabilities_list.append(prob)

            #pint("prob", prob, "p1", p1,"CE,WCE,FL", my_CE, my_WCE, my_FL1, my_FL2, my_FL3, my_FL4, my_FL5)
            
            #print("prob", prob, "p1", p1,"CE,WCE,FL", my_CE, my_WCE, my_FL5)
            #print("prob", prob, "p1", p1,"CE,WCE,FL", my_CE, my_WCE, my_FL1, my_FL2 , my_FL3, my_FL4, my_FL5, my_FL6, my_FL7)
            #print("p1", p1,"CE,WCE,FL", my_CE, my_WCE)
    
        sess.close()

################## [0.05, 0.005, 0.0045125, 14.978661367769956, 1497.8661702569134, 1497.8661702569134, 1459.9393517540743, 1422.9728617440678, 1351.824218656864, 1159.0202894709287] 1.0101010098729544 100.00000223517424
################## [0.05, 0.01, 0.009025, 29.95732273553991, 1497.8661702569134, 1497.8661702569134, 1459.9393517540743, 1422.9728617440678, 1351.824218656864, 1159.0202894709287] 1.0204081627998387 50.00000111758712
################## [0.05, 0.06, 0.05415, 179.74393641323945, 1497.866170256913, 1497.866170256913, 1459.9393517540739, 1422.9728617440674, 1351.8242186568646, 1159.0202894709284] 1.13636363290004 8.333333519597852
################## [0.05, 0.11, 0.099275, 329.5305500909392, 1497.8661448933385, 1497.8661448933385, 1459.93932703272, 1422.9728376486712, 1351.8241957662385, 1159.0202698450776] 1.282051280091892 4.545454570084564
################## [0.05, 0.23, 0.20757500000000031, 689.018422917418, 1497.8661096048907, 1497.8661096048907, 1459.9392926377918, 14

################## [0.5, 0.11, 0.0275, 76.24618986159396, 346.5735921579178, 346.5735921579178, 245.06453719504435, 173.2867960789589, 86.64339803947945, 10.830424754934931] 1.282051280091892 4.545454570084564
################## [0.5, 0.23, 0.0575, 159.42385152878734, 346.57358399293946, 346.57358399293946, 245.06453142153282, 173.28679199646973, 86.64339599823487, 10.830424499779358] 1.8518518804686228 2.173913004042296
################## [0.5, 0.34, 0.085, 235.67004139038153, 346.57358663454977, 346.57358663454977, 245.06453328943326, 173.28679331727488, 86.64339665863744, 10.83042458232968] 3.125000069849195 1.4705882198257845
################## [0.5, 0.5, 0.125, 346.5735902799731, 346.6082568572519, 346.6082568572519, 245.08904883901224, 173.30412842862594, 86.65206421431297, 10.831508026789122] 9998.340882002383 1.0001000265982527
################## [0.6, 0.005, 0.0008000000000000001, 2.5541281188299534, 255.41281759191673, 255.41281759191673, 161.5372494383159, 102.1651270367667,

In [29]:
#for i in range(len(result_list)):
#    print(result_list[i])

fig = plt.figure()
ax = plt.axes()

#FL1_list = np.divide(FL1_list, np.max(FL1_list))
#FL2_list = np.divide(FL1_list, np.max(FL2_list))
#FL3_list = np.divide(FL1_list, np.max(FL3_list))
#FL4_list = np.divide(FL1_list, np.max(FL4_list))
#FL5_list = np.divide(FL1_list, np.max(FL5_list))

proportions = [0.005, 0.01, 0.06, 0.11, 0.23,0.34, 0.50]

Y1 = []
Y2 = []
Y3 = []
Y4 = []
Y5 = []
Y6 = []
Y7 = []

for i in range(len(proportions_list)):
    if(proportions_list[i]==0.005):
        #X1 = FL1_list[i]
        #X2 = FL2_list[i]
        #X3 = FL3_list[i]
        Y1.append(FL3_list[i])
        #X4 = FL4_list[i]
        #X5 = FL5_list[i]
    if(proportions_list[i]==0.01):
        Y2.append(FL3_list[i])
    if(proportions_list[i]==0.06):
        Y3.append(FL3_list[i])
    if(proportions_list[i]==0.11):
        Y4.append(FL3_list[i])
    if(proportions_list[i]==0.23):
        Y5.append(FL3_list[i])
    if(proportions_list[i]==0.34):
        Y6.append(FL3_list[i])
    if(proportions_list[i]==0.50):
        Y7.append(FL3_list[i])

    #Y1 = np.divide(Y1, np.max(Y1))
    #Y2 = np.divide(Y2, np.max(Y2))
    #Y3 = np.divide(Y3, np.max(Y3))
    #Y4 = np.divide(Y4, np.max(Y4))
    #Y5 = np.divide(Y5, np.max(Y5))
    #Y6 = np.divide(Y6, np.max(Y6))
    #Y7 = np.divide(Y7, np.max(Y7))


X = probabilities
ax.plot(X, Y1)
ax.plot(X, Y2)
ax.plot(X, Y3)
ax.plot(X, Y4)
ax.plot(X, Y5)
ax.plot(X, Y6)
ax.plot(X, Y7)

plt.legend(['MAF = 0.005', 'MAF = 0.01', 'MAF = 0.06', 'MAF = 0.11', 'MAF = 0.23', 'MAF = 0.34', 'MAF = 0.50'], loc='upper right')
#plt.show()

fig.savefig('test.png') # save the figure to file
plt.close(fig) # close the figure






X = probabilities

for p in proportions:

    Y1 = []
    Y2 = []
    Y3 = []
    Y4 = []
    Y5 = []
    Y6 = []
    Y7 = []

    fig = plt.figure()
    ax = plt.axes()

    for i in range(len(proportions_list)):
        if(proportions_list[i]==p):
            Y1.append(FL1_list[i])
            Y2.append(FL2_list[i])
            Y3.append(FL3_list[i])
            Y4.append(FL4_list[i])
            Y5.append(FL5_list[i])
            Y6.append(MSE_list[i])
     
    print(p, Y6)

    Y1 = np.divide(Y1, np.max(Y1))
    Y2 = np.divide(Y2, np.max(Y2))
    Y3 = np.divide(Y3, np.max(Y3))
    Y4 = np.divide(Y4, np.max(Y4))
    Y5 = np.divide(Y5, np.max(Y5))
    Y6 = np.divide(Y6, np.max(Y6))
    #Y7 = np.divide(Y7, np.max(Y7))


    #Y1 = gaussian_filter1d(Y1, sigma=2)
    #Y2 = gaussian_filter1d(Y2, sigma=2)
    #Y3 = gaussian_filter1d(Y3, sigma=2)
    #Y4 = gaussian_filter1d(Y4, sigma=2)
    #Y5 = gaussian_filter1d(Y5, sigma=2)
    #Y6 = gaussian_filter1d(Y6, sigma=2)


    ax.plot(X, Y1)
    ax.plot(X, Y2)
    ax.plot(X, Y3)
    ax.plot(X, Y4)
    ax.plot(X, Y5)
    ax.plot(X, Y6)
    #ax.plot(X, Y7)

    plt.xlabel('predicted probability')
    plt.ylabel('cumulative scaled loss')


    plt.legend(['g = 0', 'g = 0.5', 'g = 1', 'g = 2', 'g = 5', 'MSE'], loc='upper right')

    fig.savefig('FL_MAF_'+str(p)+'_ns.png') # save the figure to file
    plt.close(fig) # close the figure

    fig = plt.figure()
    ax = plt.axes()


    xnew = np.linspace(np.min(X),np.max(X),10000) #300 represents number of points to make between T.min and T.max

    spl = make_interp_spline(X, Y1) #BSpline object
    s1 = spl(xnew)
    spl = make_interp_spline(X, Y2) #BSpline object
    s2 = spl(xnew)
    spl = make_interp_spline(X, Y3) #BSpline object
    s3 = spl(xnew)
    spl = make_interp_spline(X, Y4) #BSpline object
    s4 = spl(xnew)
    spl = make_interp_spline(X, Y5) #BSpline object
    s5 = spl(xnew)
    spl = make_interp_spline(X, Y6) #BSpline object
    s6 = spl(xnew)


    ax.plot(xnew, s1)
    ax.plot(xnew, s2)
    ax.plot(xnew, s3)
    ax.plot(xnew, s4)
    ax.plot(xnew, s5)
    ax.plot(xnew, s6)

    plt.xlabel('predicted probability')
    plt.ylabel('cumulative scaled loss')

    plt.legend(['g = 0', 'g = 0.5', 'g = 1', 'g = 2', 'g = 5', 'MSE'], loc='upper right')
    #plt.show()

    fig.savefig('FL_MAF'+str(p)+'_s.png') # save the figure to file

    plt.close(fig) # close the figure




fig = plt.figure()
ax = plt.axes()


xnew = np.linspace(np.min(X),np.max(X),100) #300 represents number of points to make between T.min and T.max
#xnew = np.arange(0, 1, 0.001)

f1 = interp1d(X, Y1,kind='cubic')
s1 = f1(xnew)
f2 = interp1d(X, Y2,kind='cubic')
s2 = f2(xnew)
f3 = interp1d(X, Y3,kind='cubic')
s3 = f3(xnew)
f4 = interp1d(X, Y4,kind='cubic')
s4 = f4(xnew)
f5 = interp1d(X, Y5,kind='cubic')
s5 = f5(xnew)
f6 = interp1d(X, Y6,kind='cubic')
s6 = f6(xnew)

a=len(X)

from scipy.signal import savgol_filter
s1 = savgol_filter(Y1, a, 3)
s2 = savgol_filter(Y2, a, 3)
s3 = savgol_filter(Y3, a, 3)
s4 = savgol_filter(Y4, a, 3)
s5 = savgol_filter(Y5, a, 3)
s6 = savgol_filter(Y6, a, 3)

ax.plot(X, s1)
ax.plot(X, s2)
ax.plot(X, s3)
ax.plot(X, s4)
ax.plot(X, s5)
ax.plot(X, s6)
plt.legend(['g = 0', 'g = 0.5', 'g = 1', 'g = 2', 'g = 5', 'MSE'], loc='upper right')
#plt.show()
plt.xlabel('predicted probability')
plt.ylabel('cumulative scaled loss')

fig.savefig('test2_s2.png') # save the figure to file

plt.close(fig) # close the figure


0.005 [0.0045125, 0.004050000000000001, 0.0032000000000000006, 0.00245, 0.0018, 0.00125, 0.0008000000000000001, 0.0004500000000000001, 0.0001999999999999999, 4.9999999999999975e-05, 0.0]
0.01 [0.009025, 0.008100000000000001, 0.006400000000000001, 0.0049, 0.0036, 0.0025, 0.0016000000000000003, 0.0009000000000000002, 0.0003999999999999998, 9.999999999999995e-05, 0.0]
0.06 [0.05415, 0.048600000000000025, 0.03840000000000002, 0.029400000000000003, 0.0216, 0.015, 0.009600000000000004, 0.005400000000000003, 0.0023999999999999994, 0.0005999999999999998, 0.0]
0.11 [0.099275, 0.0891, 0.07040000000000005, 0.0539, 0.0396, 0.0275, 0.01760000000000001, 0.009900000000000003, 0.004400000000000001, 0.0011000000000000003, 0.0]
0.23 [0.20757500000000031, 0.1863000000000001, 0.1472000000000001, 0.11269999999999984, 0.08279999999999994, 0.0575, 0.03680000000000003, 0.020699999999999986, 0.009200000000000005, 0.0023000000000000013, 0.0]
0.34 [0.3068500000000006, 0.2754000000000003, 0.21760000000000015, 0.1