# LLM is a great rule-based feature engineer in few-shot tabular learning
## Overview
This notebook runs training and inference for few-shot tabular learning task over benchmark datasets. GPT-4o and GPT-3.5-turbo model is used in this tutorial.

## Overall process
* Feature selection
* Prepare datasets
* Extract rules for prediction from training samples with the help of LLM
* Parse rules to the program code and convert data into the transformed datasets with rules
* Train models with each of the datasets
* Make inference with ensembling

**DISCLAIMER:**
This code is inspired from the open-source project FeatLLM developed by Sungwon Han.
You can find the original repository at: https://github.com/Sungwon-Han/FeatLLM.


In [126]:
import os
import numpy as np
from importlib import reload
import utils
reload(utils)
from tqdm import tqdm
from sklearn.tree import export_graphviz
import graphviz

## Prepare datasets
1. Set dataset and simulation parameters (e.g. # of training shots, and the random seed)
2. Get SNPs data and split it into train/test dataset, given simulation parameters

In [95]:
_NUM_QUERY = 20 # Number of Queries
_SHOT = 20 # Number of training shots 
_SEED = 0 # Seed for fixing randomness
_NUM_OF_CONDITIONS = 15
_NUM_OF_CONDITIONS_FOR_INTERACTIONS = 0
_DATA = "ancestry_15_features"
_MODEL = "gpt-4o-2024-05-13"
_FUNCTION_MODEL = "gpt-3.5-turbo"
_REWRITING_FUNCTION_MODEL = "gpt-4-1106-preview"
_PROMPT_VERSION = "v0"
_NOTE = "" # Start note with a dash
_RECORD_LOGS = True
_METADATA_VERSION = "v0" 
utils.set_seed(_SEED)

now, let's get the dataset

In [96]:
df, X_train, X_test, y_train, y_test, target_attr, label_list, is_cat = utils.get_dataset(_DATA, _SHOT, _SEED)
X_all = df.drop(target_attr, axis=1)

## Extract rules for prediction from training samples with the help of LLM
To enable the LLM to extract rules based on a more accurate reasoning path, we guided the problem-solving process to mimic how a person might approach a tabular learning task.   

We divided the problem into two sub-tasks for this purpose:   
1. Understand the task description and the features provided by the data, inferring the causal relationships beforehand.   
2. Use the inferred information and few-shot samples to deduce the prediction rules for each class. This two-step reasoning process prevents the model from identifying spurious correlations in irrelevant columns and assists in focusing on more significant features.   

Our prompt comprises three main components as follows:  
* Task description
* Reasoning instruction
* Response instruction

In [97]:
# Initialization
if 'ancestry' in _DATA:
    _DATA_TYPE = 'ancestry'
else:
    _DATA_TYPE = 'hearing_loss'
if "gpt" in _MODEL:
    ask_file_name = f'./templates/ask_llm_{_PROMPT_VERSION}_{_DATA_TYPE}.txt'
else: 
    ask_file_name = f'./templates/ask_llm_llama_{_PROMPT_VERSION}.txt'
    
meta_data_name = f"../data/{_DATA}-metadata-{_METADATA_VERSION}.json"
templates, feature_desc = utils.get_prompt_for_asking(
    _DATA, X_all, X_train, y_train, label_list, target_attr, ask_file_name, 
    meta_data_name, is_cat, num_query=_NUM_QUERY, num_conditions=_NUM_OF_CONDITIONS,
    prompt_version =_PROMPT_VERSION
)
print(templates[0])

{}
You are an expert in genetics. Given the task description and the list of features and data examples, you are extracting and engineering novel features to solve the task. The purpose of this process is to generate a set of rich, dense and robust features that better express the data.

## Task
What is the subject's genomic ancestry? European, South Asian, East Asian, African, or American?


## Features
- rs671:  (numerical variable with categories [0,1,2])
- rs1426654:  (numerical variable with categories [0,1,2])
- rs16891982:  (numerical variable with categories [0,1,2])
- rs4988235:  (numerical variable with categories [0,1,2])
- rs12913832:  (numerical variable with categories [0,1,2])
- rs2814778:  (numerical variable with categories [0,1,2])
- rs1042602:  (numerical variable with categories [0,1,2])
- rs10498746:  (numerical variable with categories [0,1,2])
- rs3827760:  (numerical variable with categories [0,1,2])
- rs2192416:  (numerical variable with categories [0,1,2])
- r

In [98]:
_DIVIDER = "\n\n---DIVIDER---\n\n"
_VERSION = "\n\n---VERSION---\n\n"

rule_file_name = f'./rules/{_DATA}/{_SHOT}_shot/rule-s{_SHOT}-c{_NUM_OF_CONDITIONS}{_PROMPT_VERSION}-{_MODEL}-q{_NUM_QUERY}-{_SEED}{_NOTE}.out'
if os.path.isfile(rule_file_name) == False:
    results = utils.query_gpt(templates, max_tokens=2000, temperature=1, model = _MODEL)
    if _RECORD_LOGS:
        with open(rule_file_name, 'w') as f:
            total_rules = _DIVIDER.join(results)
            f.write(total_rules)
else:
    with open(rule_file_name, 'r') as f:
        total_rules_str = f.read().strip()
        results = total_rules_str.split(_DIVIDER)

print(results[0])

100%|██████████| 20/20 [03:32<00:00, 10.61s/it]

Sure, let's proceed step by step to analyze and engineer features from the provided SNPs (single nucleotide polymorphisms) to predict the genomic ancestry of individuals. Here's a structured approach:

### Step 1: Understand the Role of Individual SNPs
1. **rs671 (ALDH2)**: Known to be a strong marker in East Asian populations.
2. **rs1426654 (SLC24A5)**: Associated with lighter skin pigmentation, common in European populations.
3. **rs16891982 (SLC45A2)**: The G allele is associated with lighter skin pigmentation, typically found in Europeans.
4. **rs4988235 (LCT)**: The T allele indicates lactase persistence, common in European populations.
5. **rs12913832 (HERC2/OCA2)**: The A allele is associated with lighter eye color, particularly blue, common in Northern Europeans.
6. **rs2814778 (DARC)**: The T allele indicates resistance to malaria, common in African populations.
7. **rs1042602 (TYR)**: The A allele is associated with lighter skin, common in various populations.
8. **rs1049874




In [99]:
parsed_rules = []

# Iterate through each result in the results list
for result in results:
    # Use utils.query_gpt to transform each result
    transformed_result = utils.query_gpt(
        [f"Extract the list of engineered features and list them one after another in a new line: {result}\n\nIf some features are clumped up together, list them separately. Also, make sure to include the equation/instruction for each feature. \n\nList:"], 
        max_tokens=2000, 
        temperature=0, 
        model=_FUNCTION_MODEL
    )
    # Append the transformed result to the results_transformed list
    parsed_rules.append(transformed_result[0])

# The parsed_rules list now contains all the transformed results
print(parsed_rules)


100%|██████████| 1/1 [00:03<00:00,  3.34s/it]
100%|██████████| 1/1 [00:01<00:00,  1.21s/it]
100%|██████████| 1/1 [00:01<00:00,  1.15s/it]
100%|██████████| 1/1 [00:01<00:00,  1.65s/it]
100%|██████████| 1/1 [00:01<00:00,  1.25s/it]
100%|██████████| 1/1 [00:01<00:00,  1.52s/it]
100%|██████████| 1/1 [00:00<00:00,  1.14it/s]
100%|██████████| 1/1 [00:01<00:00,  1.45s/it]
100%|██████████| 1/1 [00:02<00:00,  2.96s/it]
100%|██████████| 1/1 [00:01<00:00,  1.74s/it]
100%|██████████| 1/1 [00:01<00:00,  1.52s/it]
100%|██████████| 1/1 [00:01<00:00,  1.43s/it]
100%|██████████| 1/1 [00:03<00:00,  3.98s/it]
100%|██████████| 1/1 [00:02<00:00,  2.20s/it]
100%|██████████| 1/1 [00:01<00:00,  1.19s/it]
100%|██████████| 1/1 [00:01<00:00,  1.85s/it]
100%|██████████| 1/1 [00:01<00:00,  1.00s/it]
100%|██████████| 1/1 [00:02<00:00,  2.20s/it]
100%|██████████| 1/1 [00:01<00:00,  1.40s/it]
100%|██████████| 1/1 [00:02<00:00,  2.44s/it]

['- rs671\n- rs1426654\n- rs16891982\n- rs4988235\n- rs12913832\n- rs2814778\n- rs10498746\n- rs3827760\n- interaction_rs1426654_rs16891982: rs1426654 * rs16891982\n- interaction_rs671_rs10498746: rs671 * rs10498746\n- sum_rs10498746_rs3827760: rs10498746 + rs3827760\n- interaction_rs2814778_rs12913832: rs2814778 * rs12913832\n- interaction_rs1042602_rs1344011: rs1042602 * rs1344011\n- rs1426654_ge_1: rs1426654 ≥ 1\n- rs12913832_ge_1: rs12913832 ≥ 1\n- is_east_asian: (rs671 ≥ 1) AND (rs10498746 ≥ 1)\n- is_european: (rs1426654 ≥ 1) AND (rs16891982 ≥ 1)\n- is_african: (rs2814778 ≥ 1) AND (rs12913832 ≤ 1)', '- rs12913832 * rs16891982\n- rs671 + rs10498746\n- rs2814778 + rs1878685\n- rs1426654 * rs16891982 * rs4988235\n- rs1426654 + rs16891982\n- rs1390723 + rs3814381\n- rs12913832 AND rs16891982\n- rs2814778 == 2 AND rs1878685 == 1\n- rs671 == 1 AND rs10498746 >= 1', '- rs1426654\n- rs16891982\n- rs4988235\n- European_marker = (rs1426654 >= 1) + (rs16891982 >= 1) + (rs4988235 >= 1)\n- Afr




In [117]:
print(parsed_rules[0])

- rs671
- rs1426654
- rs16891982
- rs4988235
- rs12913832
- rs2814778
- rs10498746
- rs3827760
- interaction_rs1426654_rs16891982: rs1426654 * rs16891982
- interaction_rs671_rs10498746: rs671 * rs10498746
- sum_rs10498746_rs3827760: rs10498746 + rs3827760
- interaction_rs2814778_rs12913832: rs2814778 * rs12913832
- interaction_rs1042602_rs1344011: rs1042602 * rs1344011
- rs1426654_ge_1: rs1426654 ≥ 1
- rs12913832_ge_1: rs12913832 ≥ 1
- is_east_asian: (rs671 ≥ 1) AND (rs10498746 ≥ 1)
- is_european: (rs1426654 ≥ 1) AND (rs16891982 ≥ 1)
- is_african: (rs2814778 ≥ 1) AND (rs12913832 ≤ 1)


## Parse rules to the program code and convert data into the binary vector

We utilize the rules generated in the previous stage to transform each sample into a binary vector. These vectors are created for each answer class, indicating whether the sample satisfies the rules associated with that class. However, since the rules generated by the LLM are based on natural language, parsing the text into program code is required for automatic data transformation.  

To address the challenges of parsing noisy text, instead of building complex program code, we leverage the LLM itself. We include the function name, input and output descriptions, and inferred rules in the prompt, then input it into the LLM. The generated code is executed using Python’s exec() function along with the provided function name to perform data conversion.

In [101]:
reload(utils)
_DIVIDER = "\n\n---DIVIDER---\n\n"
_VERSION = "\n\n---VERSION---\n\n"

saved_file_name = f'./rules/{_DATA}/{_SHOT}_shot/function-s{_SHOT}-c{_NUM_OF_CONDITIONS}{_PROMPT_VERSION}-{_MODEL}-{_FUNCTION_MODEL}-q{_NUM_QUERY}-{_SEED}{_NOTE}.out'    
if os.path.isfile(saved_file_name) == False:
    function_file_name = './templates/ask_for_function_v2.txt'
    fct_strs_all = []
    for parsed_rule in tqdm(parsed_rules):
        fct_template = utils.get_prompt_for_generating_function_simple(
            parsed_rule, feature_desc, function_file_name
        )
        fct_results = utils.query_gpt(fct_template, max_tokens=2500, temperature=0, model = _FUNCTION_MODEL)
        fct_strs = [fct_txt.split('<start>')[1].split('<end>')[0].strip() for fct_txt in fct_results]
        fct_strs_all.append(fct_strs[0])
    if _RECORD_LOGS:
        with open(saved_file_name, 'w') as f:
            total_str = _VERSION.join([x for x in fct_strs_all])
            f.write(total_str)
else:
    with open(saved_file_name, 'r') as f:
        total_str = f.read().strip()
        fct_strs_all = [x for x in total_str.split(_VERSION)]

  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:05<00:00,  5.28s/it]
100%|██████████| 1/1 [00:03<00:00,  3.91s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.66s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.23s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.87s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.35s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.86s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.23s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.51s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.15s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.61s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.47s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.78s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.44s/it]t]
100%|██████████| 1/1 [00:05<00:00,  5.47s/it]t]
100%|██████████| 1/1 [00:02<00:00,  2.01s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.70s/it]t]
100%|██████████| 1/1 [00:04<00:00,  4.44s/it]t]
100%|██████████| 1/1 [00:05<00:00,  5.80s/it]t]
100%|██████████| 20/20 [01:09<00:00,  3.45s/it]


