In [None]:
# %load C:\Users\Patron\Desktop\model_project\Lab\3SpeciesModel\trainModelAndTest.py
"""
Created on Wed Aug  5 09:29:26 2020

This model contains functions that train the linear regression models and evaluate the models 
by computing the pearson, spearman coefficient and % correct signs 

@author: ZC
"""

from sklearn import linear_model
import numpy as np 
from sklearn.model_selection import cross_validate
from matplotlib import pyplot as plt
from scipy.stats import spearmanr
from scipy.stats import pearsonr 

In [None]:
def gLVMtrProcess(gLVMtr):
    """ This function delete the self growth rate coefficient from the gLV interaction coefficient matrix for easier comparison"""
    aij = gLVMtr
    gLV_coef = np.zeros([10,9]) #visually varified 
    gi,gj = 0,0 #deleting self growth rate terms 
    for i in range(10):
        gj=0
        for j in range(10):
            if(i!=j):
                gLV_coef[gi,gj] = aij[i,j]
                gj+=1
        if(i!=j): 
            gi+=1 
    
    return gLV_coef
    


##############################################################################


def ridgeCVModel( normedData, i, fold, non_present ):
    """This function train a linear regression model with ridge regularization
    Input: 
        normedData is the dataset that has already been normalized (n*10)
        i is the index of targeted species (y's column index)
        fold is the number of cross validations 
        non_present: boolean mtr indicating the location indices of absent species (value is 1 if absent ) 
    Return:
        mdl: the model object
        //output is a list of trained parameters (1*10) with the first element as the         bias term """
    alphas = np.logspace(-5,2,9)
    alphas = np.insert(alphas, 0, 0 )
    scores = np.zeros(10) #scores for the different alphas 
    tr_set = normedData.copy()
    
    #exclude absent species 
    tr_set= tr_set[non_present[:,i] == 0]
    tr_x = np.delete(tr_set, i, axis = 1 )
    tr_y = tr_set[:,i]
    
    for j in range(10): #try 10 different alphas and use the one resulting in the smallest mse
        mdl2 =linear_model.Ridge(alphas[j], solver='sag')

        cv_results = cross_validate(mdl2, tr_x, tr_y, cv=fold, scoring= 'neg_mean_squared_error')
        score = sum(cv_results['test_score'])/10 #mse over the 10 validations 
        scores[j] = score 
    
    inx = np.argmin(scores)  #find the smallest alpha 
    mdl2 = linear_model.Ridge(alpha = alphas[inx], solver= 'sag')
    mdl2.fit(tr_x, tr_y)
    
    
    return mdl2

def ridgeLOOModel( normedData, i, non_present ):
    """This function trains a linear regression moel with ridge regularization, Leave one out cross validation used to choose alpha
     Input: 
        normedData is the dataset that has already been normalized (n*10)
        i is the index of targeted species(y's column index)
    Return: 
        mdl: the model object
        //output is a list of trained parameters (1*10) with the first element as the bias term"""
    
    mdl = linear_model.RidgeCV(alphas =np.logspace(-5,5, 15), cv=None, scoring = 'neg_mean_squared_error', store_cv_values=True)
    tr_set = normedData.copy()
    
    #exclude absent species 
    tr_set= tr_set[non_present[:,i] == 0]
    tr_x = np.delete(tr_set, i, axis = 1 )
    tr_y = tr_set[:,i]
    
    mdl.fit(tr_x, tr_y)
    
    
    return mdl 

def trainModelAllSp(data, regMode, fold=10, normed=False, thresh=0) :
    """This fucntion train n linear regression model where n is the number of species of data, models can be trained with normalized data or non-normalized data, and the normalization status should be specified 
    regModel: LOO or FOLD, if it is FOLD, fold number can be specified
    normed:boolean value, set to 1 if passed in data is normalized with transDataWithNorm module 
    Return: a list of trained models; len= number of species """
    
    n_sp = data.shape[1]
    models = list()
    non_present = np.zeros(data.shape) #create a boolean mtr with the same size of data, element = 1 if the data point at the position is not present in the community
    
    if(not normed):
        #get the non-present matrix 
        data[data<thresh] = 0
        absence = np.argwhere(data == 0) # a list of indices (row, col) of the species endOD that is too small or not present in the raw data; shape is x*2
        for ind in absence:
            non_present[ind[0],ind[1]] = 1
        np.savetxt("./interm/no_present_noNorm.csv", non_present, delimiter=',')
    
    else:
        #read non_present mtr from file
        non_present = np.genfromtxt("./interm/no_present.csv", delimiter=',')
    
   
    for i in range(n_sp):
        
        if regMode == 'LOO':
            output = ridgeLOOModel(data,i, non_present)
        if regMode == 'FOLD':
            output = ridgeCVModel(data, i, fold, non_present)
        models.append(output) #put the trained model into the model list 
    
    return models

