In [None]:
import os
import sys
from dotenv import load_dotenv

load_dotenv(override=True)

sys.path.append(os.environ["WORKING_DIR"])
from os.path import join
from snorkel.labeling import labeling_function
#from snorkel.preprocess import preprocessor
import copy
import json
#import configargparse
from sklearn.metrics.pairwise import cosine_similarity
#from snorkel.labeling import PandasLFApplier
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from sklearn.metrics import classification_report
import numpy as np
from data_loader.utils import load_public_bi_table_by_cols
from scipy.stats import wasserstein_distance, kruskal

labeled_unlabeled_test_split_path = join(os.environ["WORKING_DIR"], "data",
                                         "extract", "out",
                                         "labeled_unlabeled_test_split")

valid_headers_path = join(os.environ["WORKING_DIR"], "data", "extract", "out",
                          "valid_headers")

gen_train_data_path = join(os.environ["WORKING_DIR"], "labeling_functions", "combined_LFs",
                        "gen_training_data")

# LabelEncoder
with open(
        join(os.environ["WORKING_DIR"], "data", "extract", "out",
             "valid_types", "types.json")) as f:
    valid_types = json.load(f)[os.environ["TYPENAME"]]

label_enc = LabelEncoder()
label_enc.fit(valid_types)

In [None]:
label_enc.classes_

In [None]:
labeled_data_size = 5
unlabeled_data_size = "absolute"
test_data_size = 20.0
validation_on = "test"
n_worker = 2
gen_train_data = True
corpus = "public_bi_num"
absolute_numbers = True
random_state = 2

numeric_types = ["X1B",
                 "X2B",
                 "X3B",
                 "TB",
                 "HR",
                 "R",
                 "BB",
                 "AB",
                 "GIDP",
                 "HBP",
                 "H",
                 "SF",
                 "SH",
                 "SO",
                 "iBB",
                 "CS",
                 "SB",
                 "latitude",
                 "longitude",
                 "year"]


In [None]:
#############
## Load data
#############

# load labeled data from labeled, unlabeled, test split file
with open(
        join(
            labeled_unlabeled_test_split_path,
            f"{corpus}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.json"
        )) as f:
    labeled_unlabeled_test_split_file = json.load(f)
    labeled_data_ids = labeled_unlabeled_test_split_file[
        f"labeled{labeled_data_size}"]
    if gen_train_data:
        if absolute_numbers:
            unlabeled_data_ids = labeled_unlabeled_test_split_file[
                f"unlabeled"]
        else:
            unlabeled_data_ids = labeled_unlabeled_test_split_file[
                f"unlabeled{unlabeled_data_size}"]
        print(f"Unlabeled Data: {len(unlabeled_data_ids)}")
    if validation_on == "unlabeled":
        test_data_ids = labeled_unlabeled_test_split_file[
            f"{validation_on}"]
    else:
        test_data_ids = labeled_unlabeled_test_split_file[
            f"{validation_on}{test_data_size}"]

print(f"Labeled Data: {len(labeled_data_ids)}")
print(f"Test Data: {len(test_data_ids)}")

# load the valid headers with real sem. types
valid_header_file = f"{corpus}_{os.environ['TYPENAME']}.json"
valid_headers = join(valid_headers_path, valid_header_file)
with open(valid_headers, "r") as file:
    valid_headers = json.load(file)
# transform valid header into df to make it joinable with word embeddings
valid_header_df_data = []
for table in valid_headers.keys():
    for column in valid_headers[table].keys():
        valid_header_df_data.append([
            table, column, table + "+" + column,
            valid_headers[table][column]["semanticType"]
        ])
valid_header_df = pd.DataFrame(
    valid_header_df_data,
    columns=["table", "column", "dataset_id", "semanticType"])

In [None]:
# filter out unlabeled data from valid_headers
unlabeled_data_df = valid_header_df.loc[
    valid_header_df["dataset_id"].isin(unlabeled_data_ids)]

# load already labeled data
labeled_data_df = valid_header_df.loc[valid_header_df["dataset_id"].isin(labeled_data_ids)]

# load already generated labeled train data
gen_labeled_data_df = pd.read_csv(join(gen_train_data_path, f"public_bi_gen_training_data_all_combined_maj_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), names=["table", "column", "dataset_id", "semanticType"])

total_labeled_data_df = pd.concat([labeled_data_df,gen_labeled_data_df ]).drop_duplicates()

# only unlabaled columns of tyoe numeric
numeric_unlabeled_data_df = unlabeled_data_df.loc[unlabeled_data_df["semanticType"].isin(numeric_types)]

In [None]:
numeric_unlabeled_data_df

In [None]:
print(total_labeled_data_df[total_labeled_data_df["table"] == "CommonGovernment_1"])

In [None]:
numeric_column_to_label = numeric_unlabeled_data_df.iloc[15]
print(f"Numeric col to label: {numeric_column_to_label}")
table_frac=None

# load the table with the numeric column to label
cols_to_load = [numeric_column_to_label["dataset_id"]]
df_cols_to_load = pd.DataFrame({"col_num": [int(col.split(
    "+")[1].split("_")[1]) for col in cols_to_load], "col_header": cols_to_load}).sort_values(by="col_num")
df_table_with_n_col_to_label = load_public_bi_table_by_cols(numeric_column_to_label["table"].split(
    "_")[0], numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)

print(df_table_with_n_col_to_label.head(2))

# search all already labeled numeric cols in the corpus
already_labeled_numeric_cols = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
    numeric_types)].drop_duplicates()

# iterrate over all alread labeled numeric col and do the EMD measure
results = []
for index, row in already_labeled_numeric_cols.iterrows():
    #print(index)
    if index > 700:
        break
    #print(f"curren labeled num col: {row}")
    df_table_with_labeled_numeric = load_public_bi_table_by_cols(row["table"].split(
            "_")[0], row["table"], usecols=[int(row["column"].split("_")[1])], col_headers=[row["dataset_id"]], frac=table_frac)
    #print(f"table wit labeled num col: {df_table_with_labeled_numeric.head(2)}")
    # EMD calc
    emd = wasserstein_distance(df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].to_list(), df_table_with_labeled_numeric[row["dataset_id"]].to_list())
    #print(f"EMD: {emd}")
    results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"],
                           row["dataset_id"], row["semanticType"], emd])

df_results = pd.DataFrame(results, columns=[
                            "unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD"])  # .sort_values(by="EMD")
df_results = df_results[pd.to_numeric(
    df_results['EMD'], errors='coerce').notnull()]
df_results = df_results.sort_values(by="EMD")
df_results

In [None]:
already_labeled_numeric_cols

In [None]:
numeric_unlabeled_data_df = numeric_unlabeled_data_df.reset_index(drop=True)
numeric_unlabeled_data_df[numeric_unlabeled_data_df["semanticType"] == "HBP"]

In [None]:
# 1. search for already labeled string types in the same table of the unlabeled numeric column that we want to label next
numeric_column_to_label = numeric_unlabeled_data_df.iloc[2]
string_already_labeled_cols_in_table = total_labeled_data_df[total_labeled_data_df["table"] == numeric_column_to_label["table"]]
string_already_labeled_cols_in_table = string_already_labeled_cols_in_table.loc[~string_already_labeled_cols_in_table["semanticType"].isin(numeric_types)]
string_already_labeled_cols_in_table

In [None]:
numeric_column_to_label

In [None]:
string_already_labeled_cols_in_table[["dataset_id", "semanticType"]]
string_already_labeled_cols_in_table

In [None]:
string_cols_to_load =  [int(col.split("_")[1]) for col in string_already_labeled_cols_in_table["column"].values]

In [None]:
# load the table with the numeric column to label
cols_to_load = [numeric_column_to_label["dataset_id"]]
string_cols_to_load = list(
    string_already_labeled_cols_in_table["dataset_id"].values)
cols_to_load = cols_to_load + string_cols_to_load
df_cols_to_load = pd.DataFrame({"col_num": [int(col.split("+")[1].split("_")[1]) for col in cols_to_load], "col_header":cols_to_load}).sort_values(by="col_num")
df_table_with_n_col_to_label = load_public_bi_table_by_cols(numeric_column_to_label["table"].split(
    "_")[0], numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)
for string_col in string_cols_to_load:
    df_table_with_n_col_to_label[string_col] = df_table_with_n_col_to_label[string_col].astype(str) 

df_table_with_n_col_to_label.head(10)


In [None]:
string_cols_to_load

## Correlation analysis on 

In [None]:
#### correlation analysis from numeric col to label to the existing string based cols
string_cols_with_corr = []

for string_col in string_cols_to_load:
    # if string_col == "MLB_1+column_37":
    #     break
    print(string_col)
    groups = df_table_with_n_col_to_label[string_col].unique()
    print(f"Number of groups: {len(groups)}")
    print(f"Number of rows: {len(df_table_with_n_col_to_label)}")
    print(len(groups)/len(df_table_with_n_col_to_label))
    if len(groups)/len(df_table_with_n_col_to_label) > 0.1:
        print("To many groups in the column for kruskal test")
        continue
    kruskal_input_groups = [] 
    for group in groups:
        kruskal_input = df_table_with_n_col_to_label[df_table_with_n_col_to_label[string_col] == group][numeric_column_to_label["dataset_id"]].dropna()
        #kruskal_input = kruskal_input[kruskal_input.apply(lambda x: type(pd.to_numeric(x, errors="ignore")) not in [int, float, np.int64, np.float64])]
        if len(kruskal_input) == 0:
            print("kruskal input 0")
            continue
        kruskal_input_groups.append(kruskal_input)
    if len(kruskal_input_groups) < 2:
        print(len(kruskal_input_groups))
        print("kruskal input groups smaller than 2")
        continue
    try:
        F, p = kruskal(*kruskal_input_groups)
        print(f"Kruskal-Wallis test:  \tF:{F}, \tp:{p}")
        if (p < 0.05): #and (p > 0.0):
            string_cols_with_corr.append(string_col)
    except Exception as e:
        print(e)
        continue


string_cols_with_corr = string_already_labeled_cols_in_table.loc[string_already_labeled_cols_in_table["dataset_id"].isin(string_cols_with_corr)]

In [None]:
string_cols_with_corr

In [None]:
# 2. search for a equivalent table which has the same types already labeled to make a join over that
# joinable_tables = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
#     string_already_labeled_cols_in_table["semanticType"])].drop_duplicates()
# joinable_tables_sorted = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby(
#     "table").count().sort_values(by=["semanticType"], ascending=False)

### with correlated string columns
joinable_tables = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
    string_cols_with_corr["semanticType"])].drop_duplicates()
joinable_tables_sorted = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby(
    "table").count().sort_values(by=["semanticType"], ascending=False)


In [None]:
pd.set_option('display.max_rows', 50)
joinable_tables_grouped = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby("table").apply(lambda x: list(x["semanticType"].values)).reset_index().rename(columns={0:"semanticTypes"})
joinable_tables_grouped = joinable_tables_grouped[joinable_tables_grouped["semanticTypes"].map(lambda x: set(x).intersection(set(string_cols_with_corr["semanticType"]))) == set(string_cols_with_corr["semanticType"])]

In [None]:
joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]]

In [None]:
joinable_tables_grouped

In [None]:
# just tables where exactly the corr columns are available
joinable_tables_sorted

In [None]:
#import pyspark.sql.functions as func    
from pyspark.sql.functions import sum

max_group_size = 10
PERCENTAGE_OF_OVERLAP_FOR_JOIN_CANDIDATES = 0.2
# 3. check if wich tables are suitable for a join n groub by
results = []
for index, row in joinable_tables_grouped.iterrows():
    joinable_table = row["table"]
    if joinable_table != "MLB_11":
        continue
    # if row["table"] not in ["MLB_11"]:#["CommonGovernment_2", "MLB_3", "MLB_65", "MLB_26"]:
    #     continue
    print(joinable_table)
    # first check if there is a numeric column in the table which is alread labeled with an numeric type
    numerics_in_joinable_table = total_labeled_data_df[
        total_labeled_data_df["table"] == joinable_table]
    numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
        numeric_types)]
    #print(numerics_in_joinable_table)

    # only if there are numeric cols already labeled in the joinable table, do the join n group by an try to match
    if len(numerics_in_joinable_table) > 0:
        #print(joinable_table)
        #print(numerics_in_joinable_table)
        strings_in_joinable_table = joinable_tables[joinable_tables["table"]
                                                    == joinable_table]
        #print(strings_in_joinable_table)

        # load the joinable table, do the join n groupy by EMD measurement
        # before join, check if there are values overlapping at all
        # this method can also be use in case of two or more candidates with same semantic type which can be use in principle for a join
        # than use the column with more overlapping values
        cols_to_load_for_joinable_table = list(
            numerics_in_joinable_table["dataset_id"].values)
        cols_to_load_for_joinable_table = cols_to_load_for_joinable_table + \
            list(strings_in_joinable_table["dataset_id"].values)
        #print(cols_to_load_for_joinable_table)
        df_cols_to_load_for_joinable_table = pd.DataFrame({"col_num": [int(col.split(
            "+")[1].split("_")[1]) for col in cols_to_load_for_joinable_table], "col_header": cols_to_load_for_joinable_table}).sort_values(by="col_num")
        df_joinable_table = load_public_bi_table_by_cols(joinable_table.split(
            "_")[0], joinable_table, usecols=df_cols_to_load_for_joinable_table["col_num"].values, col_headers=df_cols_to_load_for_joinable_table["col_header"].values)
        
        # typing of the cols for pyspark dataframe
        for string_col in list(strings_in_joinable_table["dataset_id"].values):
            df_joinable_table[string_col] = df_joinable_table[string_col].astype(
                str)
        
        
        #print(df_joinable_table.head())
        # only use cols for join where are overlapping values are present
        cols_to_join = []
        for string_already_labeled_col_in_table_i, string_already_labeled_col_in_table_row in string_cols_with_corr.iterrows():
            for strings_in_joinable_table_i, strings_in_joinable_table_row in strings_in_joinable_table[strings_in_joinable_table["semanticType"] == string_already_labeled_col_in_table_row["semanticType"]].iterrows():
                #print(
                #    string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"])
                #print(
                #    f"overlapping values: {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(df_joinable_table[strings_in_joinable_table_row['dataset_id']]))} from {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))} unique")
                percentage_of_overlap = (len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(
                    df_joinable_table[strings_in_joinable_table_row['dataset_id']]))) / len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))
                print(f"Percentage of overlap: {percentage_of_overlap}")
                if percentage_of_overlap >= PERCENTAGE_OF_OVERLAP_FOR_JOIN_CANDIDATES:
                    if string_already_labeled_col_in_table_row["dataset_id"] not in [x[0] for x in cols_to_join]:
                        cols_to_join.append(
                            [string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"]])
        print("Cols to join:", cols_to_join )

        # if not only all correlated columns are available for join n groupy by in the other table => not use this table
        if len(cols_to_join) != len(string_cols_with_corr):
            print("len(cols_to_join) != len(string_cols_with_corr)")
            continue

        # drop na values in cols to join in both tables => NaN values leads to high memory usage in the merge
        df_table_with_n_col_to_label.dropna(
            subset=[col1 for [col1, col2] in cols_to_join], inplace=True)
        df_joinable_table.dropna(
            subset=[col2 for [col1, col2] in cols_to_join], inplace=True)

        # Transform the pandas dfs into pyspark tables
        sdf_1 = spark.createDataFrame(df_table_with_n_col_to_label)
        sdf_1.createOrReplaceTempView(numeric_column_to_label["table"])
        sdf_2 = spark.createDataFrame(df_joinable_table)
        sdf_2.createOrReplaceTempView(joinable_table)

        ## Approach 3
        join_condition = "ON (" + " AND ".join(
            map(lambda join_att: f"`{join_att[0]}` = `{join_att[1]}`", cols_to_join))
        #print(join_condition)
        projection_list = " , ".join(map(lambda attr: f"`{attr}`", list(set([col1 for [col1, col2] in cols_to_join])) + [
                                     numeric_column_to_label["dataset_id"]] + list(numerics_in_joinable_table["dataset_id"].values)))
        #print(projection_list)
        #print(
        #    f"SELECT {projection_list} FROM {numeric_column_to_label['table']} JOIN {joinable_table} {join_condition})")
        sql_df = spark.sql(
            f"SELECT {projection_list} FROM {numeric_column_to_label['table']} JOIN {joinable_table} {join_condition})")
        # filter out null tupels with null values
        sql_df = sql_df.dropna(subset=list(
            map(lambda cur_col: "`{cur_col}`".format(cur_col=cur_col),
                sql_df.columns)))

        #print(sql_df.show())
        for i, numeric_col_in_joinable_table in numerics_in_joinable_table.iterrows():
            print("EMD Calc:")
            print(numeric_col_in_joinable_table["dataset_id"])
            # check first which values have group size fewer than max group size
            valid_group_size_values = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(count(f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count")).where(col("count") <= max_group_size)
            #print(sql_df.show())
            sql_df.createOrReplaceTempView("sql_df")
            valid_group_size_values.createOrReplaceTempView("sdf_valid_groups")
            
            # preselect only the instances with the max_group_size condition
            join_condition = "ON (" + " AND ".join(
                map(lambda join_att: f"sql_df.`{join_att[0]}` = sdf_valid_groups.`{join_att[0]}`", cols_to_join))
            #print(join_condition)
            projection_list = " , ".join(map(lambda attr: f"sql_df.`{attr}`", list(set([col1 for [col1, col2] in cols_to_join])) + [
                                        numeric_column_to_label["dataset_id"]] + list(numerics_in_joinable_table["dataset_id"].values)))
            #print(
            #    f"SELECT {projection_list} FROM sql_df JOIN sdf_valid_groups {join_condition})")
            sql_df = spark.sql(
                f"SELECT {projection_list} FROM sql_df JOIN sdf_valid_groups {join_condition})")
            print("sql_df count: ",sql_df.count())
            print("valid group size count: ",valid_group_size_values.select(sum("count")).collect())
            #print(f"Whole df count: {sql_df.count()}")
            #print(f"valid group size count: {valid_group_size_values.count()}")
            if valid_group_size_values.count() == 0:
                print("no valid group size")
                continue


            # filter out just the values which results in fewer than max_group_size
            #sql_df = sql_df.join(valid_group_size_values, [col_1 for [col_1, col_2] in cols_to_join], "leftsemi")
            #print(f"Whole df count after filtering: {sql_df.count()}")
            cur_df = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(emd_UDF(collect_list(
                f"`{numeric_column_to_label['dataset_id']}`"), collect_list(f"`{numeric_col_in_joinable_table['dataset_id']}`")).alias("EMD"), count(f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count"))#.select(col("EMD"), col("count")).where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
            #print(f"Länge: {cur_df.count()}")
            #print(f"DF: {cur_df.show()}")
            print(cur_df.show())
            #cur_df = cur_df.select("*").where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
            #cur_df = cur_df.select("*").where(col("count") <= max_group_size).groupby().agg(func.percentile_approx("EMD", 0.5).alias("med(EMD)"))
            #print(cur_df.show())
            #print(cur_df.select("*").where(col("count") <= max_group_size).count())
            grouped_n_joined_emds = cur_df.select("EMD").where(col("count") >= max_group_size).toPandas()
            grouped_n_joined_emds["unlabeled_col"] = numeric_column_to_label["dataset_id"]
            grouped_n_joined_emds["real_semantic_type"] = numeric_column_to_label["semanticType"]
            grouped_n_joined_emds["labeled_col"] = numeric_col_in_joinable_table["dataset_id"]
            grouped_n_joined_emds["semantic_type"] = numeric_col_in_joinable_table["semanticType"]
            grouped_n_joined_emds["std"] = df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].std()
            print(grouped_n_joined_emds)
            #print(f"Länge: {cur_df.count()}")
            #print(f"DF: {cur_df.show()}")
            #grouped_n_joined_emd = cur_df.collect()[0]["avg(EMD)"]
            #print(grouped_n_joined_emd)
            #results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"], numeric_col_in_joinable_table["dataset_id"], numeric_col_in_joinable_table["semanticType"], grouped_n_joined_emd, df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean()])
            results.extend(grouped_n_joined_emds.values.tolist())
            
        # drop temp views in spark
        spark.catalog.dropTempView(numeric_column_to_label["table"])
        spark.catalog.dropTempView(joinable_table)
    else:
        print("no labeled numerical cols in other table")

In [None]:
pd.set_option('display.max_rows', 500000)
#pd.DataFrame(results, columns=["unlabeled_col","real_semantic_type","labeled_col","semantic_type","EMD", "std"]).sort_values("EMD").reset_index(drop=True)
pd.DataFrame(results, columns=["EMD","unlabeled_col","real_semantic_type","labeled_col","semantic_type", "std"]).sort_values("EMD").reset_index(drop=True)

In [None]:
# 2. search for a equivalent table which has the same types already labeled to make a join over that
joinable_tables = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
    string_already_labeled_cols_in_table["semanticType"])].drop_duplicates()
joinable_tables_sorted = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby(
    "table").count().sort_values(by=["semanticType"], ascending=False)


In [None]:
string_already_labeled_cols_in_table

In [None]:
max_group_size = 3
# 3. check if wich tables are suitable for a join n groub by
results = []
for joinable_table, row in joinable_tables_sorted.iterrows():
    # if joinable_table not in ["MLB_64", "MLB_26"]:#["CommonGovernment_2", "MLB_3", "MLB_65", "MLB_26"]:
    #     continue
    print(joinable_table)
    # first check if there is a numeric column in the table which is alread labeled with an numeric type
    numerics_in_joinable_table = total_labeled_data_df[
        total_labeled_data_df["table"] == joinable_table]
    numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
        numeric_types)]
    #print(numerics_in_joinable_table)

    # only if there are numeric cols already labeled in the joinable table, do the join n group by an try to match
    if len(numerics_in_joinable_table) > 0:
        #print(joinable_table)
        #print(numerics_in_joinable_table)
        strings_in_joinable_table = joinable_tables[joinable_tables["table"]
                                                    == joinable_table]
        #print(strings_in_joinable_table)
        # load the joinable table, do the join n groupy by EMD measurement
        # before join, check if there are values overlapping at all
        # this method can also be use in case of two or more candidates with same semantic type which can be use in principle for a join
        # than use the column with more overlapping values
        cols_to_load_for_joinable_table = list(
            numerics_in_joinable_table["dataset_id"].values)
        cols_to_load_for_joinable_table = cols_to_load_for_joinable_table + \
            list(strings_in_joinable_table["dataset_id"].values)
        #print(cols_to_load_for_joinable_table)
        df_cols_to_load_for_joinable_table = pd.DataFrame({"col_num": [int(col.split(
            "+")[1].split("_")[1]) for col in cols_to_load_for_joinable_table], "col_header": cols_to_load_for_joinable_table}).sort_values(by="col_num")
        df_joinable_table = load_public_bi_table_by_cols(joinable_table.split(
            "_")[0], joinable_table, usecols=df_cols_to_load_for_joinable_table["col_num"].values, col_headers=df_cols_to_load_for_joinable_table["col_header"].values)
        #print(df_joinable_table.head())
        # only use cols for join where are overlapping values are present
        cols_to_join = []
        for string_already_labeled_col_in_table_i, string_already_labeled_col_in_table_row in string_already_labeled_cols_in_table.iterrows():
            for strings_in_joinable_table_i, strings_in_joinable_table_row in strings_in_joinable_table[strings_in_joinable_table["semanticType"] == string_already_labeled_col_in_table_row["semanticType"]].iterrows():
                #print(
                #    string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"])
                #print(
                #    f"overlapping values: {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(df_joinable_table[strings_in_joinable_table_row['dataset_id']]))} from {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))} unique")
                percentage_of_overlap = (len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(
                    df_joinable_table[strings_in_joinable_table_row['dataset_id']]))) / len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))
                #print(percentage_of_overlap)
                if percentage_of_overlap >= 0.75:
                    cols_to_join.append(
                        [string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"]])
        #print("Cols to join:", cols_to_join )
        # drop na values in cols to join in both tables => NaN values leads to high memory usage in the merge
        df_table_with_n_col_to_label.dropna(
            subset=[col1 for [col1, col2] in cols_to_join], inplace=True)
        df_joinable_table.dropna(
            subset=[col2 for [col1, col2] in cols_to_join], inplace=True)

        # Do the join n groub by over the founded column and calc the EMD for each builded group between every numeric columns
        # join
        #print([col1.split("+")[1] for [col1, col2] in cols_to_join])
        if len(cols_to_join) == 0:
            continue
        print(
            f"Table-Join: {numeric_column_to_label['table']} <-> {joinable_table} Cols: {cols_to_join}")
        df_joined = df_table_with_n_col_to_label.merge(df_joinable_table, left_on=[col1 for [
                                                       col1, col2] in cols_to_join], right_on=[col2 for [col1, col2] in cols_to_join])
        print("Finished Table-Join")
        # keep only the labeled cols from table to label and the numeric cols from both
        # df_joined = df_joined[list(set([col1 for [col1, col2] in cols_to_join])) + [numeric_column_to_label["dataset_id"]] + list(
        #     numerics_in_joinable_table["dataset_id"].values)]
        # delete the labeled string cols from the joinable table
        #df_joined.drop(columns=list(set([col2 for [col1, col2] in cols_to_join])), inplace=True)
        #print("Joined Table:")
        #print(df_joined.head(10))
        # group by
        print(f"Group BY: {cols_to_join}")
        df_joined_n_grouped = df_joined.groupby(
            [col1 for [col1, col2] in cols_to_join])
        #print(df_joined_n_grouped.count())
        # iterate over the combinations of to label numeric column and already labeled numeric column
        for i, numeric_col_in_joinable_table in numerics_in_joinable_table.iterrows():
            print("EMD Calc:")
            print(numeric_col_in_joinable_table["dataset_id"])
            grouped_n_joined_emd = df_joined_n_grouped.apply(lambda x: wasserstein_distance(
                x[numeric_column_to_label["dataset_id"]], x[numeric_col_in_joinable_table["dataset_id"]]) if len(x) <= max_group_size else None).mean()
            print(grouped_n_joined_emd)
            results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"],
                           numeric_col_in_joinable_table["dataset_id"], numeric_col_in_joinable_table["semanticType"], grouped_n_joined_emd])
        if len(results) >= 50:
            break


