In [None]:
import pandas as pd
import os
import json
import ast
from utils import *

In [None]:
# traits = pd.read_csv("latest_task.csv")['Trait'].tolist()
pairs = pd.read_csv("trait_condition_pairs.csv")

In [None]:
def normalize_trait(trait):
    trait = '-'.join(trait.split())
    normalized_trait = ''.join(trait.split("'"))
    return normalized_trait

In [None]:
rel = pd.read_csv("trait_related_genes.csv").drop(columns=["Unnamed: 0"])
rel['Related_Genes'] = rel['Related_Genes'].apply(ast.literal_eval)
t2g = pd.Series(rel['Related_Genes'].values, index=rel['Trait']).to_dict()  # the mapping from trait to genes

In [None]:
t2g

In [None]:
suc = 0
for i, (index, row) in enumerate(pairs.iterrows()):
    try:
        trait, condition = row['Trait'], row['Condition']
        # if trait not in ['Hypertension'] or condition != 'Age': continue #'Lung Cancer',
        nm_trait = normalize_trait(trait)
        nm_condition = normalize_trait(condition)
        print(nm_trait, nm_condition)
        trait_dir = os.path.join('/home/techt/Desktop/a4s/gold_subset', nm_trait)
        output_dir = os.path.join('./output', nm_trait)
        os.makedirs(output_dir, exist_ok=True)

        if suc >= 10:
            break

        if condition in ['Age', 'Gender']:
            trait_cohort_id, _ = filter_and_rank_cohorts(os.path.join(trait_dir, 'cohort_info.json'), condition)
            trait_data_path = os.path.join(trait_dir, trait_cohort_id + '.csv')
            trait_data = pd.read_csv(trait_data_path).astype('float')
            remove_col = 'Age' if condition == 'Gender' else 'Gender'
            if remove_col in trait_data.columns:
                trait_data = trait_data.drop(columns=[remove_col])

            X = trait_data.drop(columns=[trait, condition]).values
            Y = trait_data[trait].values
            Z = trait_data[condition].values

            has_batch_effect = detect_batch_effect(X)
            if has_batch_effect:
                model_constructor = VariableSelection
                model_params = {'modified': True, 'lamda': 3e-4}
            else:
                model_constructor = Lasso
                model_params = {'alpha': 1.0, 'random_state': 42}

            trait_type = 'binary' if len(np.unique(Y)) == 2 else 'continuous'
            cv_mean, cv_std = cross_validation(X, Y, Z, model_constructor, model_params, target_type=trait_type)

            normalized_X, _ = normalize_data(X)
            normalized_Z, _ = normalize_data(Z)

            model = ResidualizationRegressor(model_constructor, model_params)
            model.fit(normalized_X, Y, normalized_Z)

            feature_cols = trait_data.columns.tolist()
            feature_cols.remove(trait)
            threshold = 0.05
            interpret_result(model, feature_cols, trait, condition, threshold=threshold, save_output=True,
                             output_dir=output_dir)
            suc += 1

        else:
            condition_dir = os.path.join('/home/techt/Desktop/a4s/gold_subset', nm_condition)
            trait_cohort_id, _ = filter_and_rank_cohorts(os.path.join(trait_dir, 'cohort_info.json'))
            condition_cohort_id, _ = filter_and_rank_cohorts(os.path.join(condition_dir, 'cohort_info.json'))
            trait_data_path = os.path.join(trait_dir, trait_cohort_id + '.csv')
            condition_data_path = os.path.join(condition_dir, condition_cohort_id + '.csv')

            trait_data = pd.read_csv(trait_data_path).astype('float')
            condition_data = pd.read_csv(condition_data_path).astype('float')

            related_genes = t2g[condition]
            regressors = get_gene_regressors(trait, trait_data, condition_data, related_genes)
            if regressors is None:
                print(f'No gene regressors for trait {trait} and condition {condition}')
                continue

            print("Common gene regressors for condition and trait", regressors)
            X_condition = condition_data[regressors].values
            Y_condition = condition_data[condition].values

            condition_type = 'binary' if len(np.unique(Y_condition)) == 2 else 'continuous'

            if condition_type == 'binary':
                if X_condition.shape[1] > X_condition.shape[0]:
                    model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42)
                else:
                    model = LogisticRegression()
            else:
                if X_condition.shape[1] > X_condition.shape[0]:
                    model = Lasso()
                else:
                    model = LinearRegression()

            normalized_X_condition, _ = normalize_data(X_condition)
            model.fit(normalized_X_condition, Y_condition)

            regressors_in_trait = trait_data[regressors].values
            normalized_regressors_in_trait, _ = normalize_data(regressors_in_trait)
            if condition_type == 'binary':
                predicted_condition = model.predict_proba(normalized_regressors_in_trait)[:, 1]
            else:
                predicted_condition = model.predict(normalized_regressors_in_trait)

            trait_data[condition] = predicted_condition
            trait_data = trait_data.drop(columns=regressors)
            trait_data = trait_data.drop(columns=['Age', 'Gender'], errors='ignore')

            X = trait_data.drop(columns=[trait, condition]).values
            Y = trait_data[trait].values
            Z = trait_data[condition].values

            has_batch_effect = detect_batch_effect(X)
            if has_batch_effect:
                model_constructor = VariableSelection
                model_params = {'modified': True, 'lamda': 3e-4}
            else:
                model_constructor = Lasso
                model_params = {'alpha': 1.0, 'random_state': 42}

            trait_type = 'binary' if len(np.unique(Y)) == 2 else 'continuous'
            cv_mean, cv_std = cross_validation(X, Y, Z, model_constructor, model_params, target_type=trait_type)

            normalized_X, _ = normalize_data(X)
            normalized_Z, _ = normalize_data(Z)

            model = ResidualizationRegressor(model_constructor, model_params)
            model.fit(normalized_X, Y, normalized_Z)

            feature_cols = trait_data.columns.tolist()
            feature_cols.remove(trait)
            threshold = 0.05
            interpret_result(model, feature_cols, trait, condition, threshold=threshold, save_output=True,
                             output_dir=output_dir)
            suc += 1
    except Exception as e:
        print(f"Error processing row {i}, for the trait '{trait}' and the condition '{condition}'\n: {e}")
        continue



In [None]:
p = pd.read_csv('data/preprocessed/Lung-Cancer/Xena.csv')

In [None]:
p.columns

In [None]:
"""data_dir = './preprocessed'
for trait in os.listdir(data_dir):
    trait_dir = os.path.join(data_dir, trait)
    json_path = os.path.join(trait_dir, 'cohort_info.json')
    assert os.path.isfile(json_path)
    best_cohort_id, _ = filter_and_rank_cohorts(json_path)
    trait_data = os.path.join(trait_dir, best_cohort_id)
"""
