In [10]:
import numpy as np
import csv
# import sys
import random
from sklearn.preprocessing import PolynomialFeatures
import statsmodels.api as sm
from sklearn import linear_model
import matplotlib.pyplot as plt 
import gc
from tqdm.notebook import tqdm

# choose statistical or biochemical epistasis
# ep_type = 'biochem' 
ep_type = 'stat'

# read in data
geno_vectors_SI06 = []
phenos_SI06 = []

mutations_H1 = [str(x) for x in range(1,17)]


with open('../../Kd_Inference/results_CH65/Kd_processed/20221008_CH65_QCfilt_REPfilt.csv','r') as readfile:
    kd_reader = csv.reader(readfile)
    header = next(kd_reader)
    for row in kd_reader:
        geno = row[0]
        
        geno_vec = np.array([float(x) for x in geno])

        pheno_SI06 = row[7] # row for SI06
        
        if len(pheno_SI06) != 0 and row[23] == '1':  
            geno_vectors_SI06.append(geno_vec)
            phenos_SI06.append(float(pheno_SI06))

    readfile.close()



In [11]:
phenos_SI06 = np.array(phenos_SI06)

genos_SI06 = np.empty((len(phenos_SI06),len(geno_vectors_SI06[0])))
for i in range(len(phenos_SI06)):
    genos_SI06[i] = geno_vectors_SI06[i][:]
    
if ep_type == 'stat':
    genos_SI06 = 2*(genos_SI06-0.5)  

num_folds = 8
max_order = 7

# proportion of data to be tested 
prop_test = 0.1

size_test_SI06 = int(prop_test*len(genos_SI06))
size_train_SI06 = len(genos_SI06)-size_test_SI06

# lists to store r squared values
rsq_train_list_SI06 = np.zeros((num_folds, max_order+1))
rsq_test_list_SI06 = np.zeros((num_folds, max_order+1))



# loop over CV folds
for f in tqdm(range(num_folds)):
    #randomly selects 
    indices_permuted_SI06 = random.sample(range(0,len(genos_SI06)), size_test_SI06)

    genos_train_SI06 = np.delete(genos_SI06.copy(), indices_permuted_SI06, 0)
    genos_test_SI06 = genos_SI06[indices_permuted_SI06].copy()
    phenos_train_SI06 = np.delete(phenos_SI06, indices_permuted_SI06, 0)
    phenos_test_SI06 = phenos_SI06[indices_permuted_SI06].copy()

    # fit models of increasing order
    for order in range(0,max_order+1):
        reg_SI06_current = linear_model.Ridge(alpha=0.01, solver='lsqr', fit_intercept=False)
        poly_SI06_current = PolynomialFeatures(order,interaction_only=True)
        genos_train_SI06_current = poly_SI06_current.fit_transform(genos_train_SI06)
        genos_test_SI06_current = poly_SI06_current.fit_transform(genos_test_SI06)
        reg_SI06_current.fit(genos_train_SI06_current, phenos_train_SI06)
        reg_SI06_coefs_current  = reg_SI06_current.coef_

        #reg_SI06_current_predict = reg_SI06_coefs_current
        rsquared_train_SI06_current = 1-np.sum((phenos_train_SI06-reg_SI06_current.predict(genos_train_SI06_current))**2)/np.sum((phenos_train_SI06-np.mean(phenos_train_SI06))**2)
        rsquared_test_SI06_current = 1-np.sum((phenos_test_SI06-reg_SI06_current.predict(genos_test_SI06_current))**2)/np.sum((phenos_test_SI06-np.mean(phenos_test_SI06))**2)
        rsq_train_list_SI06[f, order] = rsquared_train_SI06_current
        rsq_test_list_SI06[f, order] = rsquared_test_SI06_current
        
        #print(rsquared_train_SI06_current)
        #print(rsquared_test_SI06_current)              
    del reg_SI06_current
    del indices_permuted_SI06
    del genos_train_SI06
    del genos_test_SI06
    del phenos_train_SI06
    del phenos_test_SI06
    del reg_SI06_coefs_current
    del poly_SI06_current
    gc.collect()
        


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8.0), HTML(value='')))




In [12]:
import pandas as pd
lst = []
df = pd.DataFrame()
for f in range(num_folds):
    for o in range(0,max_order+1):
        lst += [(f, o, rsq_train_list_SI06[f, o], rsq_test_list_SI06[f,o])]
df = pd.DataFrame(lst, columns=["fold_nb", "order", "train", "test"])
df.to_csv(f"r2_CV_{ep_type}_SI06.csv", index=False)

In [13]:
df.groupby("order").agg({"train":"mean", "test": "mean"})

Unnamed: 0_level_0,train,test
order,Unnamed: 1_level_1,Unnamed: 2_level_1
0,-8.2791e-12,-0.000413
1,0.6359801,0.634829
2,0.8784113,0.877702
3,0.9334966,0.931288
4,0.9525848,0.9455
5,0.9666007,0.951147
6,0.977326,0.945859
7,0.9867621,0.925501