In [None]:
df_results = pd.DataFrame(results, columns=["unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD"])#.sort_values(by="EMD")
df_results = df_results[pd.to_numeric(df_results['EMD'], errors='coerce').notnull()]
df_results = df_results.sort_values(by="EMD")
#df_results.to_csv(f"./joined_n_grouped_EMD/results/{numeric_column_to_label['dataset_id']}_{max_group_size}.csv", index=False)
df_results

In [None]:
if len(df_results) > 0:
    predicted_sematic_type = df_results.iloc[0]["semantic_type"]
    print(predicted_sematic_type)

label_enc.transform([predicted_sematic_type])

In [None]:
for x, y in enumerate(df_joined_n_grouped):
    if x > 0:
        break
    print(y)
    print(len(y[1]))
    print(wasserstein_distance(y[1]["MLB_1+column_10"].values, y[1]["MLB_26+column_18"].values))

# Test the Python skript

In [None]:
%run ./joined_n_grouped_EMD/run_joined_n_grouped_EMD.py --labeled_data_size 4 --corpus "public_bi_num" --gen_train_data True --n_worker 2 --random_state 2 --absolute_numbers True 

# Test pyspark

In [None]:
import pandas as pd
import numpy as np
import pyspark.pandas as ps
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col, pandas_udf, PandasUDFType, collect_list, count, avg, lit, mean, stddev, monotonically_increasing_id, row_number, when
from pyspark.sql.types import StringType, IntegerType, FloatType, DoubleType, StructType, StructField

conf = SparkConf()
#conf.set("spark.executor.instances","2")
#conf.set("spark.executor.cores","2")
conf.set("spark.executor.memory", "150g")
conf.set("spark.driver.memory", "150g")
conf.set("spark.memory.offHeap.enabled", "true")
conf.set("spark.memory.offHeap.size", "50g")
#conf.set("spark.sql.execution.arrow.enabled", "true")
conf.setMaster("local[*]")
conf.setAppName("STEER")

spark = SparkSession.builder.config(conf=conf).getOrCreate()

# create and register UDF-Function to calc EMD-Distance
@udf(returnType=FloatType())
def emd_UDF(col1, col2) -> FloatType:
    return float(wasserstein_distance(col1, col2))

spark.udf.register("emd_UDF", emd_UDF)