def thetaMtr( modelList):
    """This function returns the mtr of trained parameters from trained models
    modelList: a list of trained model objects
    Return thetaMtr: 
        a n*n mtr where each colume represents the parameters trained for the prediction of the species indexed by the column number.
        n is the number of species in the model 
        the first element of each col is the bias term, the rest 9 are weights.
        n column in total  """
   
    n = len(modelList)
    ret = np.zeros([n,n]) #the return mtr 
    for i in range(n):
        bias = modelList[i].intercept_
        thetas = modelList[i].coef_
        output =np.insert(thetas,0,  bias) # a list of numbers of len = n
        ret[:,i] = output 
   
    return ret

In [None]:
def processMtrForCompare( thetaMtr, gLVMtr, filt=False, thresh = 0.01):
    '''This fucntion falttens theta and gLV mtr and deletes bias terms from theta metr and self-growth rate values from gLV mtr
    thetaMtr: n*n trained parameter mtr extracted from the trained models 
    gLVMtr: interaction coefficient mtr used to run the gLV model to generate raw data
    filt: filter out small values from the 2 matrices for comparison 
    thresh: filtering threshold for comparison 
    Return: 
        [thetaFlat, glvFlat] a list of thetas and gLV parameters ready to be compared for spearman/pearson/plot'''
    
    gLVMtr2 = gLVMtrProcess(gLVMtr)
    thetaNoBias = np.delete(thetaMtr,0, axis=0)
    #flatten trained params to a 90 element list ( del theta0) -> [th11,th12,..th19, th21,th22,..th29]
    thetaFlat = np.transpose(thetaNoBias)
    thetaFlat = thetaFlat.flatten()
    glvFlat = gLVMtr2.flatten() #gLVMtr2 has alreay had its diagnose values deleted 
    
    if(filt):
        largeInd = np.argwhere(thetaFlat>thresh)
        largeInd2 = np.argwhere(thetaFlat<-1*thresh)
        t = thetaFlat[largeInd]
        t2 = thetaFlat[largeInd2]
        thetaFlat = np.concatenate((t,t2))
        g = glvFlat[largeInd]
        g2 = glvFlat[largeInd2]
        glvFlat = np.concatenate((g,g2))
        
        
        
        #flatten gLV coef to a 90 element list --> [a1<-2, a1<3,..a1<-10, a2<-1,...a10<-9]
    
    return [thetaFlat, glvFlat]

def calcSpPe(mtr, gLVMtr, filtparam, localthresh):
    """This function compute the spearman and pearson coef of a given mtr of trained parameters and the gLV parameters
    mtr: parameter mtr extracted from trained models 
    gLVMtr: aij mtr used to generate the raw data
    filtparam: boolean value to be set true if you want to filter out small parameters for comparison 
    localthresh: threshold value for the filteration local to this function 
    
    Return a list: [Spearman coef, Pearson Coef] """
    [ls,g] = processMtrForCompare(mtr, gLVMtr, filtparam, localthresh )
    np.savetxt("./interm/paramls.csv", ls, delimiter=',')
    np.savetxt("./interm/gLVLs.csv", g, delimiter=',')
    ls = np.genfromtxt("./interm/paramLs.csv", delimiter=',')
    g= np.genfromtxt("./interm/gLVLs.csv", delimiter= ',')
    sp, psp = spearmanr(ls, g) #psp is the p-value 
    pe, ppe = pearsonr(ls,g) #ppe is the p-value 
    return [sp,pe] 
    
    
def calcPctCorrectSign(thetaMtr, gLVMtr , calcPctfilt = False, calcPctThr = 0):
    """This fucntion output the percentage of points from thetaMtr that have signs aligning with corresponding gLV parameters
    
    calcPctfilt: boolean value to set if you want to filter out small values before calculate the percentage 
    calPctThr:threshold value you want to filter small values if the filt boolean is set to true
    
    Return: the percentage of points with the correct signs"""
    [thetaFlat, glvFlat] = processMtrForCompare(thetaMtr, gLVMtr,filt = calcPctfilt, thresh = calcPctThr )
    
    #flatten gLV coef to a 90 element list --> [a1<-2, a1<3,..a1<-10, a2<-1,...a10<-9]
    thetaFlat = np.squeeze(thetaFlat)
    glvFlat = np.squeeze(glvFlat)


    # if glv>0 and theta >0 : r+  (ct++)
    ct = 0
    r = np.argwhere(glvFlat>0)
    plus = np.argwhere(thetaFlat>0)
    minus = np.argwhere(thetaFlat<0)
    z =np.argwhere(thetaFlat==0)
    rplus = np.intersect1d(r,plus)
    ct += rplus.shape[0]
    rminus = np.intersect1d(r,minus)
    rz = np.intersect1d(r,z)
    # if glv>0 and theta <0: r-
    #                     ==0: r0
    # if glv<0 and th>0 : g+
    g =np.argwhere(glvFlat<0)
    gplus =np.intersect1d(g, plus)
    gminus = np.intersect1d(g, minus)
    ct+= gminus.shape[0]
    
    correctPercent = ct/float(glvFlat.shape[0])
    return correctPercent
    