In [102]:
# Examine function output
print(fct_strs_all[0])

def extracting_engineered_features(df_input):
    df_output = pd.DataFrame()
    df_output['rs671'] = df_input['rs671']
    df_output['rs1426654'] = df_input['rs1426654']
    df_output['rs16891982'] = df_input['rs16891982']
    df_output['rs4988235'] = df_input['rs4988235']
    df_output['rs12913832'] = df_input['rs12913832']
    df_output['rs2814778'] = df_input['rs2814778']
    df_output['rs10498746'] = df_input['rs10498746']
    df_output['rs3827760'] = df_input['rs3827760']
    df_output['interaction_rs1426654_rs16891982'] = df_input['rs1426654'] * df_input['rs16891982']
    df_output['interaction_rs671_rs10498746'] = df_input['rs671'] * df_input['rs10498746']
    df_output['sum_rs10498746_rs3827760'] = df_input['rs10498746'] + df_input['rs3827760']
    df_output['interaction_rs2814778_rs12913832'] = df_input['rs2814778'] * df_input['rs12913832']
    df_output['interaction_rs1042602_rs1344011'] = df_input['rs1042602'] * df_input['rs1344011']
    df_output['rs1426654_ge_1'] = df_inp

#### Self-Critiqueing Function Writing

In [103]:
reload(utils)
critique_fct_strs_all = utils.self_critique_functions(parsed_rules, feature_desc, fct_strs_all, X_train, _NUM_OF_CONDITIONS, _NUM_OF_CONDITIONS_FOR_INTERACTIONS, _REWRITING_FUNCTION_MODEL, condition_tolerance=30)