In [None]:
max_group_size = 3
# 3. check if wich tables are suitable for a join n groub by
results = []
for joinable_table, row in joinable_tables_sorted.iterrows():
    # ["CommonGovernment_2", "MLB_3", "MLB_65", "MLB_26"]:
    # if joinable_table not in ["MLB_13", "MLB_64"]:
    #     continue
    print(joinable_table)
    # first check if there is a numeric column in the table which is alread labeled with an numeric type
    numerics_in_joinable_table = total_labeled_data_df[
        total_labeled_data_df["table"] == joinable_table]
    numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
        numeric_types)]
    #print(numerics_in_joinable_table)

    # only if there are numeric cols already labeled in the joinable table, do the join n group by an try to match
    if len(numerics_in_joinable_table) > 0:
        print(joinable_table)
        #print(numerics_in_joinable_table)
        strings_in_joinable_table = joinable_tables[joinable_tables["table"]
                                                    == joinable_table]
        #print(strings_in_joinable_table)
        # load the joinable table, do the join n groupy by EMD measurement
        # before join, check if there are values overlapping at all
        # this method can also be use in case of two or more candidates with same semantic type which can be use in principle for a join
        # than use the column with more overlapping values
        cols_to_load_for_joinable_table = list(
            numerics_in_joinable_table["dataset_id"].values)
        cols_to_load_for_joinable_table = cols_to_load_for_joinable_table + \
            list(strings_in_joinable_table["dataset_id"].values)
        #print(cols_to_load_for_joinable_table)
        df_joinable_table = load_public_bi_table_by_cols(joinable_table.split("_")[0], joinable_table, usecols=[int(
            col.split("+")[1].split("_")[1]) for col in cols_to_load_for_joinable_table], col_headers=cols_to_load_for_joinable_table, frac=None)
        #print(df_joinable_table.head())
        # only use cols for join where are overlapping values are present
        cols_to_join = []
        for string_already_labeled_col_in_table_i, string_already_labeled_col_in_table_row in string_already_labeled_cols_in_table.iterrows():
            for strings_in_joinable_table_i, strings_in_joinable_table_row in strings_in_joinable_table[strings_in_joinable_table["semanticType"] == string_already_labeled_col_in_table_row["semanticType"]].iterrows():
                #print(
                #    string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"])
                #print(
                #    f"overlapping values: {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(df_joinable_table[strings_in_joinable_table_row['dataset_id']]))} from {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))} unique")
                percentage_of_overlap = (len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(
                    df_joinable_table[strings_in_joinable_table_row['dataset_id']]))) / len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))
                #print(f"Percentage of overlap: {percentage_of_overlap}")
                if percentage_of_overlap >= 0.75:
                    cols_to_join.append(
                        [string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"]])
        #print("Cols to join:", cols_to_join )
        # drop na values in cols to join in both tables => NaN values leads two hich memory usage in the merge
        df_table_with_n_col_to_label.dropna(
            subset=[col1 for [col1, col2] in cols_to_join], inplace=True)
        df_joinable_table.dropna(
            subset=[col2 for [col1, col2] in cols_to_join], inplace=True)

        # Do the join n groub by over the founded column and calc the EMD for each builded group between every numeric columns
        # join
        #print([col1.split("+")[1] for [col1, col2] in cols_to_join])
        if len(cols_to_join) == 0:
            continue

        # Transform the pandas dfs into pyspark tables
        sdf_1 = spark.createDataFrame(df_table_with_n_col_to_label)
        sdf_1.createOrReplaceTempView(numeric_column_to_label["table"])
        sdf_2 = spark.createDataFrame(df_joinable_table)
        sdf_2.createOrReplaceTempView(joinable_table)

        ## Approach 3
        join_condition = "ON (" + " AND ".join(
            map(lambda join_att: f"`{join_att[0]}` = `{join_att[1]}`", cols_to_join))
        #print(join_condition)
        projection_list = " , ".join(map(lambda attr: f"`{attr}`", list(set([col1 for [col1, col2] in cols_to_join])) + [
                                     numeric_column_to_label["dataset_id"]] + list(numerics_in_joinable_table["dataset_id"].values)))
        #print(projection_list)
        #print(
        #    f"SELECT {projection_list} FROM {numeric_column_to_label['table']} JOIN {joinable_table} {join_condition})")
        sql_df = spark.sql(
            f"SELECT {projection_list} FROM {numeric_column_to_label['table']} JOIN {joinable_table} {join_condition})")
        # filter out null tupels with null values
        sql_df = sql_df.dropna(subset=list(
            map(lambda cur_col: "`{cur_col}`".format(cur_col=cur_col),
                sql_df.columns)))

        #print(sql_df.show())
        for i, numeric_col_in_joinable_table in numerics_in_joinable_table.iterrows():
            print("EMD Calc:")
            print(numeric_col_in_joinable_table["dataset_id"])
            # check first which values have group size fewer than max group size
            valid_group_size_values = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(count(f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count")).where(col("count") <= max_group_size)
            #print(valid_group_size_values.show())
            #print(f"Whole df count: {sql_df.count()}")
            #print(f"valid group size count: {valid_group_size_values.count()}")
            if valid_group_size_values.count() == 0:
                continue
            # filter out just the values which results in fewer than max_group_size
            #sql_df = sql_df.join(valid_group_size_values, [col_1 for [col_1, col_2] in cols_to_join], "leftsemi")
            #print(f"Whole df count after filtering: {sql_df.count()}")
            cur_df = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(emd_UDF(collect_list(
                f"`{numeric_column_to_label['dataset_id']}`"), collect_list(f"`{numeric_col_in_joinable_table['dataset_id']}`")).alias("EMD"), count(f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count"))#.select(col("EMD"), col("count")).where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
            #print(f"Länge: {cur_df.count()}")
            #print(f"DF: {cur_df.show()}")
            cur_df = cur_df.select("*").where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
            #print(cur_df.show())
            #print(f"Länge: {cur_df.count()}")
            #print(f"DF: {cur_df.show()}")
            grouped_n_joined_emd = cur_df.collect()[0]["avg(EMD)"]
            print(grouped_n_joined_emd)
            results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"], numeric_col_in_joinable_table["dataset_id"], numeric_col_in_joinable_table["semanticType"], grouped_n_joined_emd])

        # drop temp views in spark
        spark.catalog.dropTempView(numeric_column_to_label["table"])
        spark.catalog.dropTempView(joinable_table)

        # if len(results) >= 50:
        #     break


In [None]:
df_table_with_n_col_to_label["MLB_1+column_33"] = df_table_with_n_col_to_label["MLB_1+column_33"].astype(str)

In [None]:
df_table_with_n_col_to_label.dtypes

In [None]:
sdf_1 = spark.createDataFrame(df_table_with_n_col_to_label)
sdf_1.createOrReplaceTempView(numeric_column_to_label["table"])

In [None]:
df_results = pd.DataFrame(results, columns=["unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD"])#.sort_values(by="EMD")
df_results = df_results[pd.to_numeric(df_results['EMD'], errors='coerce').notnull()]
df_results = df_results.sort_values(by="EMD")
#df_results.to_csv(f"./joined_n_grouped_EMD/results/{numeric_column_to_label['dataset_id']}_{max_group_size}.csv", index=False)
df_results

# Test on Football Stats

In [None]:
import pandas as pd
from scipy.stats import wasserstein_distance

df_2021 = pd.read_csv("2020-2021_football_stats.csv")[["Spieler", "Pos", "GS", "Tor"]]#.sort_values(by="Tor", ascending=False)
df_2019 = pd.read_csv("2019-2020_football_stats.csv")[["Spieler", "Pos", "GS", "Tor"]]#.sort_values(by="Tor", ascending=False)

In [None]:
results = []
for col1, col2 in [("GS", "GS"),("GS", "Tor"), ("Tor", "GS"), ("Tor", "Tor")]:
    emd = wasserstein_distance(df_2019[col1], df_2021[col2])
    results.append([col1, col2, emd])

pd.DataFrame(results, columns=["Col_1", "Col_2", "EMD"])

In [None]:
df_2019

In [None]:
df_2021

In [None]:
df_joined = df_2019.merge(df_2021, left_on=["Pos"], right_on=["Pos"])
#df_joined = df_2019.merge(df_2021, left_on=["Spieler","Pos"], right_on=["Spieler","Pos"])
print(df_joined.head(5))
df_joined_n_grouped = df_joined.groupby("Pos")

In [None]:
max_group_size = None
results = []
results_combined = []
for col1, col2 in [("GS_x", "GS_y"),("GS_x", "Tor_y"), ("Tor_x", "GS_y"), ("Tor_x", "Tor_y")]:
    if max_group_size == None:
        j_n_g_emd = df_joined_n_grouped.apply(lambda x: wasserstein_distance(x[col1], x[col2]))
    else:
        #j_n_g_size = df_joined_n_grouped.apply(lambda x: len(x) if len(x) <= max_group_size else None)
        j_n_g_emd = df_joined_n_grouped.apply(lambda x: wasserstein_distance(x[col1], x[col2]) if len(x) <= max_group_size else None)
    #print(j_n_g_emd)
    for index, value in j_n_g_emd.items():
        #print(value)
        results.append([col1, col2, index, value])
    
    results_combined.append([col1, col2, j_n_g_emd.min()])

df_results = pd.DataFrame(results, columns=["Col_1", "Col_2", "Pos", "EMD"])
df_results[df_results["Pos"] == "FW"]
#df_results

In [None]:
df_results = pd.DataFrame(results_combined, columns=["Col_1", "Col_2", "EMD"])
df_results

## Correlation analysis

In [None]:
df_2019

In [None]:
from scipy import stats
from scipy.stats import wasserstein_distance

col = "Tor"

F, p = stats.f_oneway(*[df_2019[df_2019["Pos"] == "TW"][col],
                    df_2019[df_2019["Pos"] == "DF"][col],
                    df_2019[df_2019["Pos"] == "FW"][col]])
print(F, p)

In [None]:
results = []
corr_col = "Pos"
col = "Tor"
for category in df_2019[corr_col].unique():
    results.append(df_2019[df_2019[corr_col] == category][col])

print(f"Correlation from '{col}' -> '{corr_col}' :")
F, p = stats.f_oneway(*results)
print(f"ANOVA test:  \t\tF:{F}, \tp:{p}")

F, p = stats.kruskal(*results)
print(f"Kruskal-Wallis test:  \tF:{F}, \tp:{p}")

results = []
col = "GS"
for category in df_2019[corr_col].unique():
    results.append(df_2019[df_2019[corr_col] == category][col])

print("")
print(f"Correlation from '{col}' -> '{corr_col}' :")
F, p = stats.f_oneway(*results)
print(f"ANOVA test:  \t\tF:{F}, \tp:{p}")

F, p = stats.kruskal(*results)
print(f"Kruskal-Wallis test:  \tF:{F}, \tp:{p}")


In [None]:
F, p = stats.f_oneway(*[df_2019["Tor"],df_2021["Tor"]])
print(F, p)

In [None]:
results = []
for col1, col2 in [("GS", "GS"),("GS", "Tor"), ("Tor", "GS"), ("Tor", "Tor")]:
    F, p = stats.f_oneway(df_2019[col1], df_2021[col2]) 
    results.append([col1, col2, p])

df_results = pd.DataFrame(results, columns=["Col_1", "Col_2", "p"])
df_results

In [None]:
results = []
for col1, col2 in [("GS", "GS"),("GS", "Tor"), ("Tor", "GS"), ("Tor", "Tor")]:
    F, p = stats.kruskal(df_2019[col1], df_2021[col2])
    results.append([col1, col2, p])

df_results = pd.DataFrame(results, columns=["Col_1", "Col_2", "p"])
df_results

# Correlations analysis on MLB Tables

In [None]:
import os
import sys
from dotenv import load_dotenv

load_dotenv(override=True)

sys.path.append(os.environ["WORKING_DIR"])
from os.path import join
from snorkel.labeling import labeling_function
#from snorkel.preprocess import preprocessor
import copy
import json
#import configargparse
from sklearn.metrics.pairwise import cosine_similarity
#from snorkel.labeling import PandasLFApplier
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from sklearn.metrics import classification_report
import numpy as np
from data_loader.utils import load_public_bi_table_by_cols, load_public_bi_table
from scipy.stats import wasserstein_distance, kruskal, f_oneway

In [None]:
df_mlb_10 = load_public_bi_table("MLB", "MLB_10", 1)
df_mlb_11 = load_public_bi_table("MLB", "MLB_11", 1)
df_mlb_10.columns

In [None]:
results = []
corr_col = "parentteam"
col = "H"
for category in df_mlb_10[corr_col].unique():
    results.append(df_mlb_10[df_mlb_10[corr_col] == category][col])

F, p = kruskal(*results)
print(f"Kruskal-Wallis test:  \tF:{F}, \tp:{p}")

In [None]:
df_mlb_10[corr_col].unique()

In [None]:
print(len(df_mlb_10["batter_name"].unique()))
print(len(df_mlb_10["batter_name"]))

print(len(df_mlb_10["batter_name"].unique())/len(df_mlb_10["batter_name"]))

In [None]:
print(len(df_mlb_10["parentteam"].unique()))
print(len(df_mlb_10["parentteam"]))

print(len(df_mlb_10["parentteam"].unique())/len(df_mlb_10["parentteam"]))

In [None]:
len(df_mlb_10["batter_name"].unique())

# embedd numerical columns with pre-trained LMs

In [None]:
from transformers import BertTokenizer, BertModel

# from transformers import AutoTokenizer, AutoModelForMaskedLM
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')


In [None]:
import torch

col_1 = tokenizer.encode("10, 20, 30, 40, 50, 60, 70", return_tensors="pt", max_length=50)
col_2 = tokenizer.encode("David Beckham, David Alaba", return_tensors="pt", max_length=50)

list_1 = col_1.tolist()[0][:-1]
list_1.extend(col_2.tolist()[0])

col_3 = tokenizer.encode("10, 20, 30, 40, 50, 60, 70", return_tensors="pt", max_length=50)
#col_4 = tokenizer.encode("David Beckham, David Alaba", return_tensors="pt", max_length=50)
col_4 = tokenizer.encode("Bayern Muenchen Schalke 04", return_tensors="pt", max_length=50)

list_2 = col_3.tolist()[0][:-1]
list_2.extend(col_4.tolist()[0])

output_1 = model(torch.LongTensor([list_1]))
output_2 = model(torch.LongTensor([list_2]))

In [None]:
col_1 = "10, 20, 30, 40, 50, 60, 70"
col_2 = "David Beckham, David Alaba"

list_1 = tokenizer([col_1, col_2], return_tensors="pt", max_length=50, padding=True)

col_3 = "10, 20, 30, 40, 50, 60, 70"
#col_4 = "David Beckham, David Alaba"
col_4 = "Bayern Muenchen Schalke 04"

list_2 = tokenizer([col_3, col_4], return_tensors="pt", max_length=50, padding=True)

output_1 = model(**list_1)
output_2 = model(**list_2)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

cosine_similarity([output_1[0][0][0].detach().numpy()], [output_2[0][0][0].detach().numpy()])

In [None]:

#cosine_similarity([output_de[0][0][0].detach().numpy()], [output_en[0][0][0].detach().numpy()])
euclidean_distances([output_de[0][0][0].detach().numpy()], [output_en[0][0][0].detach().numpy()])

# normal EMD approach using SportsDB

In [1]:
import sys
import os
from dotenv import load_dotenv

load_dotenv(override=True)
sys.path.append(os.environ["WORKING_DIR"])

from scipy.stats import wasserstein_distance, kruskal
from data_loader.utils import load_public_bi_table_by_cols, load_sportsDB_soccer_table
import numpy as np
from sklearn.metrics import classification_report
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
import json
import copy
from os.path import join
from snorkel.labeling import labeling_function
from snorkel.preprocess import preprocessor
import configargparse
from snorkel.labeling import PandasLFApplier
import pyspark.pandas as ps
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col, pandas_udf, PandasUDFType, collect_list, count, avg, lit, mean, stddev, monotonically_increasing_id, row_number
from pyspark.sql.types import StringType, IntegerType, FloatType, DoubleType, StructType, StructField


# create and register UDF-Function to calc EMD-Distance


# @udf(returnType=FloatType())
# def emd_UDF(col1, col2) -> FloatType:
#     return float(wasserstein_distance(col1, col2))


# conf = SparkConf()
# # conf.set("spark.executor.instances","2")
# # conf.set("spark.executor.cores","2")
# conf.set("spark.executor.memory", "150g")
# conf.set("spark.driver.memory", "150g")
# conf.set("spark.memory.offHeap.enabled", "true")
# conf.set("spark.memory.offHeap.size", "50g")
# #conf.set("spark.sql.execution.arrow.enabled", "true")
# conf.setMaster("local[*]")
# conf.setAppName("STEER")

#spark = SparkSession.builder.config(conf=conf).getOrCreate()

#spark.udf.register("emd_UDF", emd_UDF)

labeled_unlabeled_test_split_path = join(os.environ["WORKING_DIR"], "data",
                                         "extract", "out",
                                         "labeled_unlabeled_test_split")

valid_headers_path = join(os.environ["WORKING_DIR"], "data", "extract", "out",
                          "valid_headers")

gen_train_data_path = join(os.environ["WORKING_DIR"], "labeling_functions", "combined_LFs",
                           "gen_training_data")

numeric_types = ["X1B",
                 "X2B",
                 "X3B",
                 "TB",
                 "HR",
                 "R",
                 "BB",
                 "AB",
                 "GIDP",
                 "HBP",
                 "H",
                 "SF",
                 "SH",
                 "SO",
                 "iBB",
                 "CS",
                 "SB",
                 "latitude",
                 "longitude",
                 "year"]

numeric_types_sportsDB = [
    "age",
    "assists",
    "gamesPlayed",
    "goals",
    "goalsPlusAssistsPer90Min",
    "minutesPlayed",
    "nonPenaltyXGoalsPer90Min",
    "nonPenaltyXGoalsPlusAssists",
    "penaltiesAttempted",
    "penaltiesScored",
    "redCards",
    "xAssistsPer90Min",
    "xGoalsPer90Min",
    "xGoalsPlusAssistsPer90Min",
    "yellowCards"
  ]

# LabelEncoder
with open(
    join(os.environ["WORKING_DIR"], "data", "extract", "out",
         "valid_types", "types.json")) as f:
    valid_types = json.load(f)["type_sportsDB"]

label_enc = LabelEncoder()
label_enc.fit(valid_types)

PERCENTAGE_OF_OVERLAP_FOR_JOIN_CANDIDATES = 0.75
P_VALUE_CORRSTEERTION_ANALYZATION = 0.05
PERCENTAGE_THRESHOLD_UNIQUE_VALUES = 0.1

labeled_data_size = 4
unlabeled_data_size = "absolute"
test_data_size = 20.1
validation_on = "test"
gen_train_data = True
corpus = "sportsDB"
absolute_numbers = True
n_worker = 4
threshold_EMD_factor = 0.1
max_group_size = 4
random_state = 2
table_frac = None
approach = 1  # because this script is just for normal EMD


for labeled_data_size in [1,2,3,4,5]:
    for random_state in [1,2,3,4,5]:

        if absolute_numbers:
            unlabeled_data_size = "absolute"
            labeled_data_size = int(labeled_data_size)

        #############
        # Load data
        #############

        # load labeled data from labeled, unlabeled, test split file
        with open(
                join(
                    labeled_unlabeled_test_split_path,
                    f"{corpus}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.json"
                )) as f:
            labeled_unlabeled_test_split_file = json.load(f)
            labeled_data_ids = labeled_unlabeled_test_split_file[
                f"labeled{labeled_data_size}"]
            if gen_train_data:
                if absolute_numbers:
                    unlabeled_data_ids = labeled_unlabeled_test_split_file[
                        f"unlabeled"]
                else:
                    unlabeled_data_ids = labeled_unlabeled_test_split_file[
                        f"unlabeled{unlabeled_data_size}"]
                print(f"Unlabeled Data: {len(unlabeled_data_ids)}")
            if validation_on == "unlabeled":
                test_data_ids = labeled_unlabeled_test_split_file[
                    f"{validation_on}"]
            else:
                test_data_ids = labeled_unlabeled_test_split_file[
                    f"{validation_on}{test_data_size}"]

        print(f"Labeled Data: {len(labeled_data_ids)}")
        print(f"Test Data: {len(test_data_ids)}")

        # load the valid headers with real sem. types
        valid_header_file = f"{corpus}_type_sportsDB.json"
        valid_headers = join(valid_headers_path, valid_header_file)
        with open(valid_headers, "r") as file:
            valid_headers = json.load(file)
        # transform valid header into df to make it joinable with word embeddings
        valid_header_df_data = []
        for table in valid_headers.keys():
            for column in valid_headers[table].keys():
                valid_header_df_data.append([
                    table, column, table + "+" + column,
                    valid_headers[table][column]["semanticType"]
                ])
        valid_header_df = pd.DataFrame(
            valid_header_df_data,
            columns=["table", "column", "dataset_id", "semanticType"])

        #############
        # Build LF
        #############


        @labeling_function()
        def normal_EMD(numeric_column_to_label):
            print("Numeric Column to label: " +
                numeric_column_to_label["dataset_id"])
            # load the table with the numeric column to label
            cols_to_load = [numeric_column_to_label["dataset_id"]]
            df_cols_to_load = pd.DataFrame({"col_num": [int(col.split(
                "+")[1].split("_")[1]) for col in cols_to_load], "col_header": cols_to_load}).sort_values(by="col_num")

            df_table_with_n_col_to_label = load_sportsDB_soccer_table(
                numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)

            df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]] = pd.to_numeric(
                df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]], errors="coerce")
            df_table_with_n_col_to_label.dropna(inplace=True)
            if len(df_table_with_n_col_to_label) == 0:
                return -1

            # search all already labeled numeric cols in the corpus
            already_labeled_numeric_cols = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
                numeric_types_sportsDB)].drop_duplicates()

            # iterrate over all alread labeled numeric col and do the EMD measure
            results = []
            for index, row in already_labeled_numeric_cols.iterrows():

                df_table_with_labeled_numeric = load_sportsDB_soccer_table(row["table"], usecols=[int(
                    row["column"].split("_")[1])], col_headers=[row["dataset_id"]], frac=table_frac)

                df_table_with_labeled_numeric[row["dataset_id"]] = pd.to_numeric(
                    df_table_with_labeled_numeric[row["dataset_id"]], errors="coerce")
                df_table_with_labeled_numeric.dropna(inplace=True)
                if len(df_table_with_labeled_numeric[row["dataset_id"]].to_list()) == 0:
                    continue

                # EMD calc
                emd = wasserstein_distance(df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].to_list(
                ), df_table_with_labeled_numeric[row["dataset_id"]].to_list())
                print(f"EMD: {emd}")
                results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"],
                                row["dataset_id"], row["semanticType"], emd, df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean(), df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].std()])

            df_results = pd.DataFrame(results, columns=[
                "unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD", "mean", "std"])  # .sort_values(by="EMD")
            df_results = df_results[pd.to_numeric(
                df_results['EMD'], errors='coerce').notnull()]
            df_results = df_results.sort_values(by="EMD")
            if gen_train_data:
                df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "normal_EMD", "results", corpus,
                                    f"{numeric_column_to_label['dataset_id']}_appr{approach}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
            elif gen_train_data == False:
                df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "normal_EMD", "results_test_data", corpus,
                                    f"{numeric_column_to_label['dataset_id']}_appr{approach}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
            predicted_semantic_type = -1
            if len(df_results) > 0:
                if df_results.iloc[0]["EMD"] >= (df_results.iloc[0]["std"] * threshold_EMD_factor):
                    return -1
                if len(df_results) < 2:
                    predicted_semantic_type = df_results.iloc[0]["semantic_type"]
                    predicted_semantic_type = label_enc.transform(
                        [predicted_semantic_type])[0]
                if len(df_results) > 2:
                    if ((df_results.iloc[0]["EMD"] == df_results.iloc[1]["EMD"]) & (df_results.iloc[0]["semantic_type"] != df_results.iloc[1]["semantic_type"])):
                        return -1
                    predicted_semantic_type = df_results.iloc[0]["semantic_type"]
                    predicted_semantic_type = label_enc.transform(
                        [predicted_semantic_type])[0]
            return predicted_semantic_type


        if gen_train_data:
            # filter out unlabeled data from valid_headers
            unlabeled_data_df = valid_header_df.loc[
                valid_header_df["dataset_id"].isin(unlabeled_data_ids)]

            # load already labeled data
            labeled_data_df = valid_header_df.loc[valid_header_df["dataset_id"].isin(
                labeled_data_ids)]

            # load already generated labeled train data
            # gen_labeled_data_df = pd.read_csv(join(gen_train_data_path, f"public_bi_gen_training_data_all_combined_maj_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), names=[
            #                                   "table", "column", "dataset_id", "semanticType"])

            ### drop duplicate only on cols "table", "column", "dataset_id"!!! This must be fixed. Actually it cant happen that there are two duplicte sets with different semantic types!
            # total_labeled_data_df = pd.concat(
            #     [labeled_data_df, gen_labeled_data_df]).drop_duplicates(subset=["table", "column", "dataset_id"])
            total_labeled_data_df = labeled_data_df.drop_duplicates(
                subset=["table", "column", "dataset_id"])

            # only unlabaled columns of tyoe numeric
            numeric_unlabeled_data_df = unlabeled_data_df.loc[unlabeled_data_df["semanticType"].isin(
                numeric_types_sportsDB)]
            #numeric_unlabeled_data_df = numeric_unlabeled_data_df[678:]

            # define LF to apply
            lfs = [normal_EMD]

            # snorkel pandas applier for apply lfs to the data
            applier = PandasLFApplier(lfs=lfs)

            from multiprocessing import Pool
            from multiprocessing.pool import ThreadPool as Pool
            from functools import partial
            import numpy as np
            from tqdm.auto import tqdm

            def parallelize(data, func, num_of_processes=8):
                data_split = np.array_split(data, num_of_processes)
                pool = Pool(num_of_processes)
                #data = pd.concat(pool.map(func, data_split))
                data = np.concatenate(pool.map(func, data_split), axis=0)
                pool.close()
                pool.join()
                return data

            L_train = applier.apply(df=numeric_unlabeled_data_df)
            #L_train = parallelize(numeric_unlabeled_data_df,
            #                      applier.apply, n_worker)

            print(
                f"Length of labeled data: {len([x for x in L_train if x != -1])}")

            numeric_unlabeled_data_df["predicted_semantic_type"] = [
                label_enc.inverse_transform([x])[0] if x != -1 else "None"
                for x in L_train
            ]
            numeric_unlabeled_data_df.to_csv(join(
                os.environ["WORKING_DIR"], "labeling_functions",
                "numerics", "normal_EMD", "out", "results",
                f"{corpus}_normal_EMD_{threshold_EMD_factor}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"
            ),
                index=False)

            # save gen train data
            class_reportable_data = numeric_unlabeled_data_df.drop(numeric_unlabeled_data_df[
                numeric_unlabeled_data_df["predicted_semantic_type"] == "None"].index)

            class_reportable_data[[
                "table", "column", "dataset_id", "predicted_semantic_type"
            ]].to_csv(join(
                os.environ["WORKING_DIR"], "labeling_functions", "numerics",
                "normal_EMD", "out", "gen_train_data",
                f"{corpus}_gen_training_data_{threshold_EMD_factor}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"
            ),
                index=False)

            cls_report = classification_report(
                class_reportable_data["semanticType"],
                class_reportable_data["predicted_semantic_type"],
                output_dict=True)

            # save classification_report
            with open(
                    join(
                        os.environ["WORKING_DIR"], "labeling_functions", "numerics",
                        "normal_EMD", "out", "validation",
                        f"{corpus}_classification_report_unlabeled_{threshold_EMD_factor}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.json"
                    ), "w") as f:
                json.dump(cls_report, f)

        ########################################
        ## Validation only / no gen training data
        ########################################
        else:
            # filter out unlabeled data from valid_headers
            test_data_df = valid_header_df.loc[
                valid_header_df["dataset_id"].isin(test_data_ids)]

            # load already labeled data
            labeled_data_df = valid_header_df.loc[valid_header_df["dataset_id"].isin(
                labeled_data_ids)]

            # load already generated labeled train data
            gen_labeled_data_df = pd.read_csv(join(gen_train_data_path, f"public_bi_gen_training_data_all_combined_maj_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), names=[
                "table", "column", "dataset_id", "semanticType"])

            ### drop duplicate only on cols "table", "column", "dataset_id"!!! This must be fixed. Actually it cant happen that there are two duplicte sets with different semantic types!
            total_labeled_data_df = pd.concat(
                [labeled_data_df, gen_labeled_data_df]).drop_duplicates(subset=["table", "column", "dataset_id"])

            # only unlabaled columns of tyoe numeric
            numeric_test_data_df = test_data_df.loc[test_data_df["semanticType"].isin(
                numeric_types_sportsDB)]
            #numeric_unlabeled_data_df = numeric_unlabeled_data_df[678:]

            # define LF to apply
            lfs = [normal_EMD]

            # snorkel pandas applier for apply lfs to the data
            applier = PandasLFApplier(lfs=lfs)

            from multiprocessing import Pool
            from multiprocessing.pool import ThreadPool as Pool
            from functools import partial
            import numpy as np
            from tqdm.auto import tqdm

            def parallelize(data, func, num_of_processes=8):
                data_split = np.array_split(data, num_of_processes)
                pool = Pool(num_of_processes)
                #data = pd.concat(pool.map(func, data_split))
                data = np.concatenate(pool.map(func, data_split), axis=0)
                pool.close()
                pool.join()
                return data

            L_train = applier.apply(df=numeric_test_data_df)
            #L_train = parallelize(numeric_unlabeled_data_df,
            #                      applier.apply, n_worker)

            print(
                f"Length of labeled data: {len([x for x in L_train if x != -1])}")


  5%|▌         | 45/857 [00:17<04:01,  3.36it/s]

