In [1]:
# do this if you use a custom virtual environment to specify the directory to look for installed packages
import sys
import os
import pandas as pd
## goes back to the project directory
os.chdir("..")
# switch to the name of your virtual environment
kernel_name = ".venv_mp"
sys.path.append("\\".join([os.getcwd(), kernel_name, "Lib\\site-packages"]))

In [2]:
from node import create_samples, save_samples
directory = "//".join([os.getcwd(), "data", "bank-marketing", "reduced", "raw_node_data.csv"])
raw_data = pd.read_csv(directory, header =0, index_col = 0)
n_samples = 100
samples = create_samples(n_samples, raw_data)
save_samples(samples, "bank")

In [2]:
from node import get_node_data
directory = "//".join([os.getcwd(), "data", "bank-marketing", "samples"])
samples = [get_node_data(pd.read_csv(directory + f"//sample_{i+1}.csv")) for i in range(100)]

In [4]:
from modelling import select_model_data  
from similar import get_similar_pairs_nodes
from hypothesis_testing import test_hypothesis

def get_results(raw_node_data):
    results = []
    for balanced in [False, True]:
        rnd = raw_node_data
        clf_name = "lr_balanced"
        if balanced:
            balanced_node_data = []
            clf_name = "lr"
            for df in raw_node_data:
                yes = df.loc[df.label == "yes"]
                no = df.loc[df.label == "no"]
#                 print(df.shape[0], yes.shape[0], no.shape[0])
                if yes.shape[0] < no.shape[0]:
                    balanced_df = [yes, no.sample(yes.shape[0])]
                else:
                    balanced_df = [no, yes.sample(no.shape[0])]
                balanced_node_data.append(pd.concat(balanced_df).sample(frac=1).reset_index(drop=True))
            rnd = balanced_node_data
            
        node_data, similar_pairs, similar_nodes, asmmd, mmd_scores, ocsvm_scores = get_similar_pairs_nodes(rnd, balanced)
        
        if similar_pairs != []:
            print(f"{similar_pairs} (balanced={balanced})", end=" ")
            model_data = select_model_data(node_data, similar_nodes)
            df = test_hypothesis(clf_name, model_data, similar_pairs, similar_nodes, mmd_scores, ocsvm_scores)
            df["asmmd"] = [asmmd] * df.shape[0]
            df["balanced"] = [balanced] * df.shape[0]
            results.append(df)
    if results != []:
        return pd.concat(results, ignore_index = True)
    else:
        return pd.DataFrame()

In [4]:
get_results(samples[0])

['pi4', 'pi5'] ['pi1', 'pi2', 'pi3']
[('pi1', 'pi2'), ('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) ['pi1', 'pi2', 'pi3'] ['pi4', 'pi5']
[('pi1', 'pi2'), ('pi1', 'pi3'), ('pi4', 'pi5')] (balanced=True) 

Unnamed: 0,model_node,model,model_r2,train_time,optimisation_time,test_node,discrepancy,model_r2-d,test_r2,mmd_score,ocsvm_score,asmmd,balanced
0,pi1,"LogisticRegression(C=0.01, class_weight='balan...",0.85,0.02,2.27,pi2,0.01,0.840329,0.827984,0.12,0.0,0.149243,False
1,pi1,"LogisticRegression(C=0.01, class_weight='balan...",0.85,0.02,2.27,pi3,0.02,0.832099,0.819753,0.03,0.01,0.149243,False
2,pi2,"LogisticRegression(class_weight='balanced', ma...",0.83,0.03,2.61,pi1,0.0,0.826051,0.847486,0.12,0.06,0.149243,False
3,pi2,"LogisticRegression(class_weight='balanced', ma...",0.83,0.03,2.61,pi3,0.02,0.804115,0.819753,0.09,0.99,0.149243,False
4,pi3,"LogisticRegression(C=0.01, class_weight='balan...",0.82,0.04,2.56,pi1,0.02,0.840066,0.847486,0.03,0.07,0.149243,False
5,pi3,"LogisticRegression(C=0.01, class_weight='balan...",0.82,0.04,2.56,pi2,0.02,0.836214,0.827984,0.09,0.99,0.149243,False
6,pi1,LogisticRegression(max_iter=100000),0.84,0.02,1.71,pi2,0.01,0.846753,0.851948,0.27,0.01,0.957293,True
7,pi1,LogisticRegression(max_iter=100000),0.84,0.02,1.71,pi3,0.05,0.886243,0.886243,0.51,0.01,0.957293,True
8,pi2,LogisticRegression(max_iter=100000),0.85,0.02,2.12,pi1,0.0,0.84788,0.840399,0.27,0.14,0.957293,True
9,pi3,LogisticRegression(max_iter=100000),0.89,0.04,1.67,pi1,0.04,0.842893,0.840399,0.51,0.12,0.957293,True


In [9]:
def run(samples):
    for sample_id in range(11,100):
        print(f"Sample {sample_id+1}", end=": ")
        results = get_results(samples[sample_id])
        results.to_csv(f"results/bank-marketing/reduced/sample_{sample_id+1}.csv", index=False)
        print()

In [10]:
run(samples)

Sample 12: [('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi2'), ('pi1', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 13: [('pi2', 'pi3')] (balanced=False) [('pi1', 'pi2'), ('pi1', 'pi3'), ('pi2', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 14: [('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 15: [('pi1', 'pi2'), ('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi3'), ('pi2', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 16: [('pi1', 'pi2'), ('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi3'), ('pi2', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 17: [('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi3'), ('pi2', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 18: [('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi3'), ('pi2', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 19: [('pi1', 'pi2'), ('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi3'), (

Sample 80: [('pi1', 'pi2'), ('pi1', 'pi3')] (balanced=False) [('pi1', 'pi2'), ('pi1', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 81: [('pi2', 'pi3')] (balanced=False) [('pi1', 'pi3'), ('pi2', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 82: [('pi1', 'pi3')] (balanced=False) [('pi1', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 83: [('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi2'), ('pi1', 'pi3'), ('pi2', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 84: [('pi1', 'pi3')] (balanced=False) [('pi1', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 85: [('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi2'), ('pi4', 'pi5')] (balanced=True) 
Sample 86: [('pi2', 'pi3')] (balanced=False) [('pi2', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 87: [('pi2', 'pi3')] (balanced=False) [('pi1', 'pi2'), ('pi1', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
Sample 88: [('pi1', 'pi3'), ('pi2', 'pi3')] (balanced=False) [('pi1', 'pi3'), ('pi4', 'pi5')] (balanced=True) 
S