In [104]:
if _RECORD_LOGS:
    with open(saved_file_name, 'w') as f:
        total_str = _VERSION.join([x for x in critique_fct_strs_all])
        f.write(total_str)

In [105]:
# Get function names and strings
fct_names = []
fct_strs_final = []
for fct_str in critique_fct_strs_all:
    if 'def' not in fct_str:
        continue
    fct_names.append(fct_str.split('def')[1].split('(')[0].strip())
    fct_strs_final.append(fct_str)

### Generating Transformed Datasets

In [106]:
mask = X_test.notna().all(axis=1)

# Dropping weird NAs
X_test = X_test[mask]
y_test = y_test[mask]

In [107]:
executable_list, X_train_all_dict, X_test_all_dict = utils.convert_to_binary_vectors_simple(fct_strs_final, 
                                                                                     fct_names, 
                                                                                     X_train, 
                                                                                     X_test, 
                                                                                     num_of_features=_NUM_OF_CONDITIONS,
                                                                                     include_original_features=True)

In [108]:
# The number of functions should be == # of transformed datasets. If lower than expected, some of the functions are faulty and were dropped
len(executable_list)

20

## Train the linear model to predict the likelihood of each class from the binary vector
When given the rules for each class and a sample, a simple method to measure the class likelihood of the sample is to count how many rules of each class it satisfies (i.e., the sum of the binary vector per class). However, not all rules carry the same importance, necessitating learning their significance from training samples.    
  