EMD: 12.698687135449713
Loading Table: soccer-germany-Players-2019-2020....
Loaded with 1318 rows
EMD: 13.993080499855559
Loading Table: soccer-germany-Players-2019-2020....
Loaded with 1318 rows
EMD: 14.10173241132437
Loading Table: soccer-germany-Players-2019-2020....
Loaded with 1318 rows
EMD: 14.513804845932016
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 341.0503944763959
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 15.547302583896238
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 14.502821321020168
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 15.490464885376605
Loading Table: soccer-italy-Players-1994-1995....
Loaded with 246 rows
Loading Table: soccer-italy-Players-1996-1997....
Loaded with 467 rows
EMD: 10.185397557752996
Loading Table: soccer-italy-Players-1997-1998....
Loaded with 629 rows
Loading Table: soccer-italy-Players-2006-2007....
Loade

  5%|▌         | 46/857 [00:18<04:27,  3.03it/s]

Loaded with 1670 rows
EMD: 366.44329926925866
Numeric Column to label: soccer-big5Leagues-Players-2015-2016+column_8
Loading Table: soccer-big5Leagues-Players-2015-2016....
Loaded with 6776 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 1.4964934587624432
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 1.2495079018775268
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 1.4781627910106945
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
EMD: 1.4566117620422254
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 0.2724194836305559
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 0.31337461922929044
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 1.2624756158873804
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 1.48962

  5%|▌         | 47/857 [00:18<05:04,  2.66it/s]

Loading Table: soccer-italy-Players-1997-1998....
Loaded with 629 rows
Loading Table: soccer-italy-Players-2006-2007....
Loaded with 1146 rows
EMD: 360.3108173940507
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 0.21817293817293826
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 1.5463623458433144
Loading Table: soccer-italy-Players-2011-2012....
Loaded with 1123 rows
EMD: 0.6076775162591397
Loading Table: soccer-italy-Players-2012-2013....
Loaded with 1080 rows
EMD: 1.4141285146299074
Loading Table: soccer-italy-Players-2015-2016....
Loaded with 1583 rows
EMD: 0.2879654595118515
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 1570 rows
EMD: 0.3697324587557144
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 1570 rows
EMD: 1.497212721129
Loading Table: soccer-italy-Players-2018-2019....
Loaded with 1551 rows
EMD: 0.2685129329734592
Loading Table: soccer-italy-Players-2018-2019....
Loaded with 1551 ro

  6%|▌         | 48/857 [00:19<05:26,  2.48it/s]

Numeric Column to label: soccer-big5Leagues-Players-2015-2016+column_10
Loading Table: soccer-big5Leagues-Players-2015-2016....
Loaded with 6776 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 1.3666848479331095
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 1.272259644153809
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 1.3496254532325138
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
EMD: 1.326803151212892
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 0.33918637877497
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 0.26453584486291426
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 1.389328264340872
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 1.3616378825404771
Loading Table: soccer-big5Leagues-Play

  6%|▌         | 49/857 [00:19<05:44,  2.35it/s]

EMD: 0.28577304357443506
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 1570 rows
EMD: 1.369092923072021
Loading Table: soccer-italy-Players-2018-2019....
Loaded with 1551 rows
EMD: 0.2882320332428946
Loading Table: soccer-italy-Players-2018-2019....
Loaded with 1551 rows
EMD: 1.2198823304729538
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 1.2560980797907326
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 0.2595158143310762
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 1.3235947496493254
Loading Table: soccer-italy-Players-2020-2021....
Loaded with 1670 rows
EMD: 0.4726780528338006
Loading Table: soccer-italy-Players-2020-2021....
Loaded with 1670 rows
EMD: 0.29812673402979906
Numeric Column to label: soccer-big5Leagues-Players-2015-2016+column_11
Loading Table: soccer-big5Leagues-Players-2015-2016....
Loaded with 6776 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
L

  6%|▌         | 50/857 [00:20<05:47,  2.32it/s]

Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 0.030336715924723312
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 1.0748179788007914
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 0.23812593028441398
Loading Table: soccer-italy-Players-1994-1995....
Loaded with 246 rows
Loading Table: soccer-italy-Players-1996-1997....
Loaded with 467 rows
EMD: 24.229412504634528
Loading Table: soccer-italy-Players-1997-1998....
Loaded with 629 rows
Loading Table: soccer-italy-Players-2006-2007....
Loaded with 1146 rows
EMD: 361.8211706186837
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 1.2921802864601286
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 2.8307421283091685
Loading Table: soccer-italy-Players-2011-2012....
Loaded with 1123 rows
EMD: 0.9026757083739273
Loading Table: soccer-italy-Players-2012-2013....
Loaded with 1080 rows
EMD: 0.0962

  6%|▌         | 51/857 [00:20<05:54,  2.27it/s]

Loaded with 1136 rows
EMD: 0.19049644160899032
Loading Table: soccer-germany-Players-2018-2019....
Loaded with 1160 rows
EMD: 2.846412345909182
Loading Table: soccer-germany-Players-2019-2020....
Loaded with 1318 rows
EMD: 1.5520189815033376
Loading Table: soccer-germany-Players-2019-2020....
Loaded with 1318 rows
EMD: 1.443367070034525
Loading Table: soccer-germany-Players-2019-2020....
Loaded with 1318 rows
EMD: 1.031294635426879
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 356.59549395775485
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 0.13767706805123925
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 1.0422781603387261
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 0.1893560982061555
Loading Table: soccer-italy-Players-1994-1995....
Loaded with 246 rows
Loading Table: soccer-italy-Players-1996-1997....
Loaded with 467 rows
EMD: 24.196872686172465
Loadi

Loaded with 1570 rows
EMD: 1.4086921970918516
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 1570 rows
EMD: 0.04181976966306057
Loading Table: soccer-italy-Players-2018-2019....
Loaded with 1551 rows
EMD: 1.548779779342173
Loading Table: soccer-italy-Players-2018-2019....
Loaded with 1551 rows


  6%|▌         | 52/857 [00:21<06:06,  2.20it/s]

EMD: 2.534141621447437
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 0.28344830444965413
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 1.4205305661638663
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 0.07255941013365291
Loading Table: soccer-italy-Players-2020-2021....
Loaded with 1670 rows
EMD: 0.8916035964277503
Loading Table: soccer-italy-Players-2020-2021....
Loaded with 1670 rows
EMD: 1.5734822084919702
Numeric Column to label: soccer-big5Leagues-Players-2015-2016+column_19
Loading Table: soccer-big5Leagues-Players-2015-2016....
Loaded with 6776 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 0.1945458084085196
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 2.481904011061063
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 0.010082844366312339
Loading Table: soccer-big5Leagues-Play

  6%|▌         | 53/857 [00:21<06:15,  2.14it/s]

EMD: 1.0343607087202982
Loading Table: soccer-germany-Players-2020-2021....
Loaded with 1282 rows
EMD: 0.060434975472200184
Loading Table: soccer-italy-Players-1994-1995....
Loaded with 246 rows
Loading Table: soccer-italy-Players-1996-1997....
Loaded with 467 rows
EMD: 24.185187466100537
Loading Table: soccer-italy-Players-1997-1998....
Loaded with 629 rows
Loading Table: soccer-italy-Players-2006-2007....
Loaded with 1146 rows
EMD: 361.77694558014974
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 1.2682735346516818
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 2.7890860228116514
Loading Table: soccer-italy-Players-2011-2012....
Loaded with 1123 rows
EMD: 0.8891707701100555
Loading Table: soccer-italy-Players-2012-2013....
Loaded with 1080 rows
EMD: 0.25545716115327993
Loading Table: soccer-italy-Players-2015-2016....
Loaded with 1583 rows
EMD: 1.2056525232517983
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 

  6%|▋         | 54/857 [00:22<06:23,  2.09it/s]

EMD: 14.710810985271806
Loading Table: soccer-italy-Players-2020-2021....
Loaded with 1670 rows
EMD: 13.99470038998474
Numeric Column to label: soccer-big5Leagues-Players-2016-2017+column_5
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 12.366979598137473
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 9.855909149830062
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 12.333045014506691
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
EMD: 12.327097901417254
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 10.680914747888988
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 10.835586028165805
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 9.626822609944487
Loading Table: socce

  6%|▋         | 55/857 [00:22<06:38,  2.01it/s]

EMD: 12.297505835239608
Loading Table: soccer-italy-Players-1994-1995....
Loaded with 246 rows
Loading Table: soccer-italy-Players-1996-1997....
Loaded with 467 rows
EMD: 12.622924920093608
Loading Table: soccer-italy-Players-1997-1998....
Loaded with 629 rows
Loading Table: soccer-italy-Players-2006-2007....
Loaded with 1146 rows
EMD: 349.44033125467564
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 11.088659077547966
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 9.550097235698926
Loading Table: soccer-italy-Players-2011-2012....
Loaded with 1123 rows
EMD: 11.478163655634166
Loading Table: soccer-italy-Players-2012-2013....
Loaded with 1080 rows
EMD: 12.284614654004937
Loading Table: soccer-italy-Players-2015-2016....
Loaded with 1583 rows
EMD: 11.155896307671794
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 1570 rows
EMD: 10.923784388642268
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 1570 r

  7%|▋         | 56/857 [00:23<06:36,  2.02it/s]

Loaded with 1592 rows
EMD: 14.165356018148398
Loading Table: soccer-italy-Players-2020-2021....
Loaded with 1670 rows
EMD: 13.377415718710854
Loading Table: soccer-italy-Players-2020-2021....
Loaded with 1670 rows
EMD: 12.661305123423787
Numeric Column to label: soccer-big5Leagues-Players-2016-2017+column_8
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 1.5655910169462874
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 1.22201148426075
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 1.5465462881411658
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
EMD: 1.5257093202260696
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 0.26601030195657405
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 0.3466648274272698
Lo

  7%|▋         | 57/857 [00:23<06:35,  2.02it/s]

EMD: 360.2417198358667
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 0.2872704963567825
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 1.5116246023752782
Loading Table: soccer-italy-Players-2011-2012....
Loaded with 1123 rows
EMD: 0.6767750744429839
Loading Table: soccer-italy-Players-2012-2013....
Loaded with 1080 rows
EMD: 1.4832260728137516
Loading Table: soccer-italy-Players-2015-2016....
Loaded with 1583 rows
EMD: 0.3568925713426878
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 1570 rows
EMD: 0.4094817016680006
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 1570 rows
EMD: 1.5652788577913057
Loading Table: soccer-italy-Players-2018-2019....
Loaded with 1551 rows
EMD: 0.28758654455260046
Loading Table: soccer-italy-Players-2018-2019....
Loaded with 1551 rows
EMD: 1.050091535362818
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 1.4550042488039103
Loading Table: so

  7%|▋         | 58/857 [00:24<06:31,  2.04it/s]

Numeric Column to label: soccer-big5Leagues-Players-2016-2017+column_11
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 0.009957187536363876
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 2.5126219320932215
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 0.19189218698892918
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
EMD: 0.08490405841399218
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 1.6883492633716057
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 1.533677983094789
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 2.756784480595751
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 0.22852894379351335
Loading Table: soccer-big5Leagu

  7%|▋         | 59/857 [00:24<06:27,  2.06it/s]

Loaded with 1551 rows
EMD: 1.5869664450223875
Loading Table: soccer-italy-Players-2018-2019....
Loaded with 1551 rows
EMD: 2.572328287127651
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 0.1121382519281876
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 1.4572101244626803
Loading Table: soccer-italy-Players-2019-2020....
Loaded with 1592 rows
EMD: 0.2526507893512362
Loading Table: soccer-italy-Players-2020-2021....
Loaded with 1670 rows
EMD: 0.8955582788851195
Loading Table: soccer-italy-Players-2020-2021....
Loaded with 1670 rows
EMD: 1.6116688741721852
Numeric Column to label: soccer-big5Leagues-Players-2016-2017+column_12
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 0.04267767358840591
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 2.4688114379516017
Loading Table: soccer-bi

  7%|▋         | 60/857 [00:25<06:24,  2.07it/s]

Loaded with 1282 rows
EMD: 0.26613706036930745
Loading Table: soccer-italy-Players-1994-1995....
Loaded with 246 rows
Loading Table: soccer-italy-Players-1996-1997....
Loaded with 467 rows
EMD: 24.1732937284081
Loading Table: soccer-italy-Players-1997-1998....
Loaded with 629 rows
Loading Table: soccer-italy-Players-2006-2007....
Loaded with 1146 rows
EMD: 361.76505184245724
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 1.2360615102336956
Loading Table: soccer-italy-Players-2010-2011....
Loaded with 1156 rows
EMD: 2.7746233520827346
Loading Table: soccer-italy-Players-2011-2012....
Loaded with 1123 rows
EMD: 0.8465569321474941
Loading Table: soccer-italy-Players-2012-2013....
Loaded with 1080 rows
EMD: 0.040991395600408875
Loading Table: soccer-italy-Players-2015-2016....
Loaded with 1583 rows
EMD: 1.1688242801098676
Loading Table: soccer-italy-Players-2017-2018....
Loaded with 1570 rows
EMD: 1.4022805624992594
Loading Table: soccer-italy-Players-2017-201

  7%|▋         | 61/857 [00:25<06:26,  2.06it/s]

Numeric Column to label: soccer-germany-Players-1999-2000+column_3
Loading Table: soccer-germany-Players-1999-2000....
Loaded with 423 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 26.430277783185065
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 23.919207334877658
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 26.39634319955429
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
EMD: 26.390396086464847
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 24.743480003599267
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 24.89815128387609
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 23.675121679877623
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 26.36683346972177
Loading Table: soccer-big5Leagues-Players-2018-

  7%|▋         | 62/857 [00:25<06:03,  2.19it/s]

Loaded with 3883 rows
EMD: 383.2923490360646
Loading Table: soccer-germany-Players-2004-2005....
Loaded with 708 rows
EMD: 23.921925712225026
Loading Table: soccer-germany-Players-2007-2008....
Loaded with 786 rows
EMD: 26.45360294515727
Loading Table: soccer-italy-Players-1990-1991....
Loaded with 297 rows
EMD: 349.9418051102692
Loading Table: soccer-italy-Players-1990-1991....
Loaded with 297 rows
EMD: 26.288753799392097
Loading Table: soccer-germany-Players-2009-2010....
Loaded with 966 rows
EMD: 24.808422536452134
Loading Table: soccer-germany-Players-2009-2010....
Loaded with 966 rows
EMD: 26.4387167897696
Loading Table: soccer-germany-Players-2010-2011....
Loaded with 998 rows
EMD: 1.8979250334672022
Loading Table: soccer-germany-Players-2012-2013....
Loaded with 1060 rows
EMD: 26.44510075743877
Loading Table: soccer-germany-Players-2012-2013....
Loaded with 1060 rows
EMD: 24.10799027585237
Loading Table: soccer-germany-Players-2015-2016....
Loaded with 1180 rows
EMD: 24.98905541

  7%|▋         | 63/857 [00:26<05:57,  2.22it/s]

Numeric Column to label: soccer-germany-Players-1999-2000+column_7
Loading Table: soccer-germany-Players-1999-2000....
Loaded with 423 rows
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 14.828859343468753
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 12.31778889516134
Loading Table: soccer-big5Leagues-Players-2013-2014....
Loaded with 5465 rows
EMD: 14.794924759837972
Loading Table: soccer-big5Leagues-Players-2016-2017....
Loaded with 6817 rows
EMD: 14.788977646748535
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 13.142061563882958
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 13.296732844159777
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 12.075241110211294
Loading Table: soccer-big5Leagues-Players-2017-2018....
Loaded with 6624 rows
EMD: 14.765415030005457
Loading Table: soccer-big5Leagues-Players-201

# J n' G approach using SportsDB

In [None]:
import sys
import os
from dotenv import load_dotenv

load_dotenv(override=True)
sys.path.append(os.environ["WORKING_DIR"])
from pyspark.sql.types import StringType, IntegerType, FloatType, DoubleType, StructType, StructField
from pyspark.sql.functions import udf, col, pandas_udf, PandasUDFType, collect_list, count, avg, lit, mean, stddev, monotonically_increasing_id, row_number
from pyspark.sql import SparkSession
from pyspark import SparkConf, SparkContext
import pyspark.pandas as ps
from snorkel.labeling import PandasLFApplier
import configargparse
from snorkel.preprocess import preprocessor
from snorkel.labeling import labeling_function
from os.path import join
import copy
import json
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from sklearn.metrics import classification_report
import numpy as np
from data_loader.utils import load_public_bi_table_by_cols, load_sportsDB_soccer_table
from scipy.stats import wasserstein_distance, kruskal



# create and register UDF-Function to calc EMD-Distance

@udf(returnType=FloatType())
def emd_UDF(col1, col2) -> FloatType:
    return float(wasserstein_distance(col1, col2))


conf = SparkConf()
# conf.set("spark.executor.instances","2")
# conf.set("spark.executor.cores","2")
conf.set("spark.executor.memory", "150g")
conf.set("spark.driver.memory", "150g")
conf.set("spark.memory.offHeap.enabled", "true")
conf.set("spark.memory.offHeap.size", "50g")
#conf.set("spark.sql.execution.arrow.enabled", "true")
conf.setMaster("local[*]")
conf.setAppName("STEER")

spark = SparkSession.builder.config(conf=conf).getOrCreate()

spark.udf.register("emd_UDF", emd_UDF)

