In [67]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [136]:
import numpy as np
import polars as pl
from sweep.experiment import Experiment, ExperimentType
from model.data import DataModel, KFeaturesDefinition, k_features_factory
from erm.problems.problems import ProblemType
import subprocess
from experiments.data_loading import read_polars_dataframe

Define a datamodel

In [69]:
d = 1000

In [70]:
x_diagonal = KFeaturesDefinition(diagonal=[(0.5,d)])
θ_diagonal = KFeaturesDefinition(diagonal=[(2,d)])
ω_diagonal = KFeaturesDefinition(diagonal=[(1,d)])
δ_diagonal = KFeaturesDefinition(diagonal=[(1,d)])
ν_diagonal = KFeaturesDefinition(diagonal=[(1,d)])

low_low_kwargs = {
    "x_diagonal": x_diagonal,
    "θ_diagonal": θ_diagonal,
    "ω_diagonal": ω_diagonal,
    "δ_diagonal": δ_diagonal,
    "ν_diagonal": ν_diagonal,
}

low_robustness_low_usefulness = DataModel(
    d,
    normalize_matrices=False,
    data_model_factory=k_features_factory,
    factory_kwargs=low_low_kwargs,
    name="low_robustness_low_usefulness"
)


In [71]:
x_diagonal = KFeaturesDefinition(diagonal=[(0.5,d)])
θ_diagonal = KFeaturesDefinition(diagonal=[(8,d)])
ω_diagonal = KFeaturesDefinition(diagonal=[(1,d)])
δ_diagonal = KFeaturesDefinition(diagonal=[(1,d)])
ν_diagonal = KFeaturesDefinition(diagonal=[(1,d)])

low_high_kwargs = {
    "x_diagonal": x_diagonal,
    "θ_diagonal": θ_diagonal,
    "ω_diagonal": ω_diagonal,
    "δ_diagonal": δ_diagonal,
    "ν_diagonal": ν_diagonal,
}

low_robustness_high_usefulness = DataModel(
    d,
    normalize_matrices=False,
    data_model_factory=k_features_factory,
    factory_kwargs=low_high_kwargs,
    name="low_robustness_high_usefulness"
)


In [72]:
x_diagonal = KFeaturesDefinition(diagonal=[(2,d)])
θ_diagonal = KFeaturesDefinition(diagonal=[(0.5,d)])
ω_diagonal = KFeaturesDefinition(diagonal=[(1,d)])
δ_diagonal = KFeaturesDefinition(diagonal=[(1,d)])
ν_diagonal = KFeaturesDefinition(diagonal=[(1,d)])

high_low_kwargs = {
    "x_diagonal": x_diagonal,
    "θ_diagonal": θ_diagonal,
    "ω_diagonal": ω_diagonal,
    "δ_diagonal": δ_diagonal,
    "ν_diagonal": ν_diagonal,
}

high_robustness_low_usefulness = DataModel(
    d,
    normalize_matrices=False,
    data_model_factory=k_features_factory,
    factory_kwargs=high_low_kwargs,
    name="high_robustness_low_usefulness"
)


In [73]:
x_diagonal = KFeaturesDefinition(diagonal=[(2,d)])
θ_diagonal = KFeaturesDefinition(diagonal=[(2,d)])
ω_diagonal = KFeaturesDefinition(diagonal=[(1,d)])
δ_diagonal = KFeaturesDefinition(diagonal=[(1,d)])
ν_diagonal = KFeaturesDefinition(diagonal=[(1,d)])

high_high_kwargs = {
    "x_diagonal": x_diagonal,
    "θ_diagonal": θ_diagonal,
    "ω_diagonal": ω_diagonal,
    "δ_diagonal": δ_diagonal,
    "ν_diagonal": ν_diagonal,
}

high_robustness_high_usefulness = DataModel(
    d,
    normalize_matrices=False,
    data_model_factory=k_features_factory,
    factory_kwargs=high_high_kwargs,
    name="high_robustness_high_usefulness"
)


Define an experiment