def calcPctNotWrong(thetaMtr, gLVMtr, calcPctfilt = False, calcPctThr = 0):
    """This fucntion output the percentage of points from thetaMtr that have signs aligning with corresponding gLV parameters; other parameters same with calcPctCorrectSign """
    [thetaFlat, glvFlat] = processMtrForCompare(thetaMtr, gLVMtr,filt = calcPctfilt, thresh = calcPctThr )
    
    #flatten gLV coef to a 90 element list --> [a1<-2, a1<3,..a1<-10, a2<-1,...a10<-9]
    
    thetaFlat = np.squeeze(thetaFlat)
    glvFlat = np.squeeze(glvFlat)


    # if glv>0 and theta >0 : r+  (ct++)
    wr = 0
    r = np.argwhere(glvFlat>0)
    plus = np.argwhere(thetaFlat>0)
    minus = np.argwhere(thetaFlat<0)
    rminus = np.intersect1d(r,minus)
    wr+= rminus.shape[0]
    # if glv>0 and theta <0: r-
    #                     ==0: r0
    # if glv<0 and th>0 : g+
    g =np.argwhere(glvFlat<0)
    gplus =np.intersect1d(g, plus)
    wr+= gplus.shape[0]
    
    notWrongPct = 1-( wr/float(glvFlat.shape[0]))
    return notWrongPct

def correlatePlot(thetaMtr, gLVMtr , name='xyz', path= '', filt=False, th = 0.01):
    """this function plot the figure of gLV coef vs trained parameters
    thetaMtr: the n*n mtr of trained parameters with each col as the trained params for each individual model that predict a single species' abundance from other species'. The first param of the col is the bias term 
    gLVMtr is the interaction coef between species used for gLV simulation. Each row represents the coefs between the targeted species and others 
    name is the figure's name'
    path is the path that you want the figure to be saved 
    Return the percentage of points whose signs are aligned with gLV interaction coef 
    """
    
    if(not filt):
        [thetaFlat, glvFlat] = processMtrForCompare(thetaMtr )
    else:
        [thetaFlat, glvFlat] = processMtrForCompare(thetaMtr, filt=True, thresh = th)
    thetaFlat = np.squeeze(thetaFlat)
    glvFlat = np.squeeze(glvFlat)
  
        
    
    #flatten gLV coef to a 90 element list --> [a1<-2, a1<3,..a1<-10, a2<-1,...a10<-9]
    

#print glv (r/g) coef vs thetas (+/-)
    # if glv>0 and theta >0 : r+  (ct++)
    ct = 0
    r = np.argwhere(glvFlat>0)

    plus = np.argwhere(thetaFlat>0)
    minus = np.argwhere(thetaFlat<0)
    z =np.argwhere(thetaFlat==0)
    rplus = np.intersect1d(r,plus)
    ct += rplus.shape[0]
    rminus = np.intersect1d(r,minus)
    rz = np.intersect1d(r,z)
    # if glv>0 and theta <0: r-
    #                     ==0: r0
    # if glv<0 and th>0 : g+
    g =np.argwhere(glvFlat<0)
    gplus =np.intersect1d(g, plus)
    gminus = np.intersect1d(g, minus)
    ct+= gminus.shape[0]
    gz =np.intersect1d(g,z)
    
    correctPercent = ct/float(glvFlat.shape[0])
    print(f"pct: {correctPercent}")

    
    # if glv<0 and th<0: g- (ct++)
    #                     ==0: go
    # correct sign % = ct/90
    plt.figure()
    plt.plot( thetaFlat[rplus],glvFlat[rplus], 'r+')
    plt.plot( thetaFlat[rminus],glvFlat[rminus], 'r.')
    plt.plot( thetaFlat[gplus],glvFlat[gplus], 'g+')
    plt.plot( thetaFlat[gminus],glvFlat[gminus], 'g.')
    plt.plot( thetaFlat[rz],glvFlat[rz], 'r^')
    plt.plot( thetaFlat[gz],glvFlat[gz], 'g^')
    plt.annotate(f"%correct:{correctPercent: .3f}", (0.01,2))
    figTitle = "gLV vs trained param" + name
    plt.title(figTitle)
    figName = name + ".png"
    #plt.savefig(path+figName)
    plt.show()
    plt.close()
    
    return correctPercent