labeled_unlabeled_test_split_path = join(os.environ["WORKING_DIR"], "data",
                                         "extract", "out",
                                         "labeled_unlabeled_test_split")

valid_headers_path = join(os.environ["WORKING_DIR"], "data", "extract", "out",
                          "valid_headers")

gen_train_data_path = join(os.environ["WORKING_DIR"], "labeling_functions", "combined_LFs",
                           "gen_training_data")

# numeric_types = ["X1B",
#                  "X2B",
#                  "X3B",
#                  "TB",
#                  "HR",
#                  "R",
#                  "BB",
#                  "AB",
#                  "GIDP",
#                  "HBP",
#                  "H",
#                  "SF",
#                  "SH",
#                  "SO",
#                  "iBB",
#                  "CS",
#                  "SB",
#                  "latitude",
#                  "longitude",
#                  "year"]

numeric_types = [
    "age",
    "gamesPlayed",
    "gamesStarted",
    "minutesPlayed",
    "matchesPlayedCompletely",
    "goals",
    "assists",
    "nonPenaltyGoals",
    "penaltiesScored",
    "penaltiesAttempted",
    "yellowCards",
    "redCards",
    "goalsPer90Min",
    "assistsPer90Min",
    "goalsPlusAssistsPer90Min",
    "nonPenaltyGoalsPer90Min",
    "nonPenaltyGoalsPlusAssists",
    "xGoals",
    "nonPenaltyXGoals",
    "xAssists",
    "nonPenaltyXGoalsPlusAssists",
    "xGoalsPer90Min",
    "xAssistsPer90Min",
    "xGoalsPlusAssistsPer90Min",
    "nonPenaltyXGoalsPer90Min",
    "nonPenaltyXGoalsPlusAssistsPer90Min"
]


# LabelEncoder
with open(
        join(os.environ["WORKING_DIR"], "data", "extract", "out",
             "valid_types", "types.json")) as f:
    valid_types = json.load(f)["type_sportsDB"]

label_enc = LabelEncoder()
label_enc.fit(valid_types)

PERCENTAGE_OF_OVERLAP_FOR_JOIN_CANDIDATES = 0.75
P_VALUE_CORRSTEERTION_ANALYZATION = 0.05
PERCENTAGE_THRESHOLD_UNIQUE_VALUES = 0.1

labeled_data_size = 1
unlabeled_data_size = "absolute"
test_data_size = 20.1
validation_on = "test"
gen_train_data = True
corpus = "sportsDB"
absolute_numbers = True
n_worker = 4
threshold_EMD = 0.01
max_group_size = 1
random_state = 2
table_frac = None
approach = 5

if absolute_numbers:
    unlabeled_data_size = "absolute"
    labeled_data_size = int(labeled_data_size)

#############
# Load data
#############

# load labeled data from labeled, unlabeled, test split file
with open(
        join(
            labeled_unlabeled_test_split_path,
            f"{corpus}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.json"
        )) as f:
    labeled_unlabeled_test_split_file = json.load(f)
    labeled_data_ids = labeled_unlabeled_test_split_file[
        f"labeled{labeled_data_size}"]
    if gen_train_data:
        if absolute_numbers:
            unlabeled_data_ids = labeled_unlabeled_test_split_file[
                f"unlabeled"]
        else:
            unlabeled_data_ids = labeled_unlabeled_test_split_file[
                f"unlabeled{unlabeled_data_size}"]
        print(f"Unlabeled Data: {len(unlabeled_data_ids)}")
    if validation_on == "unlabeled":
        test_data_ids = labeled_unlabeled_test_split_file[
            f"{validation_on}"]
    else:
        test_data_ids = labeled_unlabeled_test_split_file[
            f"{validation_on}{test_data_size}"]

print(f"Labeled Data: {len(labeled_data_ids)}")
print(f"Test Data: {len(test_data_ids)}")

# load the valid headers with real sem. types
valid_header_file = f"{corpus}_type_sportsDB.json"
valid_headers = join(valid_headers_path, valid_header_file)
with open(valid_headers, "r") as file:
    valid_headers = json.load(file)
# transform valid header into df to make it joinable with word embeddings
valid_header_df_data = []
for table in valid_headers.keys():
    for column in valid_headers[table].keys():
        valid_header_df_data.append([
            table, column, table + "+" + column,
            valid_headers[table][column]["semanticType"]
        ])
valid_header_df = pd.DataFrame(
    valid_header_df_data,
    columns=["table", "column", "dataset_id", "semanticType"])

#############
# Build LF
#############
@labeling_function()
def joined_n_grouped_EMD(numeric_column_to_label):
    print("Numeric Column to label: " +
            numeric_column_to_label["dataset_id"])
    # 1. search for already labeled string types in the same table of the unlabeled numeric column that we want to label next
    string_already_labeled_cols_in_table = total_labeled_data_df[
        total_labeled_data_df["table"] == numeric_column_to_label["table"]]
    string_already_labeled_cols_in_table = string_already_labeled_cols_in_table.loc[
        ~string_already_labeled_cols_in_table["semanticType"].isin(numeric_types)]

    # load the table with the numeric column to label
    cols_to_load = [numeric_column_to_label["dataset_id"]]
    string_cols_to_load = list(
        string_already_labeled_cols_in_table["dataset_id"].values)
    cols_to_load = cols_to_load + string_cols_to_load
    df_cols_to_load = pd.DataFrame({"col_num": [int(col.split(
        "+")[1].split("_")[1]) for col in cols_to_load], "col_header": cols_to_load}).sort_values(by="col_num")

    df_table_with_n_col_to_label = load_public_bi_table_by_cols(numeric_column_to_label["table"].split(
        "_")[0], numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)
    # typing of the cols for pyspark dataframe
    for string_col in string_cols_to_load:
        df_table_with_n_col_to_label[string_col] = df_table_with_n_col_to_label[string_col].astype(
            str)

    # 2. search for a equivalent table which has the same types already labeled to make a join over that
    joinable_tables = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
        string_already_labeled_cols_in_table["semanticType"])].drop_duplicates()
    joinable_tables_sorted = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby(
        "table").count().sort_values(by=["semanticType"], ascending=False)

    # 3. Iterate over all joinable tables. But before doing the join, check if the table are suitable for a join n group by
    # - there must be at least one numeric column in the joinable table which is already labeled
    # - do the table join n group by only for string cols, where are overlapping values are present
    results = []
    for joinable_table, row in joinable_tables_sorted.iterrows():
        # first check if there is a numeric column in the table which is already labeled with an numeric type
        numerics_in_joinable_table = total_labeled_data_df[
            total_labeled_data_df["table"] == joinable_table]
        numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
            numeric_types)]

        if len(numerics_in_joinable_table) > 0:
            strings_in_joinable_table = joinable_tables[joinable_tables["table"]
                                                        == joinable_table]
            # load the joinable table, do the join n groupy by EMD measurement
            # before join, check if there are values overlapping at all
            # this method can also be use in case of two or more candidates with same semantic type which can be use in principle for a join
            # than use the column with more overlapping values
            cols_to_load_for_joinable_table = list(
                numerics_in_joinable_table["dataset_id"].values)
            cols_to_load_for_joinable_table = cols_to_load_for_joinable_table + \
                list(strings_in_joinable_table["dataset_id"].values)

            df_cols_to_load_for_joinable_table = pd.DataFrame({"col_num": [int(col.split(
                "+")[1].split("_")[1]) for col in cols_to_load_for_joinable_table], "col_header": cols_to_load_for_joinable_table}).sort_values(by="col_num")
            df_joinable_table = load_public_bi_table_by_cols(joinable_table.split(
                "_")[0], joinable_table, usecols=df_cols_to_load_for_joinable_table["col_num"].values, col_headers=df_cols_to_load_for_joinable_table["col_header"].values, frac=table_frac)
            # typing of the cols for pyspark dataframe
            for string_col in list(strings_in_joinable_table["dataset_id"].values):
                df_joinable_table[string_col] = df_joinable_table[string_col].astype(
                    str)

            # only use cols for join where are overlapping values are present
            cols_to_join = []
            for string_already_labeled_col_in_table_i, string_already_labeled_col_in_table_row in string_already_labeled_cols_in_table.iterrows():
                for strings_in_joinable_table_i, strings_in_joinable_table_row in strings_in_joinable_table[strings_in_joinable_table["semanticType"] == string_already_labeled_col_in_table_row["semanticType"]].iterrows():
                    percentage_of_overlap = (len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(
                        df_joinable_table[strings_in_joinable_table_row['dataset_id']]))) / len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))
                    if percentage_of_overlap >= PERCENTAGE_OF_OVERLAP_FOR_JOIN_CANDIDATES:
                        cols_to_join.append(
                            [string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"]])

            # Do the join n groub by over the founded column and calc the EMD for each builded group between every numeric columns
            # join
            if len(cols_to_join) == 0:
                continue

            # drop na values in cols to join in both tables => NaN values leads two hich memory usage in the merge
            df_table_with_n_col_to_label.dropna(
                subset=[col1 for [col1, col2] in cols_to_join], inplace=True)
            df_joinable_table.dropna(
                subset=[col2 for [col1, col2] in cols_to_join], inplace=True)

            # Transform the pandas dfs into pyspark tables
            sdf_1 = spark.createDataFrame(df_table_with_n_col_to_label)
            sdf_1.createOrReplaceTempView(numeric_column_to_label["table"])
            sdf_2 = spark.createDataFrame(df_joinable_table)
            sdf_2.createOrReplaceTempView(joinable_table)

            # Approach 3
            join_condition = "ON (" + " AND ".join(
                map(lambda join_att: f"`{join_att[0]}` = `{join_att[1]}`", cols_to_join))
            # print(join_condition)
            projection_list = " , ".join(map(lambda attr: f"`{attr}`", list(set([col1 for [col1, col2] in cols_to_join])) + [
                numeric_column_to_label["dataset_id"]] + list(numerics_in_joinable_table["dataset_id"].values)))
            # print(projection_list)
            # print(
            #    f"SELECT {projection_list} FROM {numeric_column_to_label['table']} JOIN {joinable_table} {join_condition})")
            sql_df = spark.sql(
                f"SELECT {projection_list} FROM {numeric_column_to_label['table']} JOIN {joinable_table} {join_condition})")
            # filter out null tupels with null values
            sql_df = sql_df.dropna(subset=list(
                map(lambda cur_col: "`{cur_col}`".format(cur_col=cur_col),
                    sql_df.columns)))

            # print(sql_df.show())
            for i, numeric_col_in_joinable_table in numerics_in_joinable_table.iterrows():
                print("EMD Calc:")
                print(numeric_col_in_joinable_table["dataset_id"])
                # check first which values have group size fewer than max group size
                valid_group_size_values = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(count(
                    f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count")).where(col("count") <= max_group_size)
                # print(valid_group_size_values.show())
                #print(f"Whole df count: {sql_df.count()}")
                #print(f"valid group size count: {valid_group_size_values.count()}")
                if valid_group_size_values.count() == 0:
                    continue
                # filter out just the values which results in fewer than max_group_size
                #sql_df = sql_df.join(valid_group_size_values, [col_1 for [col_1, col_2] in cols_to_join], "leftsemi")
                #print(f"Whole df count after filtering: {sql_df.count()}")
                cur_df = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(emd_UDF(collect_list(
                    f"`{numeric_column_to_label['dataset_id']}`"), collect_list(f"`{numeric_col_in_joinable_table['dataset_id']}`")).alias("EMD"), count(f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count"))  # .select(col("EMD"), col("count")).where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
                #print(f"Länge: {cur_df.count()}")
                #print(f"DF: {cur_df.show()}")
                cur_df = cur_df.select(
                    "*").where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
                # print(cur_df.show())
                #print(f"Länge: {cur_df.count()}")
                #print(f"DF: {cur_df.show()}")
                grouped_n_joined_emd = cur_df.collect()[0]["avg(EMD)"]
                print(grouped_n_joined_emd)
                results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"],
                                numeric_col_in_joinable_table["dataset_id"], numeric_col_in_joinable_table["semanticType"], grouped_n_joined_emd])

            # if len(results) >= 50:
            #     break

        # drop temp views in spark
        spark.catalog.dropTempView(numeric_column_to_label["table"])
        spark.catalog.dropTempView(joinable_table)

        # print(f"Table-Join: {numeric_column_to_label['table']} <-> {joinable_table} Cols: {cols_to_join}")
        # df_joined = df_table_with_n_col_to_label.merge(df_joinable_table, left_on=[col1 for [
        #                                                col1, col2] in cols_to_join], right_on=[col2 for [col1, col2] in cols_to_join])
        # print("Finished Table-Join")
        # # keep only the labeled cols from table to label and the numeric cols from both
        # # df_joined = df_joined[list(set([col1 for [col1, col2] in cols_to_join])) + [numeric_column_to_label["dataset_id"]] + list(
        # #     numerics_in_joinable_table["dataset_id"].values)]
        # # delete the labeled string cols from the joinable table
        # #df_joined.drop(columns=list(set([col2 for [col1, col2] in cols_to_join])), inplace=True)
        # # group by
        # print(f"Group BY: {cols_to_join}")
        # df_joined_n_grouped = df_joined.groupby(
        #     [col1 for [col1, col2] in cols_to_join])

        # # iterate over the combinations of to label numeric column and already labeled numeric column
        # for i, numeric_col_in_joinable_table in numerics_in_joinable_table.iterrows():
        #     print("EMD Calc:")
        #     print(numeric_column_to_label["dataset_id"], numeric_col_in_joinable_table["dataset_id"])
        #     grouped_n_joined_emd = df_joined_n_grouped.apply(lambda x: wasserstein_distance(
        #         x[numeric_column_to_label["dataset_id"]], x[numeric_col_in_joinable_table["dataset_id"]]) if len(x) <= max_group_size else None).mean()
        #     print(grouped_n_joined_emd)
        #     results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"],
        #                    numeric_col_in_joinable_table["dataset_id"], numeric_col_in_joinable_table["semanticType"], grouped_n_joined_emd])
        # if len(results) >= 50:
        #     break

    df_results = pd.DataFrame(results, columns=[
                                "unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD"])  # .sort_values(by="EMD")
    df_results = df_results[pd.to_numeric(
        df_results['EMD'], errors='coerce').notnull()]
    df_results = df_results.sort_values(by="EMD")
    df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "joined_n_grouped_EMD", "results",
                            f"{numeric_column_to_label['dataset_id']}_appr{approach}_{max_group_size}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
    predicted_semantic_type = -1
    if len(df_results) > 0:
        predicted_semantic_type = df_results.iloc[0]["semantic_type"]
        predicted_semantic_type = label_enc.transform(
            [predicted_semantic_type])[0]

    return predicted_semantic_type

@labeling_function()
def joined_n_grouped_EMD_correlation(numeric_column_to_label):
    if numeric_column_to_label["dataset_id"] == "soccer-italy-Players-2002-2003+column_6":
        print("Here")
    print("Numeric Column to label: " +
            numeric_column_to_label["dataset_id"])
    # 1. search for already labeled string types in the same table of the unlabeled numeric column that we want to label next
    string_already_labeled_cols_in_table = total_labeled_data_df[
        total_labeled_data_df["table"] == numeric_column_to_label["table"]]
    string_already_labeled_cols_in_table = string_already_labeled_cols_in_table.loc[
        ~string_already_labeled_cols_in_table["semanticType"].isin(numeric_types)]

    # load the table with the numeric column to label
    cols_to_load = [numeric_column_to_label["dataset_id"]]
    string_cols_to_load = list(
        string_already_labeled_cols_in_table["dataset_id"].values)
    cols_to_load = cols_to_load + string_cols_to_load
    df_cols_to_load = pd.DataFrame({"col_num": [int(col.split(
        "+")[1].split("_")[1]) for col in cols_to_load], "col_header": cols_to_load}).sort_values(by="col_num")

    df_table_with_n_col_to_label = load_sportsDB_soccer_table(numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)
    # typing of the cols
    df_table_with_n_col_to_label = df_table_with_n_col_to_label[pd.to_numeric(df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]], errors="coerce").notnull()]
    df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]] = df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].astype(float)
    if len(df_table_with_n_col_to_label) == 0:
        return -1
    # typing of the cols for pyspark dataframe
    for string_col in string_cols_to_load:
        df_table_with_n_col_to_label[string_col] = df_table_with_n_col_to_label[string_col].astype(
            str)

    # correlation analysis of numeric_column_to_label to the already labeled string columns
    # correlation analysis from numeric col to label to the existing string based cols
    string_cols_with_corr = []

    for string_col in string_cols_to_load:
        # if string_col == "MLB_1+column_37":
        #     break
        print(string_col)
        groups = df_table_with_n_col_to_label[string_col].unique()
        print(len(groups))
        print(len(df_table_with_n_col_to_label))
        print(len(groups)/len(df_table_with_n_col_to_label))
        if len(groups)/len(df_table_with_n_col_to_label) > PERCENTAGE_THRESHOLD_UNIQUE_VALUES:
            print("To many groups in the column for kruskal test")
            continue
        kruskal_input_groups = []
        for group in groups:
            kruskal_input = df_table_with_n_col_to_label[df_table_with_n_col_to_label[string_col]
                                                            == group][numeric_column_to_label["dataset_id"]].dropna()
            if len(kruskal_input) == 0:
                continue
            kruskal_input_groups.append(kruskal_input)
        if len(kruskal_input_groups) < 2:
            print("kruskal input groups smaller than 2")
            continue
        try:
            F, p = kruskal(*kruskal_input_groups)
            print(f"Kruskal-Wallis test:  \tF:{F}, \tp:{p}")
            if p < P_VALUE_CORRSTEERTION_ANALYZATION:
                string_cols_with_corr.append(string_col)
        except Exception as e:
            print(e)

    string_cols_with_corr = string_already_labeled_cols_in_table.loc[string_already_labeled_cols_in_table["dataset_id"].isin(
        string_cols_with_corr)]

    if len(string_cols_with_corr) == 0:
        return -1

    # 2. search for a equivalent table which has the same types already labeled to make a join over that
    joinable_tables = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
        string_cols_with_corr["semanticType"])].drop_duplicates()

    joinable_tables_grouped = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby(
        "table").apply(lambda x: list(x["semanticType"].values)).reset_index().rename(columns={0: "semanticTypes"})
    joinable_tables_grouped = joinable_tables_grouped[joinable_tables_grouped["semanticTypes"].map(lambda x: set(
        x).intersection(set(string_cols_with_corr["semanticType"]))) == set(string_cols_with_corr["semanticType"])]

    # 3. check if wich tables are suitable for a join n groub by
    results = []
    for index, row in joinable_tables_grouped.iterrows():
        joinable_table = row["table"]
        # if row["table"] not in ["MLB_11"]:#["CommonGovernment_2", "MLB_3", "MLB_65", "MLB_26"]:
        #     continue
        print(joinable_table)
        # first check if there is a numeric column in the table which is alread labeled with an numeric type
        numerics_in_joinable_table = total_labeled_data_df[
            total_labeled_data_df["table"] == joinable_table]
        numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
            numeric_types)]
        # print(numerics_in_joinable_table)

        # only if there are numeric cols already labeled in the joinable table, do the join n group by an try to match
        if len(numerics_in_joinable_table) > 0:
            # print(joinable_table)
            # print(numerics_in_joinable_table)
            strings_in_joinable_table = joinable_tables[joinable_tables["table"]
                                                        == joinable_table]
            # print(strings_in_joinable_table)

            # load the joinable table, do the join n groupy by EMD measurement
            # before join, check if there are values overlapping at all
            # this method can also be use in case of two or more candidates with same semantic type which can be use in principle for a join
            # than use the column with more overlapping values
            cols_to_load_for_joinable_table = list(
                numerics_in_joinable_table["dataset_id"].values)
            cols_to_load_for_joinable_table = cols_to_load_for_joinable_table + \
                list(strings_in_joinable_table["dataset_id"].values)
            # print(cols_to_load_for_joinable_table)
            df_cols_to_load_for_joinable_table = pd.DataFrame({"col_num": [int(col.split(
                "+")[1].split("_")[1]) for col in cols_to_load_for_joinable_table], "col_header": cols_to_load_for_joinable_table}).sort_values(by="col_num")
            df_joinable_table = load_sportsDB_soccer_table(joinable_table, usecols=df_cols_to_load_for_joinable_table["col_num"].values, col_headers=df_cols_to_load_for_joinable_table["col_header"].values)
            
            # typing of the cols for pyspark dataframe
            for numeric_col in list(numerics_in_joinable_table["dataset_id"].values):
                df_joinable_table = df_joinable_table[pd.to_numeric(df_joinable_table[numeric_col], errors="coerce").notnull()]
                
            if len(df_joinable_table) == 0:
                continue
            # typing of the cols for pyspark dataframe
            for string_col in list(strings_in_joinable_table["dataset_id"].values):
                df_joinable_table[string_col] = df_joinable_table[string_col].astype(
                    str)
            
            
            # print(df_joinable_table.head())
            # only use cols for join where are overlapping values are present
            cols_to_join = []
            for string_already_labeled_col_in_table_i, string_already_labeled_col_in_table_row in string_cols_with_corr.iterrows():
                for strings_in_joinable_table_i, strings_in_joinable_table_row in strings_in_joinable_table[strings_in_joinable_table["semanticType"] == string_already_labeled_col_in_table_row["semanticType"]].iterrows():
                    # print(
                    #    string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"])
                    # print(
                    #    f"overlapping values: {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(df_joinable_table[strings_in_joinable_table_row['dataset_id']]))} from {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))} unique")
                    percentage_of_overlap = (len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(
                        df_joinable_table[strings_in_joinable_table_row['dataset_id']]))) / len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))
                    # print(percentage_of_overlap)
                    if percentage_of_overlap >= 0.75:
                        if string_already_labeled_col_in_table_row["dataset_id"] not in [x[0] for x in cols_to_join]:
                            cols_to_join.append(
                                [string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"]])
            #print("Cols to join:", cols_to_join )

            # if not only all correlated columns are available for join n groupy by in the other table => not use this table
            if len(cols_to_join) != len(string_cols_with_corr):
                continue

            # drop na values in cols to join in both tables => NaN values leads to high memory usage in the merge
            df_table_with_n_col_to_label.dropna(
                subset=[col1 for [col1, col2] in cols_to_join], inplace=True)
            df_joinable_table.dropna(
                subset=[col2 for [col1, col2] in cols_to_join], inplace=True)

            # Transform the pandas dfs into pyspark tables
            sdf_1 = spark.createDataFrame(df_table_with_n_col_to_label)
            #sdf_1.createOrReplaceTempView(numeric_column_to_label["table"])
            table_view_1 = "numeric_column_to_label"
            sdf_1.createOrReplaceTempView(table_view_1)
            sdf_2 = spark.createDataFrame(df_joinable_table)
            #sdf_2.createOrReplaceTempView(joinable_table)
            table_view_2 = "joinable_table"
            sdf_2.createOrReplaceTempView(table_view_2)

            # Approach 3
            join_condition = "ON (" + " AND ".join(
                map(lambda join_att: f"`{join_att[0]}` = `{join_att[1]}`", cols_to_join))
            # print(join_condition)
            projection_list = " , ".join(map(lambda attr: f"`{attr}`", list(set([col1 for [col1, col2] in cols_to_join])) + [
                numeric_column_to_label["dataset_id"]] + list(numerics_in_joinable_table["dataset_id"].values)))
            # print(projection_list)
            # print(
            #    f"SELECT {projection_list} FROM {numeric_column_to_label['table']} JOIN {joinable_table} {join_condition})")
            sql_df = spark.sql(
                f"SELECT {projection_list} FROM {table_view_1} JOIN {table_view_2} {join_condition})")
            # filter out null tupels with null values
            sql_df = sql_df.dropna(subset=list(
                map(lambda cur_col: "`{cur_col}`".format(cur_col=cur_col),
                    sql_df.columns)))

            # print(sql_df.show())
            for i, numeric_col_in_joinable_table in numerics_in_joinable_table.iterrows():
                print("EMD Calc:")
                print(numeric_col_in_joinable_table["dataset_id"])
                # check first which values have group size fewer than max group size
                valid_group_size_values = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(count(
                    f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count")).where(col("count") <= max_group_size)
                sql_df.createOrReplaceTempView("sql_df")
                valid_group_size_values.createOrReplaceTempView("sdf_valid_groups")
                
                # preselect only the instances with the max_group_size condition
                join_condition = "ON (" + " AND ".join(
                    map(lambda join_att: f"sql_df.`{join_att[0]}` = sdf_valid_groups.`{join_att[0]}`", cols_to_join))
                #print(join_condition)
                projection_list = " , ".join(map(lambda attr: f"sql_df.`{attr}`", list(set([col1 for [col1, col2] in cols_to_join])) + [
                                            numeric_column_to_label["dataset_id"]] + list(numerics_in_joinable_table["dataset_id"].values)))
                #print(
                #    f"SELECT {projection_list} FROM sql_df JOIN sdf_valid_groups {join_condition})")
                sql_df = spark.sql(
                    f"SELECT {projection_list} FROM sql_df JOIN sdf_valid_groups {join_condition})")
                #print("sql_df count: ",sql_df.count())
                #print("valid group size count: ",valid_group_size_values.select(sum("count")).collect())
                #print(f"Whole df count: {sql_df.count()}")
                #print(f"valid group size count: {valid_group_size_values.count()}")
                if valid_group_size_values.count() == 0:
                    print("no valid group size")
                    continue
                
                # filter out just the values which results in fewer than max_group_size
                #sql_df = sql_df.join(valid_group_size_values, [col_1 for [col_1, col_2] in cols_to_join], "leftsemi")
                #print(f"Whole df count after filtering: {sql_df.count()}")
                cur_df = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(emd_UDF(collect_list(
                    f"`{numeric_column_to_label['dataset_id']}`"), collect_list(f"`{numeric_col_in_joinable_table['dataset_id']}`")).alias("EMD"), count(f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count"))  # .select(col("EMD"), col("count")).where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
                #print(f"Länge: {cur_df.count()}")
                #print(f"DF: {cur_df.show()}")
                # print(cur_df.show())
                cur_df = cur_df.select(
                    "*").where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
                #cur_df = cur_df.select("*").where(col("count") <= max_group_size).groupby().agg(func.percentile_approx("EMD", 0.5).alias("med(EMD)"))
                # print(cur_df.show())
                #print(cur_df.select("*").where(col("count") <= max_group_size).count())
                # grouped_n_joined_emds = cur_df.select("EMD").where(col("count") >= max_group_size).toPandas()
                # grouped_n_joined_emds["unlabeled_col"] = numeric_column_to_label["dataset_id"]
                # grouped_n_joined_emds["real_semantic_type"] = numeric_column_to_label["semanticType"]
                # grouped_n_joined_emds["labeled_col"] = numeric_col_in_joinable_table["dataset_id"]
                # grouped_n_joined_emds["semantic_type"] = numeric_col_in_joinable_table["semanticType"]
                # grouped_n_joined_emds["EMD_treshold"] = df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean()
                # print(grouped_n_joined_emds)
                #print(f"Länge: {cur_df.count()}")
                #print(f"DF: {cur_df.show()}")
                grouped_n_joined_emd = cur_df.collect()[0]["avg(EMD)"]
                print(grouped_n_joined_emd)
                if numeric_column_to_label["dataset_id"] == "soccer-italy-Players-2002-2003+column_6":
                    print("Here")
                results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"], numeric_col_in_joinable_table["dataset_id"],
                                numeric_col_in_joinable_table["semanticType"], grouped_n_joined_emd, df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean(), df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].std()])
                # results.extend(grouped_n_joined_emds.values.tolist())

            # drop temp views in spark
            spark.catalog.dropTempView(table_view_1)
            spark.catalog.dropTempView(table_view_2)
    df_results = pd.DataFrame(results, columns=[
                                "unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD", "mean", "std"])  # .sort_values(by="EMD")
    df_results = df_results[pd.to_numeric(
        df_results['EMD'], errors='coerce').notnull()]
    df_results = df_results.sort_values(by="EMD")
    df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "joined_n_grouped_EMD", "results", corpus,
                            f"{numeric_column_to_label['dataset_id']}_appr{approach}_{max_group_size}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
    predicted_semantic_type = -1
    if len(df_results) > 0:
        predicted_semantic_type = df_results.iloc[0]["semantic_type"]
        predicted_semantic_type = label_enc.transform(
            [predicted_semantic_type])[0]

    return predicted_semantic_type

@labeling_function()
def normal_EMD(numeric_column_to_label):
    print("Numeric Column to label: " +
            numeric_column_to_label["dataset_id"])
    # load the table with the numeric column to label
    cols_to_load = [numeric_column_to_label["dataset_id"]]
    df_cols_to_load = pd.DataFrame({"col_num": [int(col.split(
        "+")[1].split("_")[1]) for col in cols_to_load], "col_header": cols_to_load}).sort_values(by="col_num")

    df_table_with_n_col_to_label = load_public_bi_table_by_cols(numeric_column_to_label["table"].split(
        "_")[0], numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)

    df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]] = pd.to_numeric(
        df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]], errors="coerce")
    df_table_with_n_col_to_label.dropna(inplace=True)
    if len(df_table_with_n_col_to_label) == 0:
        return -1

    # search all already labeled numeric cols in the corpus
    already_labeled_numeric_cols = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
        numeric_types)].drop_duplicates()

    # iterrate over all alread labeled numeric col and do the EMD measure
    results = []
    for index, row in already_labeled_numeric_cols.iterrows():

        df_table_with_labeled_numeric = load_public_bi_table_by_cols(row["table"].split(
            "_")[0], row["table"], usecols=[int(row["column"].split("_")[1])], col_headers=[row["dataset_id"]], frac=table_frac)

        df_table_with_labeled_numeric[row["dataset_id"]] = pd.to_numeric(
            df_table_with_labeled_numeric[row["dataset_id"]], errors="coerce")
        df_table_with_labeled_numeric.dropna(inplace=True)
        if len(df_table_with_labeled_numeric[row["dataset_id"]].to_list()) == 0:
            continue

        # EMD calc
        emd = wasserstein_distance(df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].to_list(
        ), df_table_with_labeled_numeric[row["dataset_id"]].to_list())
        print(f"EMD: {emd}")
        results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"],
                        row["dataset_id"], row["semanticType"], emd, df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean(), df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].std()])

    df_results = pd.DataFrame(results, columns=[
                                "unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD", "mean", "std"])  # .sort_values(by="EMD")
    df_results = df_results[pd.to_numeric(
        df_results['EMD'], errors='coerce').notnull()]
    df_results = df_results.sort_values(by="EMD")
    df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "normal_EMD", "results",
                            f"{numeric_column_to_label['dataset_id']}_appr{approach}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
    predicted_semantic_type = -1
    if len(df_results) > 0:
        predicted_semantic_type = df_results.iloc[0]["semantic_type"]
        predicted_semantic_type = label_enc.transform(
            [predicted_semantic_type])[0]

    return predicted_semantic_type


