In [1]:
import numpy as np
import csv
# import sys
import random
from sklearn.preprocessing import PolynomialFeatures
import statsmodels.api as sm
from sklearn import linear_model
import matplotlib.pyplot as plt 
import gc
from tqdm.notebook import tqdm

# choose statistical or biochemical epistasis
#ep_type = 'biochem' 
ep_type = 'stat'

# read in data
geno_vectors_MA90 = []
phenos_MA90 = []

mutations_H1 = [str(x) for x in range(1,17)]


with open('../../Kd_Inference/results_CH65/Kd_processed/20221008_CH65_QCfilt_REPfilt.csv','r') as readfile:
    kd_reader = csv.reader(readfile)
    header = next(kd_reader)
    for row in kd_reader:
        geno = row[0]
        
        geno_vec = np.array([float(x) for x in geno])

        pheno_MA90 = row[3] # row for MA90
        
            
        if len(pheno_MA90) != 0:  
            geno_vectors_MA90.append(geno_vec)
            phenos_MA90.append(float(pheno_MA90))
    readfile.close()



In [2]:
phenos_MA90 = np.array(phenos_MA90)

genos_MA90 = np.empty((len(phenos_MA90),len(geno_vectors_MA90[0])))
for i in range(len(phenos_MA90)):
    genos_MA90[i] = geno_vectors_MA90[i][:]
if ep_type == 'stat':
    genos_MA90 = 2*(genos_MA90-0.5)    


num_folds = 8
max_order = 7

# proportion of data to be tested 
prop_test = 0.1

size_test_MA90 = int(prop_test*len(genos_MA90))
size_train_MA90 = len(genos_MA90)-size_test_MA90

# lists to store r squared values
rsq_train_list_MA90 = np.zeros((num_folds, max_order+1))
rsq_test_list_MA90 = np.zeros((num_folds, max_order+1))



# loop over CV folds
for f in tqdm(range(num_folds)):
    #randomly selects 
    indices_permuted_MA90 = random.sample(range(0,len(genos_MA90)), size_test_MA90)

    genos_train_MA90 = np.delete(genos_MA90.copy(), indices_permuted_MA90, 0)
    genos_test_MA90 = genos_MA90[indices_permuted_MA90].copy()
    phenos_train_MA90 = np.delete(phenos_MA90, indices_permuted_MA90, 0)
    phenos_test_MA90 = phenos_MA90[indices_permuted_MA90].copy()

    # fit models of increasing order
    for order in range(0,max_order+1):
        reg_MA90_current = linear_model.Ridge(alpha=0.01, solver='lsqr', fit_intercept=False)
        poly_MA90_current = PolynomialFeatures(order,interaction_only=True)
        genos_train_MA90_current = poly_MA90_current.fit_transform(genos_train_MA90)
        genos_test_MA90_current = poly_MA90_current.fit_transform(genos_test_MA90)
        reg_MA90_current.fit(genos_train_MA90_current, phenos_train_MA90)
        reg_MA90_coefs_current  = reg_MA90_current.coef_

        #reg_MA90_current_predict = reg_MA90_coefs_current
        rsquared_train_MA90_current = 1-np.sum((phenos_train_MA90-reg_MA90_current.predict(genos_train_MA90_current))**2)/np.sum((phenos_train_MA90-np.mean(phenos_train_MA90))**2)
        rsquared_test_MA90_current = 1-np.sum((phenos_test_MA90-reg_MA90_current.predict(genos_test_MA90_current))**2)/np.sum((phenos_test_MA90-np.mean(phenos_test_MA90))**2)
        rsq_train_list_MA90[f, order] = rsquared_train_MA90_current
        rsq_test_list_MA90[f, order] = rsquared_test_MA90_current
        
        #print(rsquared_train_MA90_current)
        #print(rsquared_test_MA90_current)              
    del reg_MA90_current
    del indices_permuted_MA90
    del genos_train_MA90
    del genos_test_MA90
    del phenos_train_MA90
    del phenos_test_MA90
    del reg_MA90_coefs_current
    del poly_MA90_current
    gc.collect()
        


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8.0), HTML(value='')))




In [3]:
import pandas as pd
lst = []
df = pd.DataFrame()
for f in range(num_folds):
    for o in range(0,max_order+1):
        lst += [(f, o, rsq_train_list_MA90[f, o], rsq_test_list_MA90[f,o])]
df = pd.DataFrame(lst, columns=["fold_nb", "order", "train", "test"])
df.to_csv(f"r2_CV_{ep_type}_MA90.csv", index=False)
df.groupby("order").agg({"train":"mean", "test": "mean"})

Unnamed: 0_level_0,train,test
order,Unnamed: 1_level_1,Unnamed: 2_level_1
0,-1.434922e-11,-0.000183
1,0.7550479,0.754806
2,0.9564684,0.956421
3,0.9800033,0.979851
4,0.9868478,0.985986
5,0.9900636,0.987514
6,0.9927433,0.987323
7,0.9949761,0.983613