We aimed to train this importance using a basic linear model without bias, applied to each class's binary vector.

In [109]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

multiclass = True if len(label_list) > 2 else False
y_train_num = np.array([label_list.index(k) for k in y_train])
y_test_num = np.array([label_list.index(k) for k in y_test])
single_model = LogisticRegression(random_state=_SEED)

# Fit the model
single_model.fit(X_train, y_train_num)
lr_pred_probs_train = single_model.predict_proba(X_train)
lr_metrics_train = utils.evaluate(lr_pred_probs_train, y_train_num, multiclass=multiclass, class_level_analysis=True, label_list=label_list)
lr_pred_probs_test = single_model.predict_proba(X_test)
lr_metrics_test = utils.evaluate(lr_pred_probs_test, y_test_num, multiclass=multiclass, class_level_analysis=True, label_list=label_list)
lr_metrics_test

{'AUC': 0.969554334497049,
 'Accuracy': 0.8706070287539937,
 'F1-Score': 0.8506037121536921,
 'Class African Ancestry Precision': 0.9589442815249267,
 'Class African Ancestry Recall': 0.9879154078549849,
 'Class African Ancestry F1-Score': 0.9732142857142857,
 'Class American Ancestry Precision': 0.8524590163934426,
 'Class American Ancestry Recall': 0.30057803468208094,
 'Class American Ancestry F1-Score': 0.4444444444444445,
 'Class East Asian Ancestry Precision': 0.8723404255319149,
 'Class East Asian Ancestry Recall': 0.9761904761904762,
 'Class East Asian Ancestry F1-Score': 0.9213483146067415,
 'Class European Ancestry Precision': 0.775,
 'Class European Ancestry Recall': 0.9841269841269841,
 'Class European Ancestry F1-Score': 0.8671328671328671,
 'Class South Asian Ancestry Precision': 0.875,
 'Class South Asian Ancestry Recall': 0.889344262295082,
 'Class South Asian Ancestry F1-Score': 0.8821138211382114}

