In [None]:
from test_fair_clustering import main
import argparse
import os
import csv
import numpy as np

In [None]:
OUTPUT_FOLDER = "outputs"
CSV_NAME = "results.csv"

lambdas = {
    "Synthetic": {
        "kmedian": 10,
        "kmeans": 10,
        # "ncut": 10,
    }, 
    "Synthetic-unequal": {
        "kmedian": 10,
        "kmeans": 10,
        # "ncut": 10,
    }, 
    "Adult": {
        "kmedian": 9000,
        "kmeans": 9000,
        # "ncut": 10,
    }, 
    "Bank": {
        "kmedian": 9000,
        "kmeans": 6000,
        # "ncut": 40,
    }, 
    "CensusII": {
        "kmedian": 500000,
        "kmeans": 500000,
        # "ncut": 100,
    }
}

n_runs = {
    "Synthetic": 30, 
    "Synthetic-unequal": 30, 
    "Adult": 20,
    "Bank": 20,
    "CensusII": 3,
}

In [None]:
def get_args(seed=1, dataset="Synthetic-unequal", cluster_option="ncut", lmbda=10):
    args = argparse.Namespace()
    
    args.plot_option_clusters_vs_lambda = True
    args.plot_option_fairness_vs_clusterE = False
    args.plot_option_balance_vs_clusterE = False
    args.plot_option_convergence = False
    args.lmbda_tune = False

    args.seed = seed
    args.dataset = dataset
    args.cluster_option = cluster_option
    args.lmbda = lmbda    

    working_dir = os.getcwd()
    args.data_dir = os.path.join(working_dir, "data")
    args.output_path = os.path.join(working_dir, OUTPUT_FOLDER)
    return args

def make_csv(dir_path, csv_path, fieldnames):
    os.makedirs(dir_path, exist_ok=True)
    if os.path.isfile(csv_path):
        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            if len([row for row in reader]) > 0:
                return

    with open(csv_path, "w", newline='') as f:
        writer = csv.DictWriter(f, fieldnames)
        writer.writeheader()

def run_main(args, csv_name=CSV_NAME):
    results = main(args, logging=False, seedable=True)

    save_dict = {
        "dataset": args.dataset,
        "N": results['N'],
        "J": results['J'],
        "lmbda": args.lmbda,
        "Objective": results["clustering energy (Objective)"],
        "fairness error": results["fairness error"],
        "balance": results["balance"],
        "cluster_option": args.cluster_option,
        "time": results["time"],
        "seed": args.seed,
        "lmbda_tune": args.lmbda_tune,
        "K": results['K'],        
    }

    csv_path = os.path.join(args.output_path, csv_name)
    fieldnames = save_dict.keys()
    make_csv(args.output_path, csv_path, fieldnames)
    with open(csv_path, "a", newline='') as f:
        writer = csv.DictWriter(f, fieldnames)
        writer.writerow(save_dict)
    

In [None]:
def compare_entry(args, row):
    for key in ["dataset", "lmbda", "cluster_option", "lmbda_tune"]:
        if str(getattr(args, key)) != row[key]:
            return False
    return True

def find_same_options(csv_name, args):
    entries = []
    csv_path = os.path.join(args.output_path, csv_name)
    with open(csv_path, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if compare_entry(args, row):
                entries.append(row)
    return entries

In [None]:
for dataset in lambdas:
    for cluster_option in lambdas[dataset]:
        lmbda = lambdas[dataset][cluster_option]

        args = get_args(dataset=dataset, cluster_option=cluster_option, lmbda=lmbda)
        existing_entries = find_same_options(CSV_NAME, args)
        n = n_runs[dataset] - len(existing_entries)

        if n < 1:
            print("enough results for these settings")
            continue

        seeds = [int(entry["seed"]) for entry in existing_entries]
        seeds.append(0)
        next_seed = max(seeds) + 1
        for seed in range(next_seed, next_seed + n):
            args.seed = seed
            print()
            run_main(args, CSV_NAME)

In [None]:
# Fetch results
for dataset in lambdas:
    print(f"\n\n{dataset}")
    for cluster_option in lambdas[dataset]:
        print("\n"+cluster_option.upper())
        lmbda = lambdas[dataset][cluster_option]

        args = get_args(dataset=dataset, cluster_option=cluster_option, lmbda=lmbda)
        existing_entries = find_same_options(CSV_NAME, args)
        
        if len(existing_entries) < 1:
            print("no data yet")
            continue

        entry = existing_entries[0]
        # name = f"{dataset} (N = {entry['N']}, J = {entry['J']}, lmbda = {lmbda})"

        keys = ["Objective", "fairness error", "balance"]
        for key in keys:
            data = [float(entry[key]) for entry in existing_entries]
            mean = np.mean(data)
            std = np.std(data)

            print(f"{key}{' '*(20-len(key))} M = {mean:.2f}     SD = {std:.2f}")

In [None]:
# run_main(get_args())