In [74]:
experiment = Experiment(
    state_evolution_repetitions=1,
    erm_repetitions=1,
    alphas=np.logspace(-0.2,2,10),
    epsilons=np.array([0.0, 0.1, 0.2, 0.3]),
    lambdas=np.array([0.001]),
    taus=np.array([0.05]),
    d=d,
    experiment_type=ExperimentType.Sweep,
    data_models = [low_robustness_low_usefulness, low_robustness_high_usefulness, high_robustness_low_usefulness, high_robustness_high_usefulness],
    test_against_epsilons=np.array([0.2]),
    erm_problem_type=ProblemType.Logistic,
    gamma_fair_error=0.01,
    name="feature_combinations"
)
experiment_json = experiment.to_json()

In [75]:
with open("feature_combinations.json", "w") as f:
    f.write(experiment_json)

In [76]:
# venv_python = ".venv/bin/python"
# command = ["mpiexec", "-n", "5", venv_python, "sweep/run_sweep.py", "--json", experiment_json, "--log-level", "INFO"]
# subprocess.run(command)

In [196]:
df = read_polars_dataframe(experiment.name)

  df = df.rename({"epsilon_g_0":"epsilon_g"})


In [197]:
df.shape

(160, 148)

In [198]:
df

alpha,epsilon,tau,lam,epsilon_g,data_model_name,id,date,task_type,erm_problem_type,test_against_epsilons,d,values,gamma_fair_error,gamma,generalization_error,adversarial_generalization_errors,training_error,training_loss,test_losses,m,q,sigma,A,P,F,m_hat,q_hat,sigma_hat,A_hat,F_hat,P_hat,n_m,n_q,n_sigma,n_A,n_P,…,adversarial_generalization_errors_teacher,adversarial_generalization_errors_overlap,fair_adversarial_errors,training_error_erm,boundary_loss_train,boundary_loss_test_es,training_loss_erm,test_losses_erm,duration_erm,id_std_erm,date_std_erm,task_type_std_erm,erm_problem_type_std_erm,test_against_epsilons_std_erm,d_std_erm,values_std_erm,gamma_fair_error_std_erm,gamma_std_erm,ρ_std,m_std_erm,F_std_erm,Q_std,A_std_erm,P_std_erm,angle_std_erm,generalization_error_erm_std,generalization_error_overlap_std,adversarial_generalization_errors_std_erm,adversarial_generalization_errors_teacher_std,adversarial_generalization_errors_overlap_std,fair_adversarial_errors_std,training_error_std_erm,boundary_loss_train_std,boundary_loss_test_es_std,training_loss_std_erm,test_losses_std_erm,duration_std_erm
f64,f64,f64,f64,f64,str,str,str,str,str,list[f64],f64,struct[0],f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,str,str,str,list[f64],f64,struct[0],f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0.630957,0.0,0.05,0.001,0.2,"""low_robustness_high_usefulness""",,,,,,1000.0,{},0.01,1.0,0.30329,0.423484,0.0,0.011026,4.362861,4.427787,14.59142,291.863443,29.18284,29.18284,8.855575,0.003793,0.000114,0.001426,0.0,0.0,0.0,4.427787,14.59142,291.863443,29.18284,29.18284,…,0.1295,0.595167,0.4159,0.0,0.0,18.46232,0.010394,1.173841,18.655912,,,,,,,{},,,,,,,,,,,,,,,,,,,,,
100.0,0.0,0.05,0.001,0.2,"""low_robustness_high_usefulness""",,,,,,1000.0,{},0.01,1.0,0.010747,0.222703,0.005813,0.022581,351.075538,61.262941,938.770702,0.917681,1877.541404,1877.541404,122.525881,16.689601,0.574405,1.087703,0.0,0.0,0.0,61.262941,938.770702,0.917681,1877.541404,1877.541404,…,0.23,0.570608,0.1025,0.00573,0.0,2094.717195,0.022662,1.051357,59.017711,,,,,,,{},,,,,,,,,,,,,,,,,,,,,
5.994843,0.0,0.05,0.001,0.2,"""high_robustness_low_usefulness""",,,,,,1000.0,{},0.01,1.0,0.076062,0.144942,1.1302e-29,0.008788,18.103042,17.194694,312.422403,75.169449,156.211201,156.211201,8.597347,0.228746,0.002967,0.012803,0.0,0.0,0.0,17.194694,312.422403,75.169449,156.211201,156.211201,…,0.1117,-0.431544,0.1298,0.0,0.0,340.524849,0.008571,0.525551,49.044804,,,,,,,{},,,,,,,,,,,,,,,,,,,,,
32.442261,0.2,0.05,0.001,0.2,"""high_robustness_low_usefulness""",,,,,,1000.0,{},0.01,1.0,0.028747,0.114271,0.014585,4.314174,4.38137,10.504156,110.964747,0.565413,55.482374,55.482374,5.252078,18.577849,1.962684,1.71634,0.0,0.0,0.103558,10.504156,110.964747,0.565413,55.482374,55.482374,…,0.1077,-0.441666,0.0661,0.014703,406.354458,125.488195,0.184422,0.203569,30.646932,,,,,,,{},,,,,,,,,,,,,,,,,,,,,
5.994843,0.3,0.05,0.001,0.2,"""high_robustness_high_usefulnes…",,,,,,1000.0,{},0.01,1.0,0.073855,0.143106,0.004528,11.532701,6.177271,22.948687,138.924017,4.642131,69.462009,69.462009,11.474343,1.235892,0.337061,0.202323,0.0,0.0,0.02519,22.948687,138.924017,4.642131,69.462009,69.462009,…,0.1135,0.545327,0.1257,0.00317,154.948698,159.249107,0.179152,0.392646,35.183068,,,,,,,{},,,,,,,,,,,,,,,,,,,,,
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
56.958108,0.0,0.05,0.001,0.2,"""low_robustness_high_usefulness""",,,,,,1000.0,{},0.01,1.0,0.013246,0.222703,0.004762,0.024205,247.825375,51.786245,671.196611,1.50442,1342.393223,1342.393223,103.572489,8.605685,0.328208,0.662708,0.0,0.0,0.0,51.786245,671.196611,1.50442,1342.393223,1342.393223,…,0.2226,0.580978,0.1126,0.00467,0.0,1509.104774,0.023997,0.895668,41.887833,,,,,,,{},,,,,,,,,,,,,,,,,,,,,
0.630957,0.2,0.05,0.001,0.2,"""low_robustness_low_usefulness""",,,,,,1000.0,{},0.01,1.0,0.304297,0.424433,0.0,2.646676,6.344104,2.625037,20.659576,213.274232,41.319152,41.319152,5.250074,0.012308,0.000303,0.001997,0.0,0.0,0.000346,2.625037,20.659576,213.274232,41.319152,41.319152,…,0.1288,0.606269,0.4188,0.0,4.782306,26.601356,0.018156,1.409722,28.110478,,,,,,,{},,,,,,,,,,,,,,,,,,,,,
3.414549,0.2,0.05,0.001,0.2,"""high_robustness_low_usefulness""",,,,,,1000.0,{},0.01,1.0,0.123782,0.187667,3.6641e-38,17.653053,19.328098,16.572011,319.937972,93.104672,159.968986,159.968986,8.286006,0.177993,0.005227,0.009755,0.0,0.0,0.000971,16.572011,319.937972,93.104672,159.968986,159.968986,…,0.1054,0.547274,0.1751,0.0,128.542665,334.57136,0.01857,0.974807,44.219075,,,,,,,{},,,,,,,,,,,,,,,,,,,,,
10.525003,0.2,0.05,0.001,0.2,"""low_robustness_low_usefulness""",,,,,,1000.0,{},0.01,1.0,0.060315,0.227148,0.01995,7.601932,7.801971,5.389877,30.044196,0.93741,60.088392,60.088392,10.779755,5.749756,1.130507,0.948861,0.0,0.0,0.057954,5.389877,30.044196,0.93741,60.088392,60.088392,…,0.2159,-0.378445,0.1653,0.018527,70.649461,64.545011,0.325801,0.427423,35.28939,,,,,,,{},,,,,,,,,,,,,,,,,,,,,