In [110]:

models = []

# Train an MLP model on each version of the training data
for X_train_now, X_test_now in zip(X_train_all_dict.values(), X_test_all_dict.values()):
    model = LogisticRegression(random_state=_SEED)
    model.fit(X_train_now, y_train_num)
    models.append(model)
    lr_pred_probs_train = model.predict_proba(X_train_now)
    lr_metrics_train = utils.evaluate(lr_pred_probs_train, y_train_num, multiclass=multiclass, class_level_analysis=True, label_list=label_list)
    lr_pred_probs_test = model.predict_proba(X_test_now)
    lr_metrics_test = utils.evaluate(lr_pred_probs_test, y_test_num, multiclass=multiclass, class_level_analysis=True, label_list=label_list)
    print("num of features: ", X_train_now.shape[1])
    print(lr_metrics_test['AUC'])

# Initialize arrays to store ensemble predictions
ensemble_pred_probs_train = np.zeros((X_train_all_dict[0].shape[0], len(label_list)))
ensemble_pred_probs_test = np.zeros((X_test_all_dict[0].shape[0], len(label_list)))

# Predict probabilities for training and test sets using each model and combine them
for i, (X_train_now, X_test_now) in enumerate(zip(X_train_all_dict.values(), X_test_all_dict.values())):
    ensemble_pred_probs_train += models[i].predict_proba(X_train_now)
    ensemble_pred_probs_test += models[i].predict_proba(X_test_now)
    

# Average the probabilities
ensemble_pred_probs_train /= len(X_train_all_dict)
ensemble_pred_probs_test /= len(X_test_all_dict)

# Evaluate the ensemble predictions
ensemble_metrics_train = utils.evaluate(
    ensemble_pred_probs_train, 
    y_train_num, 
    multiclass=multiclass, 
    class_level_analysis=True, 
    label_list=label_list
)

ensemble_metrics_test = utils.evaluate(
    ensemble_pred_probs_test, 
    y_test_num, 
    multiclass=multiclass, 
    class_level_analysis=True, 
    label_list=label_list
)

# Output the test metrics
ensemble_metrics_test