@labeling_function()
def joined_n_grouped_EMD_correlation_noAvg(numeric_column_to_label):
    print("Numeric Column to label: " +
            numeric_column_to_label["dataset_id"])
    # 1. search for already labeled string types in the same table of the unlabeled numeric column that we want to label next
    string_already_labeled_cols_in_table = total_labeled_data_df[
        total_labeled_data_df["table"] == numeric_column_to_label["table"]]
    string_already_labeled_cols_in_table = string_already_labeled_cols_in_table.loc[
        ~string_already_labeled_cols_in_table["semanticType"].isin(numeric_types)]

    # load the table with the numeric column to label
    cols_to_load = [numeric_column_to_label["dataset_id"]]
    string_cols_to_load = list(
        string_already_labeled_cols_in_table["dataset_id"].values)
    #print(string_cols_to_load)
    cols_to_load = cols_to_load + string_cols_to_load
    df_cols_to_load = pd.DataFrame({"col_num": [int(col.split(
        "+")[1].split("_")[1]) for col in cols_to_load], "col_header": cols_to_load}).sort_values(by="col_num")
    #print(df_cols_to_load)
    df_table_with_n_col_to_label = load_sportsDB_soccer_table(numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)
    # typing of the cols
    df_table_with_n_col_to_label = df_table_with_n_col_to_label[pd.to_numeric(df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]], errors="coerce").notnull()]
    df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]] = df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].astype(float)
    if len(df_table_with_n_col_to_label) == 0:
        return -1
    # typing of the cols for pyspark dataframe
    for string_col in string_cols_to_load:
        df_table_with_n_col_to_label[string_col] = df_table_with_n_col_to_label[string_col].astype(
            str)

    # correlation analysis of numeric_column_to_label to the already labeled string columns
    # correlation analysis from numeric col to label to the existing string based cols
    string_cols_with_corr = []

    for string_col in string_cols_to_load:
        # if string_col == "MLB_1+column_37":
        #     break
        print(string_col)
        groups = df_table_with_n_col_to_label[string_col].unique()
        print(len(groups))
        print(len(df_table_with_n_col_to_label))
        print(len(groups)/len(df_table_with_n_col_to_label))
        if len(groups)/len(df_table_with_n_col_to_label) > PERCENTAGE_THRESHOLD_UNIQUE_VALUES:
            print("To many groups in the column for kruskal test")
            continue
        kruskal_input_groups = []
        for group in groups:
            kruskal_input = df_table_with_n_col_to_label[df_table_with_n_col_to_label[string_col]
                                                            == group][numeric_column_to_label["dataset_id"]].dropna()
            if len(kruskal_input) == 0:
                continue
            kruskal_input_groups.append(kruskal_input)
        if len(kruskal_input_groups) < 2:
            print("kruskal input groups smaller than 2")
            continue
        try:
            F, p = kruskal(*kruskal_input_groups)
            print(f"Kruskal-Wallis test:  \tF:{F}, \tp:{p}")
            if p < P_VALUE_CORRSTEERTION_ANALYZATION:
                string_cols_with_corr.append(string_col)
        except Exception as e:
            print(e)

    string_cols_with_corr = string_already_labeled_cols_in_table.loc[string_already_labeled_cols_in_table["dataset_id"].isin(
        string_cols_with_corr)]

    if len(string_cols_with_corr) == 0:
        return -1

    # 2. search for a equivalent table which has the same types already labeled to make a join over that
    joinable_tables = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
        string_cols_with_corr["semanticType"])].drop_duplicates()

    joinable_tables_grouped = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby(
        "table").apply(lambda x: list(x["semanticType"].values)).reset_index().rename(columns={0: "semanticTypes"})
    joinable_tables_grouped = joinable_tables_grouped[joinable_tables_grouped["semanticTypes"].map(lambda x: set(
        x).intersection(set(string_cols_with_corr["semanticType"]))) == set(string_cols_with_corr["semanticType"])]

    # 3. check if wich tables are suitable for a join n groub by
    results = []
    for index, row in joinable_tables_grouped.iterrows():
        joinable_table = row["table"]
        # if row["table"] not in ["MLB_11"]:#["CommonGovernment_2", "MLB_3", "MLB_65", "MLB_26"]:
        #     continue
        print(joinable_table)
        # first check if there is a numeric column in the table which is alread labeled with an numeric type
        numerics_in_joinable_table = total_labeled_data_df[
            total_labeled_data_df["table"] == joinable_table]
        numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
            numeric_types)]
        # print(numerics_in_joinable_table)

        # only if there are numeric cols already labeled in the joinable table, do the join n group by an try to match
        if len(numerics_in_joinable_table) > 0:
            # print(joinable_table)
            # print(numerics_in_joinable_table)
            strings_in_joinable_table = joinable_tables[joinable_tables["table"]
                                                        == joinable_table]
            # print(strings_in_joinable_table)

            # load the joinable table, do the join n groupy by EMD measurement
            # before join, check if there are values overlapping at all
            # this method can also be use in case of two or more candidates with same semantic type which can be use in principle for a join
            # than use the column with more overlapping values
            cols_to_load_for_joinable_table = list(
                numerics_in_joinable_table["dataset_id"].values)
            cols_to_load_for_joinable_table = cols_to_load_for_joinable_table + \
                list(strings_in_joinable_table["dataset_id"].values)
            # print(cols_to_load_for_joinable_table)
            df_cols_to_load_for_joinable_table = pd.DataFrame({"col_num": [int(col.split(
                "+")[1].split("_")[1]) for col in cols_to_load_for_joinable_table], "col_header": cols_to_load_for_joinable_table}).sort_values(by="col_num")
            df_joinable_table = load_sportsDB_soccer_table(joinable_table, usecols=df_cols_to_load_for_joinable_table["col_num"].values, col_headers=df_cols_to_load_for_joinable_table["col_header"].values)
            
            # typing of the cols for pyspark dataframe
            for numeric_col in list(numerics_in_joinable_table["dataset_id"].values):
                df_joinable_table = df_joinable_table[pd.to_numeric(df_joinable_table[numeric_col], errors="coerce").notnull()]
                
            if len(df_joinable_table) == 0:
                continue
            # typing of the cols for pyspark dataframe
            for string_col in list(strings_in_joinable_table["dataset_id"].values):
                df_joinable_table[string_col] = df_joinable_table[string_col].astype(
                    str)
            
            
            # print(df_joinable_table.head())
            # only use cols for join where are overlapping values are present
            cols_to_join = []
            for string_already_labeled_col_in_table_i, string_already_labeled_col_in_table_row in string_cols_with_corr.iterrows():
                for strings_in_joinable_table_i, strings_in_joinable_table_row in strings_in_joinable_table[strings_in_joinable_table["semanticType"] == string_already_labeled_col_in_table_row["semanticType"]].iterrows():
                    # print(
                    #    string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"])
                    # print(
                    #    f"overlapping values: {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(df_joinable_table[strings_in_joinable_table_row['dataset_id']]))} from {len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))} unique")
                    percentage_of_overlap = (len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]) & set(
                        df_joinable_table[strings_in_joinable_table_row['dataset_id']]))) / len(set(df_table_with_n_col_to_label[string_already_labeled_col_in_table_row['dataset_id']]))
                    # print(percentage_of_overlap)
                    if percentage_of_overlap >= 0.75:
                        if string_already_labeled_col_in_table_row["dataset_id"] not in [x[0] for x in cols_to_join]:
                            cols_to_join.append(
                                [string_already_labeled_col_in_table_row["dataset_id"], strings_in_joinable_table_row["dataset_id"]])
            #print("Cols to join:", cols_to_join )

            # if not only all correlated columns are available for join n groupy by in the other table => not use this table
            if len(cols_to_join) != len(string_cols_with_corr):
                continue

            # drop na values in cols to join in both tables => NaN values leads to high memory usage in the merge
            df_table_with_n_col_to_label.dropna(
                subset=[col1 for [col1, col2] in cols_to_join], inplace=True)
            df_joinable_table.dropna(
                subset=[col2 for [col1, col2] in cols_to_join], inplace=True)

            # Transform the pandas dfs into pyspark tables
            sdf_1 = spark.createDataFrame(df_table_with_n_col_to_label)
            #sdf_1.createOrReplaceTempView(numeric_column_to_label["table"])
            table_view_1 = "numeric_column_to_label"
            sdf_1.createOrReplaceTempView(table_view_1)
            sdf_2 = spark.createDataFrame(df_joinable_table)
            #sdf_2.createOrReplaceTempView(joinable_table)
            table_view_2 = "joinable_table"
            sdf_2.createOrReplaceTempView(table_view_2)

            # Approach 3
            join_condition = "ON (" + " AND ".join(
                map(lambda join_att: f"`{join_att[0]}` = `{join_att[1]}`", cols_to_join))
            # print(join_condition)
            projection_list = " , ".join(map(lambda attr: f"`{attr}`", list(set([col1 for [col1, col2] in cols_to_join])) + [
                numeric_column_to_label["dataset_id"]] + list(numerics_in_joinable_table["dataset_id"].values)))
            # print(projection_list)
            # print(
            #    f"SELECT {projection_list} FROM {numeric_column_to_label['table']} JOIN {joinable_table} {join_condition})")
            sql_df = spark.sql(
                f"SELECT {projection_list} FROM {table_view_1} JOIN {table_view_2} {join_condition})")
            # filter out null tupels with null values
            sql_df = sql_df.dropna(subset=list(
                map(lambda cur_col: "`{cur_col}`".format(cur_col=cur_col),
                    sql_df.columns)))

            # print(sql_df.show())
            for i, numeric_col_in_joinable_table in numerics_in_joinable_table.iterrows():
                print("EMD Calc:")
                print(numeric_col_in_joinable_table["dataset_id"])
                # check first which values have group size fewer than max group size
                valid_group_size_values = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(count(
                    f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count")).where(col("count") <= max_group_size)
                sql_df.createOrReplaceTempView("sql_df")
                valid_group_size_values.createOrReplaceTempView("sdf_valid_groups")
                
                # preselect only the instances with the max_group_size condition
                join_condition = "ON (" + " AND ".join(
                    map(lambda join_att: f"sql_df.`{join_att[0]}` = sdf_valid_groups.`{join_att[0]}`", cols_to_join))
                #print(join_condition)
                projection_list = " , ".join(map(lambda attr: f"sql_df.`{attr}`", list(set([col1 for [col1, col2] in cols_to_join])) + [
                                            numeric_column_to_label["dataset_id"]] + list(numerics_in_joinable_table["dataset_id"].values)))
                #print(
                #    f"SELECT {projection_list} FROM sql_df JOIN sdf_valid_groups {join_condition})")
                sql_df = spark.sql(
                    f"SELECT {projection_list} FROM sql_df JOIN sdf_valid_groups {join_condition})")
                #print("sql_df count: ",sql_df.count())
                #print("valid group size count: ",valid_group_size_values.select(sum("count")).collect())
                #print(f"Whole df count: {sql_df.count()}")
                #print(f"valid group size count: {valid_group_size_values.count()}")
                if valid_group_size_values.count() == 0:
                    print("no valid group size")
                    continue
                
                # filter out just the values which results in fewer than max_group_size
                #sql_df = sql_df.join(valid_group_size_values, [col_1 for [col_1, col_2] in cols_to_join], "leftsemi")
                #print(f"Whole df count after filtering: {sql_df.count()}")
                cur_df = sql_df.groupby([col_1 for [col_1, col_2] in cols_to_join]).agg(emd_UDF(collect_list(
                    f"`{numeric_column_to_label['dataset_id']}`"), collect_list(f"`{numeric_col_in_joinable_table['dataset_id']}`")).alias("EMD"), count(f"`{numeric_col_in_joinable_table['dataset_id']}`").alias("count"))  # .select(col("EMD"), col("count")).where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
                #print(f"Länge: {cur_df.count()}")
                #print(f"DF: {cur_df.show()}")
                # print(cur_df.show())
                # cur_df = cur_df.select(
                #     "*").where(col("count") <= max_group_size).groupby().avg("EMD").alias("avg(EMD)")
                #cur_df = cur_df.select("*").where(col("count") <= max_group_size).groupby().agg(func.percentile_approx("EMD", 0.5).alias("med(EMD)"))
                # print(cur_df.show())
                #print(cur_df.select("*").where(col("count") <= max_group_size).count())
                grouped_n_joined_emds = cur_df.select("EMD").where(col("count") >= max_group_size).toPandas()
                grouped_n_joined_emds["unlabeled_col"] = numeric_column_to_label["dataset_id"]
                grouped_n_joined_emds["real_semantic_type"] = numeric_column_to_label["semanticType"]
                grouped_n_joined_emds["labeled_col"] = numeric_col_in_joinable_table["dataset_id"]
                grouped_n_joined_emds["semantic_type"] = numeric_col_in_joinable_table["semanticType"]
                grouped_n_joined_emds["mean"] = df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean()
                grouped_n_joined_emds["std"] = df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].std()
                # print(grouped_n_joined_emds)
                #print(f"Länge: {cur_df.count()}")
                #print(f"DF: {cur_df.show()}")
                #grouped_n_joined_emd = cur_df.collect()[0]["avg(EMD)"]
                #print(grouped_n_joined_emd)
                #results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"], numeric_col_in_joinable_table["dataset_id"],
                #               numeric_col_in_joinable_table["semanticType"], grouped_n_joined_emd, df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean(), df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].std()])
                results.extend(grouped_n_joined_emds.values.tolist())

            # drop temp views in spark
            spark.catalog.dropTempView(numeric_column_to_label["table"])
            spark.catalog.dropTempView(joinable_table)

    #df_results = pd.DataFrame(results, columns=[
    #                          "unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD", "mean", "std"])  # .sort_values(by="EMD")
    df_results = pd.DataFrame(results, columns=[
                                "EMD","unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "mean", "std"])
    
    df_results = df_results[pd.to_numeric(
        df_results['EMD'], errors='coerce').notnull()]
    df_results = df_results.sort_values(by="EMD")
    df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "joined_n_grouped_EMD", "results", corpus,
                            f"{numeric_column_to_label['dataset_id']}_appr{approach}_{max_group_size}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
    predicted_semantic_type = -1
    if len(df_results) > 0:
        predicted_semantic_type = df_results.iloc[0]["semantic_type"]
        predicted_semantic_type = label_enc.transform(
            [predicted_semantic_type])[0]

    return predicted_semantic_type


if gen_train_data:
    # filter out unlabeled data from valid_headers
    unlabeled_data_df = valid_header_df.loc[
        valid_header_df["dataset_id"].isin(unlabeled_data_ids)]

    # load already labeled data
    labeled_data_df = valid_header_df.loc[valid_header_df["dataset_id"].isin(
        labeled_data_ids)]

    # load already generated labeled train data
    # gen_labeled_data_df = pd.read_csv(join(gen_train_data_path, f"public_bi_gen_training_data_all_combined_maj_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), names=[
    #                                     "table", "column", "dataset_id", "semanticType"])

    ### drop duplicate only on cols "table", "column", "dataset_id"!!! This must be fixed. Actually it cant happen that there are two duplicte sets with different semantic types!
    total_labeled_data_df = labeled_data_df.drop_duplicates(subset=["table", "column", "dataset_id"])

    # only unlabaled columns of tyoe numeric
    numeric_unlabeled_data_df = unlabeled_data_df.loc[unlabeled_data_df["semanticType"].isin(
        numeric_types)]
    #numeric_unlabeled_data_df = numeric_unlabeled_data_df[678:]

    # define LF to apply
    if approach == 4:
        lfs = [joined_n_grouped_EMD]
    elif approach == 1:
        lfs = [normal_EMD]
    elif approach == 5:
        lfs = [joined_n_grouped_EMD_correlation]
    elif approach == 6:
        lfs = [joined_n_grouped_EMD_correlation_noAvg]

    # snorkel pandas applier for apply lfs to the data
    applier = PandasLFApplier(lfs=lfs)

    from multiprocessing import Pool
    from multiprocessing.pool import ThreadPool as Pool
    from functools import partial
    import numpy as np
    from tqdm.auto import tqdm

    def parallelize(data, func, num_of_processes=8):
        data_split = np.array_split(data, num_of_processes)
        pool = Pool(num_of_processes)
        #data = pd.concat(pool.map(func, data_split))
        data = np.concatenate(pool.map(func, data_split), axis=0)
        pool.close()
        pool.join()
        return data

    L_train = applier.apply(df=numeric_unlabeled_data_df)
    #L_train = parallelize(numeric_unlabeled_data_df,
    #                      applier.apply, n_worker)

    print(
        f"Length of labeled data: {len([x for x in L_train if x != -1])}")

    numeric_unlabeled_data_df["predicted_semantic_type"] = [
        label_enc.inverse_transform([x])[0] if x != -1 else "None"
        for x in L_train
    ]

    # save lf results
    if approach > 1:
        numeric_unlabeled_data_df.to_csv(join(
            os.environ["WORKING_DIR"], "labeling_functions",
            "numerics", "joined_n_grouped_EMD", "out", "results",
            f"{corpus}_{approach}_joined_n_grouped_EMD_{threshold_EMD}_{max_group_size}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"
        ),
            index=False)
    elif approach == 1:
        numeric_unlabeled_data_df.to_csv(join(
            os.environ["WORKING_DIR"], "labeling_functions",
            "numerics", "normal_EMD", "out", "results",
            f"{corpus}_normal_EMD_{threshold_EMD}_{max_group_size}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"
        ),
            index=False)


# normal EMD with pruning

In [None]:
import sys
import os
from dotenv import load_dotenv

load_dotenv(override=True)
sys.path.append(os.environ["WORKING_DIR"])
from pyspark.sql.types import StringType, IntegerType, FloatType, DoubleType, StructType, StructField
from pyspark.sql.functions import udf, col, pandas_udf, PandasUDFType, collect_list, count, avg, lit, mean, stddev, monotonically_increasing_id, row_number
from pyspark.sql import SparkSession
from pyspark import SparkConf, SparkContext
import pyspark.pandas as ps
from snorkel.labeling import PandasLFApplier
import configargparse
from snorkel.preprocess import preprocessor
from snorkel.labeling import labeling_function
from os.path import join
import copy
import json
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from sklearn.metrics import classification_report
import numpy as np
from data_loader.utils import load_public_bi_table_by_cols
from scipy.stats import wasserstein_distance, kruskal

# create and register UDF-Function to calc EMD-Distance
@udf(returnType=FloatType())
def emd_UDF(col1, col2) -> FloatType:
    return float(wasserstein_distance(col1, col2))


conf = SparkConf()
# conf.set("spark.executor.instances","2")
# conf.set("spark.executor.cores","2")
conf.set("spark.executor.memory", "150g")
conf.set("spark.driver.memory", "150g")
conf.set("spark.memory.offHeap.enabled", "true")
conf.set("spark.memory.offHeap.size", "50g")
#conf.set("spark.sql.execution.arrow.enabled", "true")
conf.setMaster("local[*]")
conf.setAppName("STEER")

spark = SparkSession.builder.config(conf=conf).getOrCreate()

spark.udf.register("emd_UDF", emd_UDF)

labeled_unlabeled_test_split_path = join(os.environ["WORKING_DIR"], "data",
                                         "extract", "out",
                                         "labeled_unlabeled_test_split")

valid_headers_path = join(os.environ["WORKING_DIR"], "data", "extract", "out",
                          "valid_headers")

gen_train_data_path = join(os.environ["WORKING_DIR"], "labeling_functions", "combined_LFs",
                           "gen_training_data")

numeric_types = ["X1B",
                 "X2B",
                 "X3B",
                 "TB",
                 "HR",
                 "R",
                 "BB",
                 "AB",
                 "GIDP",
                 "HBP",
                 "H",
                 "SF",
                 "SH",
                 "SO",
                 "iBB",
                 "CS",
                 "SB",
                 "latitude",
                 "longitude",
                 "year"]

# LabelEncoder
with open(
        join(os.environ["WORKING_DIR"], "data", "extract", "out",
             "valid_types", "types.json")) as f:
    valid_types = json.load(f)[os.environ["TYPENAME"]]

label_enc = LabelEncoder()
label_enc.fit(valid_types)

PERCENTAGE_OF_OVERLAP_FOR_JOIN_CANDIDATES = 0.75
P_VALUE_CORRSTEERTION_ANALYZATION = 0.05
PERCENTAGE_THRESHOLD_UNIQUE_VALUES = 0.1

labeled_data_size = 5
unlabeled_data_size = "absolute"
test_data_size = 20.0
validation_on = "test"
gen_train_data = True
corpus = "public_bi_num"
absolute_numbers = True
n_worker = 4
threshold_EMD_factor = 0.01
max_group_size = 4
random_state = 2
table_frac = None
approach = 1 # because this script is just for normal EMD 

if absolute_numbers:
    unlabeled_data_size = "absolute"
    labeled_data_size = int(labeled_data_size)

#############
# Load data
#############

# load labeled data from labeled, unlabeled, test split file
with open(
        join(
            labeled_unlabeled_test_split_path,
            f"{corpus}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.json"
        )) as f:
    labeled_unlabeled_test_split_file = json.load(f)
    labeled_data_ids = labeled_unlabeled_test_split_file[
        f"labeled{labeled_data_size}"]
    if gen_train_data:
        if absolute_numbers:
            unlabeled_data_ids = labeled_unlabeled_test_split_file[
                f"unlabeled"]
        else:
            unlabeled_data_ids = labeled_unlabeled_test_split_file[
                f"unlabeled{unlabeled_data_size}"]
        print(f"Unlabeled Data: {len(unlabeled_data_ids)}")
    if validation_on == "unlabeled":
        test_data_ids = labeled_unlabeled_test_split_file[
            f"{validation_on}"]
    else:
        test_data_ids = labeled_unlabeled_test_split_file[
            f"{validation_on}{test_data_size}"]

print(f"Labeled Data: {len(labeled_data_ids)}")
print(f"Test Data: {len(test_data_ids)}")

# load the valid headers with real sem. types
valid_header_file = f"{corpus}_{os.environ['TYPENAME']}.json"
valid_headers = join(valid_headers_path, valid_header_file)
with open(valid_headers, "r") as file:
    valid_headers = json.load(file)
# transform valid header into df to make it joinable with word embeddings
valid_header_df_data = []
for table in valid_headers.keys():
    for column in valid_headers[table].keys():
        valid_header_df_data.append([
            table, column, table + "+" + column,
            valid_headers[table][column]["semanticType"]
        ])
valid_header_df = pd.DataFrame(
    valid_header_df_data,
    columns=["table", "column", "dataset_id", "semanticType"])

#############
# Build LF
#############
@labeling_function()
def normal_EMD(numeric_column_to_label):
    print("Numeric Column to label: " +
            numeric_column_to_label["dataset_id"])
    # load the table with the numeric column to label
    cols_to_load = [numeric_column_to_label["dataset_id"]]
    df_cols_to_load = pd.DataFrame({"col_num": [int(col.split(
        "+")[1].split("_")[1]) for col in cols_to_load], "col_header": cols_to_load}).sort_values(by="col_num")

    df_table_with_n_col_to_label = load_public_bi_table_by_cols(numeric_column_to_label["table"].split(
        "_")[0], numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)

    df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]] = pd.to_numeric(
        df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]], errors="coerce")
    df_table_with_n_col_to_label.dropna(inplace=True)
    if len(df_table_with_n_col_to_label) == 0:
        return -1

    # search all already labeled numeric cols in the corpus
    already_labeled_numeric_cols = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
        numeric_types)].drop_duplicates()

    # iterrate over all alread labeled numeric col and do the EMD measure
    results = []
    for index, row in already_labeled_numeric_cols.iterrows():

        df_table_with_labeled_numeric = load_public_bi_table_by_cols(row["table"].split(
            "_")[0], row["table"], usecols=[int(row["column"].split("_")[1])], col_headers=[row["dataset_id"]], frac=table_frac)

        df_table_with_labeled_numeric[row["dataset_id"]] = pd.to_numeric(
            df_table_with_labeled_numeric[row["dataset_id"]], errors="coerce")
        df_table_with_labeled_numeric.dropna(inplace=True)
        if len(df_table_with_labeled_numeric[row["dataset_id"]].to_list()) == 0:
            continue

        # EMD calc
        emd = wasserstein_distance(df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].to_list(
        ), df_table_with_labeled_numeric[row["dataset_id"]].to_list())
        print(f"EMD: {emd}")
        results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"],
                        row["dataset_id"], row["semanticType"], emd, df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean(), df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].std()])

    df_results = pd.DataFrame(results, columns=[
                                "unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD", "mean", "std"])  # .sort_values(by="EMD")
    df_results = df_results[pd.to_numeric(
        df_results['EMD'], errors='coerce').notnull()]
    df_results = df_results.sort_values(by="EMD")
    if gen_train_data:
        df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "normal_EMD", "results",
                            f"{numeric_column_to_label['dataset_id']}_appr{approach}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
    elif gen_train_data == False:
        df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "normal_EMD", "results_test_data",
                            f"{numeric_column_to_label['dataset_id']}_appr{approach}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
    predicted_semantic_type = -1
    if len(df_results) > 0:
        if df_results.iloc[0]["EMD"] >= (df_results.iloc[0]["std"] * threshold_EMD_factor):
            return -1
        if len(df_results) < 2:
            predicted_semantic_type = df_results.iloc[0]["semantic_type"]
            predicted_semantic_type = label_enc.transform(
                [predicted_semantic_type])[0]
        if len(df_results) > 2:
            if ((df_results.iloc[0]["EMD"] == df_results.iloc[1]["EMD"]) & (df_results.iloc[0]["semantic_type"] != df_results.iloc[1]["semantic_type"])):
                return -1
            predicted_semantic_type = df_results.iloc[0]["semantic_type"]
            predicted_semantic_type = label_enc.transform(
                [predicted_semantic_type])[0]
    return predicted_semantic_type

