-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: allow to import training sets produced with nanite
- Loading branch information
1 parent
586dc27
commit 4ffb0fd
Showing
4 changed files
with
107 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import appdirs | ||
import collections | ||
import pathlib | ||
import shutil | ||
import zipfile | ||
|
||
import nanite.rate.rater | ||
|
||
|
||
RATING_SCHEMES = collections.OrderedDict() | ||
RATING_SCHEMES["Default (zef18 & Extra Trees)"] = ["zef18", "Extra Trees"] | ||
RATING_SCHEMES["Disabled"] = ["none", "none"] | ||
|
||
#: Rating configuration directory | ||
CFG_DIR = pathlib.Path(appdirs.user_config_dir(appname="PyJibe")) / "rating" | ||
#: Path to main rating configuration file | ||
CFG_PATH = CFG_DIR / "rating_schemes.txt" | ||
|
||
|
||
def get_rating_schemes(): | ||
"""Return an ordered dict with available rating schemes""" | ||
schemes = collections.OrderedDict() | ||
ts = get_training_set_paths() | ||
# We currently stick to Extra Trees | ||
for key in ts: | ||
schemes["{} + Extra Trees".format(key)] = [ts[key], "Extra Trees"] | ||
schemes["Disabled"] = ["none", "none"] | ||
return schemes | ||
|
||
|
||
def get_training_set_paths(): | ||
"""Return ordered dict with available training set names and paths""" | ||
ts = collections.OrderedDict() | ||
# training sets from nanite | ||
nanite_list = nanite.rate.rater.get_available_training_sets() | ||
for key in nanite_list: | ||
ts[key] = nanite.rate.IndentationRater.get_training_set_path(key) | ||
# user-imported training sets | ||
for pp in sorted(CFG_DIR.glob("ts_*")): | ||
ts[pp.name[3:]] = pp | ||
return ts | ||
|
||
|
||
def import_training_set(ts_zip, override=False): | ||
"""Open a training set zip file and import it to :const:`CFG_DIR`""" | ||
path = pathlib.Path(ts_zip) | ||
if not path.suffix == ".zip": | ||
raise ValueError("Training set file suffix must be '.zip', " | ||
"got '{}'!".format(ts_zip.suffix)) | ||
if not path.name.startswith("ts_"): | ||
raise ValueError("Training set file name must begin with 'ts_', " | ||
"got '{}'!".format(ts_zip.name)) | ||
pout = CFG_DIR / path.with_suffix("").name | ||
|
||
if pout.exists(): | ||
if override: | ||
shutil.rmtree(pout) | ||
else: | ||
raise OSError("Training set already exists: {}".format(pout)) | ||
|
||
pout.mkdir(exist_ok=True, parents=True) | ||
|
||
with zipfile.ZipFile(ts_zip) as zp: | ||
zp.extractall(pout) | ||
|
||
# return index in new training set collection | ||
ts = get_training_set_paths() | ||
for idx, key in enumerate(ts.keys()): | ||
if ts[key] == pout: | ||
break | ||
return idx |