num of features:  33
0.9697826369775411
num of features:  24
0.9635840445335706
num of features:  21
0.9680005236557422
num of features:  24
0.9702244901307093
num of features:  21
0.9701163242912226
num of features:  22
0.9655161057134105
num of features:  20
0.9695989813395048
num of features:  21
0.9647511094239517
num of features:  24
0.9717911408647584
num of features:  24
0.9641433621424074
num of features:  21
0.97032004497116
num of features:  34
0.9686770495844191
num of features:  23
0.9595605871568985
num of features:  23
0.9694344855483854
num of features:  22
0.9660964290859682
num of features:  28
0.9705685493466685
num of features:  20
0.9707145224572409
num of features:  23
0.9705227367411494
num of features:  23
0.9743781348988744
num of features:  37
0.9712535050934834


{'AUC': 0.9730741192810519,
 'Accuracy': 0.8753993610223643,
 'F1-Score': 0.8562916444502596,
 'Class African Ancestry Precision': 0.9646017699115044,
 'Class African Ancestry Recall': 0.9879154078549849,
 'Class African Ancestry F1-Score': 0.9761194029850746,
 'Class American Ancestry Precision': 0.8059701492537313,
 'Class American Ancestry Recall': 0.31213872832369943,
 'Class American Ancestry F1-Score': 0.45,
 'Class East Asian Ancestry Precision': 0.8848920863309353,
 'Class East Asian Ancestry Recall': 0.9761904761904762,
 'Class East Asian Ancestry F1-Score': 0.9283018867924528,
 'Class European Ancestry Precision': 0.8071895424836601,
 'Class European Ancestry Recall': 0.9801587301587301,
 'Class European Ancestry F1-Score': 0.8853046594982079,
 'Class South Asian Ancestry Precision': 0.8473282442748091,
 'Class South Asian Ancestry Recall': 0.9098360655737705,
 'Class South Asian Ancestry F1-Score': 0.8774703557312252}

### Random Forest Analysis

In [119]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

multiclass = True if len(label_list) > 2 else False
y_train_num = np.array([label_list.index(k) for k in y_train])
y_test_num = np.array([label_list.index(k) for k in y_test])
single_model = RandomForestClassifier(random_state=_SEED)

# Fit the model
single_model.fit(X_train, y_train_num)
lr_pred_probs_train = single_model.predict_proba(X_train)
lr_metrics_train = utils.evaluate(lr_pred_probs_train, y_train_num, multiclass=multiclass, class_level_analysis=True, label_list=label_list)
lr_pred_probs_test = single_model.predict_proba(X_test)
lr_metrics_test = utils.evaluate(lr_pred_probs_test, y_test_num, multiclass=multiclass, class_level_analysis=True, label_list=label_list)
lr_metrics_test

{'AUC': 0.9578963880654356,
 'Accuracy': 0.8538338658146964,
 'F1-Score': 0.8405902174391501,
 'Class African Ancestry Precision': 0.9480122324159022,
 'Class African Ancestry Recall': 0.9365558912386707,
 'Class African Ancestry F1-Score': 0.9422492401215805,
 'Class American Ancestry Precision': 0.7764705882352941,
 'Class American Ancestry Recall': 0.3815028901734104,
 'Class American Ancestry F1-Score': 0.5116279069767441,
 'Class East Asian Ancestry Precision': 0.8531468531468531,
 'Class East Asian Ancestry Recall': 0.9682539682539683,
 'Class East Asian Ancestry F1-Score': 0.9070631970260223,
 'Class European Ancestry Precision': 0.78125,
 'Class European Ancestry Recall': 0.9920634920634921,
 'Class European Ancestry F1-Score': 0.8741258741258742,
 'Class South Asian Ancestry Precision': 0.8504273504273504,
 'Class South Asian Ancestry Recall': 0.8155737704918032,
 'Class South Asian Ancestry F1-Score': 0.8326359832635982}

In [133]:

estimator = single_model.estimators_[4]

# Export the tree to Graphviz format
dot_data = export_graphviz(estimator, out_file=None,
                           class_names=label_list,
                           filled=True, rounded=True,
                           special_characters=True)

# Visualize the tree using graphviz
graph = graphviz.Source(dot_data)  
graph.render("plots/tree")  # Save the tree as a file
graph.view()  # View the tree

'plots/tree.pdf'

In [128]:
models = []

