/
main_fullshot.py
117 lines (92 loc) · 3.92 KB
/
main_fullshot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Entry point to compute the loss decomposition for different models.
This should be called by `python main_fullshot.py <conf>` where <conf> sets all configs from the cli, see
the file `config/main.yaml` for details about the configs. or use `python main_fullshot.py -h`.
"""
from __future__ import annotations
try:
from sklearnex import patch_sklearn
patch_sklearn(["LogisticRegression"])
except:
# tries to speedup sklearn if possible (has to be before import sklearn)
pass
import logging
import traceback
import os
import sys
import pandas as pd
import hydra
from utils.cluster import nlp_cluster
from utils.helpers import LightningWrapper
import hubconf
from utils.tune_hyperparam import tune_hyperparam_
from main_fewshot import begin, instantiate_datamodule_, run_component_, save_results
try:
import wandb
except ImportError:
pass
logger = logging.getLogger(__name__)
RESULTS_FILE = "results_{component}.csv"
LAST_CHECKPOINT = "last.ckpt"
FILE_END = "end.txt"
@hydra.main(config_name="main", config_path="config")
def main_except(cfg):
if cfg.is_nlp_cluster:
with nlp_cluster(cfg):
main(cfg)
else:
main(cfg)
def main(cfg):
logger.info(os.uname().nodename)
############## STARTUP ##############
logger.info("Stage : Startup")
begin(cfg)
############## REPRESENT DATA ##############
logger.info(f"Representing data with {cfg.representor}")
representor, preprocess = hubconf.__dict__[cfg.representor]()
representor = LightningWrapper(representor)
datamodule = instantiate_datamodule_(cfg, representor, preprocess)
############## DOWNSTREAM PREDICTOR ##############
results = dict()
assert cfg.predictor.is_tune_hyperparam
# those components can have the same hyperparameters
components2hypopt = {"train_train": dict(train_on="train-sbst-0.5", validate_on="train-sbst-0.5", label_size=0.2),
"train-cmplmnt-ntest_train-sbst-ntest": dict(train_on="train-sbst-0.5", validate_on="train-cmplmnt-0.5", label_size=0.2),
"train_test": dict(train_on="train-sbst-0.5", validate_on="test", label_size=0.2), # valdiation should be done on test-sbst-0.1
"train_test-cmplmnt-0.1": dict(train_on="train", validate_on="test-sbst-0.1"),
"union_test": dict(train_on="train-sbst-0.5", validate_on="train-sbst-0.5", label_size=0.2),
}
# those components have the same training setup so don't retrain
components_same_train = {}
if cfg.is_supervised:
# only need train on train for supervised baselines (i.e. approx error) and train on test (agg risk)
components = ["train_train",
"train_test"
#"train_test-cmplmnt-0.1",
]
else:
# test should be replaced by test-cmplmnt-0.1
components = ["train_train",
"train_test",
"train-cmplmnt-ntest_train-sbst-ntest"]
if cfg.is_alternative_decomposition:
components += ["union_test"]
for component in components:
sffx_hypopt = "hyp_{train_on}_{validate_on}_{label_size}".format(**components2hypopt[component])
tune_hyperparam_(datamodule, cfg,
tuning_path=cfg.paths.tuning + sffx_hypopt,
**components2hypopt[component])
results = run_component_(component, datamodule, cfg, results, components_same_train,
results_path=cfg.paths.results + sffx_hypopt)
# save results
results = pd.DataFrame.from_dict(results)
save_results(cfg, results, "all")
if __name__ == "__main__":
try:
main_except()
except:
logger.exception("Failed this error:")
# exit gracefully, so wandb logs the problem
print(traceback.print_exc(), file=sys.stderr)
exit(1)
finally:
wandb.finish()