if gen_train_data:
    # filter out unlabeled data from valid_headers
    unlabeled_data_df = valid_header_df.loc[
        valid_header_df["dataset_id"].isin(unlabeled_data_ids)]

    # load already labeled data
    labeled_data_df = valid_header_df.loc[valid_header_df["dataset_id"].isin(
        labeled_data_ids)]

    # load already generated labeled train data
    gen_labeled_data_df = pd.read_csv(join(gen_train_data_path, f"public_bi_gen_training_data_all_combined_maj_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), names=[
                                      "table", "column", "dataset_id", "semanticType"])
    # not using generated train data from the previous LFs of the numeric types
    gen_labeled_data_df = gen_labeled_data_df[~gen_labeled_data_df["semanticType"].isin(numeric_types)]


    ## drop duplicate only on cols "table", "column", "dataset_id"!!! This must be fixed. Actually it cant happen that there are two duplicte sets with different semantic types!
    total_labeled_data_df = pd.concat(
        [labeled_data_df, gen_labeled_data_df]).drop_duplicates(subset=["table", "column", "dataset_id"])
    #total_labeled_data_df = labeled_data_df.drop_duplicates(subset=["table", "column", "dataset_id"])

    # only unlabaled columns of tyoe numeric
    numeric_unlabeled_data_df = unlabeled_data_df.loc[unlabeled_data_df["semanticType"].isin(
        numeric_types)]
    #numeric_unlabeled_data_df = numeric_unlabeled_data_df[678:]

    # # define LF to apply
    # lfs = [normal_EMD]


    # # snorkel pandas applier for apply lfs to the data
    # applier = PandasLFApplier(lfs=lfs)

    # from multiprocessing import Pool
    # from multiprocessing.pool import ThreadPool as Pool
    # from functools import partial
    # import numpy as np
    # from tqdm.auto import tqdm

    # def parallelize(data, func, num_of_processes=8):
    #     data_split = np.array_split(data, num_of_processes)
    #     pool = Pool(num_of_processes)
    #     #data = pd.concat(pool.map(func, data_split))
    #     data = np.concatenate(pool.map(func, data_split), axis=0)
    #     pool.close()
    #     pool.join()
    #     return data

    # L_train = applier.apply(df=numeric_unlabeled_data_df)
    # #L_train = parallelize(numeric_unlabeled_data_df,
    # #                      applier.apply, n_worker)

    # print(
    #     f"Length of labeled data: {len([x for x in L_train if x != -1])}")

    # numeric_unlabeled_data_df["predicted_semantic_type"] = [
    #     label_enc.inverse_transform([x])[0] if x != -1 else "None"
    #     for x in L_train
    # ]
    # numeric_unlabeled_data_df.to_csv(join(
    #     os.environ["WORKING_DIR"], "labeling_functions",
    #     "numerics", "normal_EMD", "out", "results",
    #     f"{corpus}_normal_EMD_{threshold_EMD_factor}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"
    # ),
    #     index=False)

    # # save gen train data
    # class_reportable_data = numeric_unlabeled_data_df.drop(numeric_unlabeled_data_df[
    #     numeric_unlabeled_data_df["predicted_semantic_type"] == "None"].index)

    # class_reportable_data[[
    #     "table", "column", "dataset_id", "predicted_semantic_type"
    # ]].to_csv(join(
    #     os.environ["WORKING_DIR"], "labeling_functions", "numerics",
    #     "normal_EMD", "out", "gen_train_data",
    #     f"{corpus}_gen_training_data_{threshold_EMD_factor}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"
    # ),
    #             index=False)

    # cls_report = classification_report(
    #     class_reportable_data["semanticType"],
    #     class_reportable_data["predicted_semantic_type"],
    #     output_dict=True)

    # # save classification_report
    # with open(
    #         join(
    #             os.environ["WORKING_DIR"], "labeling_functions", "numerics",
    #             "normal_EMD", "out", "validation",
    #             f"{corpus}_classification_report_unlabeled_{threshold_EMD_factor}_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.json"
    #         ), "w") as f:
    #     json.dump(cls_report, f)