# Train an MLP model on each version of the training data
for X_train_now, X_test_now in zip(X_train_all_dict.values(), X_test_all_dict.values()):
    model = RandomForestClassifier(random_state=_SEED)
    model.fit(X_train_now, y_train_num)
    models.append(model)
    lr_pred_probs_train = model.predict_proba(X_train_now)
    lr_metrics_train = utils.evaluate(lr_pred_probs_train, y_train_num, multiclass=multiclass, class_level_analysis=True, label_list=label_list)
    lr_pred_probs_test = model.predict_proba(X_test_now)
    lr_metrics_test = utils.evaluate(lr_pred_probs_test, y_test_num, multiclass=multiclass, class_level_analysis=True, label_list=label_list)
    print("num of features: ", X_train_now.shape[1])
    print(lr_metrics_test['AUC'])

# Initialize arrays to store ensemble predictions
ensemble_pred_probs_train = np.zeros((X_train_all_dict[0].shape[0], len(label_list)))
ensemble_pred_probs_test = np.zeros((X_test_all_dict[0].shape[0], len(label_list)))

# Predict probabilities for training and test sets using each model and combine them
for i, (X_train_now, X_test_now) in enumerate(zip(X_train_all_dict.values(), X_test_all_dict.values())):
    ensemble_pred_probs_train += models[i].predict_proba(X_train_now)
    ensemble_pred_probs_test += models[i].predict_proba(X_test_now)
    

# Average the probabilities
ensemble_pred_probs_train /= len(X_train_all_dict)
ensemble_pred_probs_test /= len(X_test_all_dict)

# Evaluate the ensemble predictions
ensemble_metrics_train = utils.evaluate(
    ensemble_pred_probs_train, 
    y_train_num, 
    multiclass=multiclass, 
    class_level_analysis=True, 
    label_list=label_list
)

ensemble_metrics_test = utils.evaluate(
    ensemble_pred_probs_test, 
    y_test_num, 
    multiclass=multiclass, 
    class_level_analysis=True, 
    label_list=label_list
)

# Output the test metrics
ensemble_metrics_test

num of features:  33
0.9604055638360265
num of features:  24
0.963850644330997
num of features:  21
0.9651194211347158
num of features:  24
0.9608693997409599
num of features:  21
0.9658239533808324
num of features:  22
0.9592185415826915
num of features:  20
0.9648763181844668
num of features:  21
0.965003526008376
num of features:  24
0.9668977361621801
num of features:  24
0.9650423256509164
num of features:  21
0.9587053244582086
num of features:  34
0.9640501067793199
num of features:  23
0.9550881419397588
num of features:  23
0.9624929262985589
num of features:  22
0.9608536632277399
num of features:  28
0.9629612690313311
num of features:  20
0.9625805686343346
num of features:  23
0.9581236817206495
num of features:  23
0.9718080775111352
num of features:  37
0.959202106286489


{'AUC': 0.969950951589888,
 'Accuracy': 0.8490415335463258,
 'F1-Score': 0.8311596051541861,
 'Class African Ancestry Precision': 0.9936102236421726,
 'Class African Ancestry Recall': 0.9395770392749244,
 'Class African Ancestry F1-Score': 0.9658385093167702,
 'Class American Ancestry Precision': 0.7428571428571429,
 'Class American Ancestry Recall': 0.30057803468208094,
 'Class American Ancestry F1-Score': 0.4279835390946502,
 'Class East Asian Ancestry Precision': 0.82,
 'Class East Asian Ancestry Recall': 0.9761904761904762,
 'Class East Asian Ancestry F1-Score': 0.8913043478260869,
 'Class European Ancestry Precision': 0.7830188679245284,
 'Class European Ancestry Recall': 0.9880952380952381,
 'Class European Ancestry F1-Score': 0.8736842105263158,
 'Class South Asian Ancestry Precision': 0.8167330677290837,
 'Class South Asian Ancestry Recall': 0.8401639344262295,
 'Class South Asian Ancestry F1-Score': 0.8282828282828283}

In [134]:
estimator = model.estimators_[4]

# Export the tree to Graphviz format
dot_data = export_graphviz(estimator, out_file=None,
                           class_names=label_list,
                           filled=True, rounded=True,
                           special_characters=True)

# Visualize the tree using graphviz
graph = graphviz.Source(dot_data)  
graph.render("plots/ensemble_tree")  # Save the tree as a file
graph.view()  # View the tree

'plots/ensemble_tree.pdf'