In [None]:
# search all already labeled numeric cols in the corpus
already_labeled_numeric_cols = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
    numeric_types)].drop_duplicates()

print(f"labeled numerical columns: {len(already_labeled_numeric_cols)}")

In [None]:
numeric_unlabeled_data_df

In [None]:
numeric_column_to_label = numeric_unlabeled_data_df.iloc[0]
numeric_column_to_label

In [None]:
# 1. search for already labeled string types in the same table of the unlabeled numeric column that we want to label next
string_already_labeled_cols_in_table = total_labeled_data_df[
    total_labeled_data_df["table"] == numeric_column_to_label["table"]]
string_already_labeled_cols_in_table = string_already_labeled_cols_in_table.loc[
    ~string_already_labeled_cols_in_table["semanticType"].isin(numeric_types)]

string_already_labeled_cols_in_table

In [None]:
# load the table with the numeric column to label
cols_to_load = [numeric_column_to_label["dataset_id"]]
string_cols_to_load = list(
    string_already_labeled_cols_in_table["dataset_id"].values)
cols_to_load = cols_to_load + string_cols_to_load
df_cols_to_load = pd.DataFrame({"col_num": [int(col.split(
    "+")[1].split("_")[1]) for col in cols_to_load], "col_header": cols_to_load}).sort_values(by="col_num")

df_table_with_n_col_to_label = load_public_bi_table_by_cols(numeric_column_to_label["table"].split(
    "_")[0], numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)

df_table_with_n_col_to_label

In [None]:
# 2. search tables with labeled numeric columns from the same domain
# detect the same domain by comparing already labeled string columns
# at least one same already labeled string column must be in the other table with the labales numerical column
joinable_tables = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
    string_already_labeled_cols_in_table["semanticType"])].drop_duplicates()
joinable_tables_sorted = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby(
    "table").count().sort_values(by=["semanticType"], ascending=False)

print(f"tables with same labeled string col: {len(joinable_tables_sorted)}")

In [None]:
total_number_of_labeled_n_cols_pruning = 0
for joinable_table, row in joinable_tables_sorted.iterrows():
    # first check if there is a numeric column in the table which is already labeled with an numeric type
    numerics_in_joinable_table = total_labeled_data_df[
        total_labeled_data_df["table"] == joinable_table]
    numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
        numeric_types)]

    total_number_of_labeled_n_cols_pruning += len(numerics_in_joinable_table)
    if len(numerics_in_joinable_table) > 0:
        print(joinable_table)

print(f"Total number of labeled num cols with pruning: {total_number_of_labeled_n_cols_pruning}")

In [None]:
# search tables with labeled numerical column but with no labeled string columns
total_number_of_lab_num_cols_in_total_unlabeled_tables = 0
tables_with_no_strCol_labels = []
for idx, groups in enumerate(total_labeled_data_df.groupby("table")):
    # if idx > 2:
    #     break
    #print(groups[0])
    df = groups[1]
    #print(df)
    # search in table for a labeled num col
    df_lab_num_col = df[df["semanticType"].isin(numeric_types)]
    #print(df_lab_num_col)
    if len(df_lab_num_col) == 0:
        #print("No labeled num col")
        continue
    # search in the table for labeled string cols
    df_lab_str_col = df[~df["semanticType"].isin(numeric_types)]
    if len(df_lab_str_col) == 0:
        total_number_of_lab_num_cols_in_total_unlabeled_tables += len(df_lab_num_col)
        tables_with_no_strCol_labels.append(groups[0])
        #print(df)

print(f"Total number of labeled num cols in tables with no labeled str cols: {total_number_of_lab_num_cols_in_total_unlabeled_tables}")
print(f"Total number of tables with labeled num cols and no labeled str cols: {len(tables_with_no_strCol_labels)}")
print(tables_with_no_strCol_labels)

In [None]:
# count the performance improvement of "pruning" approach by counting the needed emd combination measurements

# search all already labeled numeric cols in the corpus
total_already_labeled_numeric_cols = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
    numeric_types)].drop_duplicates()

print(f"labeled numerical columns: {len(total_already_labeled_numeric_cols)}")


# search tables with labeled numerical column but with no labeled string columns
total_number_of_lab_num_cols_in_total_unlabeled_tables = 0
for idx, groups in enumerate(total_labeled_data_df.groupby("table")):
    # if idx > 5:
    #     break
    df = groups[1]
    #print(df)
    # search in table for a labeled num col
    df_lab_num_col = df[df["semanticType"].isin(numeric_types)]
    if len(df_lab_num_col) == 0:
        #print("No labeled num col")
        continue
    # search in the table for labeled string cols
    df_lab_str_col = df[~df["semanticType"].isin(numeric_types)]
    if len(df_lab_str_col) == 0:
        total_number_of_lab_num_cols_in_total_unlabeled_tables += len(df_lab_num_col)
        #print(df)

print(f"Total number of labeled num cols in tables with no labeled str cols: {total_number_of_lab_num_cols_in_total_unlabeled_tables}")


results = []
for idx, numeric_column_to_label in numeric_unlabeled_data_df.iterrows():
    # 1. search for already labeled string types in the same table of the unlabeled numeric column that we want to label next
    string_already_labeled_cols_in_table = total_labeled_data_df[
        total_labeled_data_df["table"] == numeric_column_to_label["table"]]
    string_already_labeled_cols_in_table = string_already_labeled_cols_in_table.loc[
        ~string_already_labeled_cols_in_table["semanticType"].isin(numeric_types)]

    # 2. search tables with labeled numeric columns from the same domain
    # detect the same domain by comparing already labeled string columns
    # at least one same already labeled string column must be in the other table with the labales numerical column
    joinable_tables = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
        string_already_labeled_cols_in_table["semanticType"])].drop_duplicates()
    joinable_tables_sorted = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby(
        "table").count().sort_values(by=["semanticType"], ascending=False)

    #print(f"tables with same labeled string col: {len(joinable_tables_sorted)}")

    total_number_of_labeled_n_cols_pruning = 0
    for joinable_table, row in joinable_tables_sorted.iterrows():
        # first check if there is a numeric column in the table which is already labeled with an numeric type
        numerics_in_joinable_table = total_labeled_data_df[
            total_labeled_data_df["table"] == joinable_table]
        numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
            numeric_types)]

        total_number_of_labeled_n_cols_pruning += len(numerics_in_joinable_table)

    #print(f"Total number of labeled num cols with pruning: {total_number_of_labeled_n_cols_pruning}")

    results.append([total_number_of_labeled_n_cols_pruning, total_number_of_lab_num_cols_in_total_unlabeled_tables, len(total_already_labeled_numeric_cols)])

df_results = pd.DataFrame(results, columns=["n_cols_pruning", "n_cols_no_lab_tablecols", "total_lab_num_cols"])

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(20,10))

plt.plot(df_results["n_cols_pruning"], label="pruning")
plt.plot(df_results["n_cols_pruning"]+df_results["n_cols_no_lab_tablecols"], label="pruning + unlabeled tables")
plt.plot(df_results["total_lab_num_cols"], label="# labeled num col")
plt.xlabel("unlabeled numerical column")
plt.ylabel("# numeric columns against that are measured")
plt.legend()
plt.grid()

plt.show()

## implementing pruning

In [None]:
numeric_column_to_label = numeric_unlabeled_data_df.iloc[15]
numeric_column_to_label

In [None]:
# 1. search for already labeled string types in the same table of the unlabeled numeric column that we want to label next
string_already_labeled_cols_in_table = total_labeled_data_df[
    total_labeled_data_df["table"] == numeric_column_to_label["table"]]
string_already_labeled_cols_in_table = string_already_labeled_cols_in_table.loc[
    ~string_already_labeled_cols_in_table["semanticType"].isin(numeric_types)]

string_already_labeled_cols_in_table

In [None]:
# load the table with the numeric column to label
cols_to_load = [numeric_column_to_label["dataset_id"]]
string_cols_to_load = list(
    string_already_labeled_cols_in_table["dataset_id"].values)
cols_to_load = cols_to_load + string_cols_to_load
df_cols_to_load = pd.DataFrame({"col_num": [int(col.split(
    "+")[1].split("_")[1]) for col in cols_to_load], "col_header": cols_to_load}).sort_values(by="col_num")

df_table_with_n_col_to_label = load_public_bi_table_by_cols(numeric_column_to_label["table"].split(
    "_")[0], numeric_column_to_label["table"], usecols=df_cols_to_load["col_num"].values, col_headers=df_cols_to_load["col_header"].values)

df_table_with_n_col_to_label.head(5)

In [None]:
# 2. search tables with labeled numeric columns from the same domain
# detect the same domain by comparing already labeled string columns
# at least one same already labeled string column must be in the other table with the labales numerical column
joinable_tables = total_labeled_data_df.loc[total_labeled_data_df["semanticType"].isin(
    string_already_labeled_cols_in_table["semanticType"])].drop_duplicates()
joinable_tables_sorted = joinable_tables[joinable_tables["table"] != numeric_column_to_label["table"]].groupby(
    "table").count().sort_values(by=["semanticType"], ascending=False)

joinable_tables_sorted.head(3)

In [None]:
table_frac = None
approach = 1
pruning_mode = 1
# 3. Iterate over all founded tables + the tables with no labels inside.
results = []
for joinable_table, row in joinable_tables_sorted.iterrows():
    if joinable_table != "MLB_4":
        break
    # first check if there is a numeric column in the table which is already labeled with an numeric type
    numerics_in_joinable_table = total_labeled_data_df[
        total_labeled_data_df["table"] == joinable_table]
    numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
        numeric_types)]
    if len(numerics_in_joinable_table) == 0:
        continue
    print(numerics_in_joinable_table)

    if len(numerics_in_joinable_table) > 0:
        # strings_in_joinable_table = joinable_tables[joinable_tables["table"] == joinable_table]

        # load the table, only with the numeric cols to measure emd against
        cols_to_load_for_joinable_table = list(
            numerics_in_joinable_table["dataset_id"].values)

        df_cols_to_load_for_joinable_table = pd.DataFrame({"col_num": [int(col.split(
            "+")[1].split("_")[1]) for col in cols_to_load_for_joinable_table], "col_header": cols_to_load_for_joinable_table}).sort_values(by="col_num")
        df_joinable_table = load_public_bi_table_by_cols(joinable_table.split(
            "_")[0], joinable_table, usecols=df_cols_to_load_for_joinable_table["col_num"].values, col_headers=df_cols_to_load_for_joinable_table["col_header"].values, frac=table_frac)
                
        print(df_joinable_table.head(3))

        for i, numeric_col_in_joinable_table in numerics_in_joinable_table.iterrows():
            print("EMD Calc:")
            print(numeric_col_in_joinable_table)
            unlabeled_col = df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].dropna().to_list()
            labeled_col = df_joinable_table[numeric_col_in_joinable_table["dataset_id"]].dropna().to_list()
            print(len(labeled_col))
            if len(labeled_col) == 0:
                continue
            
            # EMD calc
            emd = wasserstein_distance(unlabeled_col, labeled_col)
            print(f"EMD: {emd}")

            results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"],numeric_col_in_joinable_table["dataset_id"], numeric_col_in_joinable_table["semanticType"], emd, df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean(), df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].std()])

if pruning_mode == 1:
    for joinable_table in tables_with_no_strCol_labels:
        # first check if there is a numeric column in the table which is already labeled with an numeric type
        numerics_in_joinable_table = total_labeled_data_df[
            total_labeled_data_df["table"] == joinable_table]
        numerics_in_joinable_table = numerics_in_joinable_table.loc[numerics_in_joinable_table["semanticType"].isin(
            numeric_types)]
        if len(numerics_in_joinable_table) == 0:
            continue
        print(numerics_in_joinable_table)

        if len(numerics_in_joinable_table) > 0:
            # strings_in_joinable_table = joinable_tables[joinable_tables["table"] == joinable_table]

            # load the table, only with the numeric cols to measure emd against
            cols_to_load_for_joinable_table = list(
                numerics_in_joinable_table["dataset_id"].values)

            df_cols_to_load_for_joinable_table = pd.DataFrame({"col_num": [int(col.split(
                "+")[1].split("_")[1]) for col in cols_to_load_for_joinable_table], "col_header": cols_to_load_for_joinable_table}).sort_values(by="col_num")
            df_joinable_table = load_public_bi_table_by_cols(joinable_table.split(
                "_")[0], joinable_table, usecols=df_cols_to_load_for_joinable_table["col_num"].values, col_headers=df_cols_to_load_for_joinable_table["col_header"].values, frac=table_frac)
                    
            print(df_joinable_table.head(3))

            for i, numeric_col_in_joinable_table in numerics_in_joinable_table.iterrows():
                print("EMD Calc:")
                print(numeric_col_in_joinable_table)
                unlabeled_col = df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].dropna().to_list()
                labeled_col = df_joinable_table[numeric_col_in_joinable_table["dataset_id"]].dropna().to_list()
                print(len(labeled_col))
                if len(labeled_col) == 0:
                    continue
                
                # EMD calc
                emd = wasserstein_distance(unlabeled_col, labeled_col)
                print(f"EMD: {emd}")

                results.append([numeric_column_to_label["dataset_id"], numeric_column_to_label["semanticType"],numeric_col_in_joinable_table["dataset_id"], numeric_col_in_joinable_table["semanticType"], emd, df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].mean(), df_table_with_n_col_to_label[numeric_column_to_label["dataset_id"]].std()])

        


df_results = pd.DataFrame(results, columns=[
                            "unlabeled_col", "real_semantic_type", "labeled_col", "semantic_type", "EMD", "mean", "std"])  # .sort_values(by="EMD")
df_results = df_results[pd.to_numeric(
    df_results['EMD'], errors='coerce').notnull()]
df_results = df_results.sort_values(by="EMD")

if gen_train_data:
    df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "normal_EMD", "results",
                        f"{numeric_column_to_label['dataset_id']}_appr{approach}_pruning_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
elif gen_train_data == False:
    df_results.to_csv(join(os.environ["WORKING_DIR"], "labeling_functions", "numerics", "normal_EMD", "results_test_data",
                        f"{numeric_column_to_label['dataset_id']}_appr{approach}_pruning_{table_frac}_{labeled_data_size}_{unlabeled_data_size}_{test_data_size}_{random_state}.csv"), index=False)
predicted_semantic_type = -1
if len(df_results) > 0:
    if df_results.iloc[0]["EMD"] >= (df_results.iloc[0]["std"] * threshold_EMD_factor):
        return -1
    if len(df_results) < 2:
        predicted_semantic_type = df_results.iloc[0]["semantic_type"]
        predicted_semantic_type = label_enc.transform(
            [predicted_semantic_type])[0]
    if len(df_results) > 2:
        if ((df_results.iloc[0]["EMD"] == df_results.iloc[1]["EMD"]) & (df_results.iloc[0]["semantic_type"] != df_results.iloc[1]["semantic_type"])):
            return -1
        predicted_semantic_type = df_results.iloc[0]["semantic_type"]
        predicted_semantic_type = label_enc.transform(
            [predicted_semantic_type])[0]

print(predicted_semantic_type)

In [None]:
df_results