Skip to content

Commit

Permalink
feat: allow to import training sets produced with nanite
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmueller committed Sep 1, 2019
1 parent 586dc27 commit 4ffb0fd
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
0.5.0
- feat: allow to import training sets produced with nanite
0.4.4
- ci: pin 'joblib==0.11.0' on travis-CI (workaround for infinite loop
in macOS build, https://github.com/pyinstaller/pyinstaller/issues/4067)
Expand Down
42 changes: 33 additions & 9 deletions pyjibe/fd/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import collections
import inspect
import io
import pkg_resources
Expand All @@ -18,11 +17,7 @@
from .mpl_indent import MPLIndentation
from .mpl_edelta import MPLEDelta
from .mpl_qmap import MPLQMap


RATING_SCHEMES = collections.OrderedDict()
RATING_SCHEMES["Default (zef18 & Extra Trees)"] = ["zef18", "Extra Trees"]
RATING_SCHEMES["Disabled"] = ["none", "none"]
from . import rating_scheme


# load QWidget from ui file
Expand Down Expand Up @@ -71,6 +66,9 @@ def __init__(self, parent_widget):
# fitting setup
self.fit_setup()

# rating scheme
self.rating_scheme_setup()

self.tabs.currentChanged.connect(self.on_tab_changed)
self.signal_slot(True)
self.btn_rater.clicked.connect(self.on_user_rate)
Expand Down Expand Up @@ -328,6 +326,24 @@ def mpl_qmap_setup(self):
self.qmap_sp_range2.valueChanged.connect(self.on_qmap_min_max_changed)
self.mpl_qmap.connect_curve_selection_event(self.on_qmap_selection)

def on_cb_rating_scheme(self):
scheme_id = self.cb_rating_scheme.currentIndex()
schemes = rating_scheme.get_rating_schemes()
if len(schemes) == scheme_id:
search_dir = ""
exts_str = "Training set zip file (*.zip)"
tsz, _e = QtWidgets.QFileDialog.getOpenFileName(
self.parent_widget, "Import a training set",
search_dir, exts_str)
if tsz:
idx = rating_scheme.import_training_set(tsz)
self.rating_scheme_setup()
self.cb_rating_scheme.setCurrentIndex(idx)
else:
self.cb_rating_scheme.setCurrentIndex(0)
else:
self.on_params_init()

def on_mpl_curve_update(self):
fdist = self.current_curve
self.mpl_curve_update(fdist)
Expand Down Expand Up @@ -697,8 +713,9 @@ def rate_data(self, data):
return_single = False

scheme_id = self.cb_rating_scheme.currentIndex()
scheme_key = list(RATING_SCHEMES.keys())[scheme_id]
training_set, regressor = RATING_SCHEMES[scheme_key]
schemes = rating_scheme.get_rating_schemes()
scheme_key = list(schemes.keys())[scheme_id]
training_set, regressor = schemes[scheme_key]
rates = []
for fdist in data:
rt = fdist.rate_quality(regressor=regressor,
Expand All @@ -710,6 +727,12 @@ def rate_data(self, data):
else:
return rates

def rating_scheme_setup(self):
self.cb_rating_scheme.clear()
schemes = rating_scheme.get_rating_schemes()
self.cb_rating_scheme.addItems(list(schemes.keys()))
self.cb_rating_scheme.addItem("Import...")

@property
def selected_curves(self):
"""Return an IndentationGroup with all curves selected by the user"""
Expand Down Expand Up @@ -745,7 +768,8 @@ def signal_slot(self, enable):
[self.sp_delta_num_samples.valueChanged, self.on_params_init],
[self.sp_delta_num_samples.valueChanged, self.mpl_edelta_update],
# rating scheme dropdown
[self.cb_rating_scheme.currentTextChanged, self.on_params_init],
[self.cb_rating_scheme.currentTextChanged,
self.on_cb_rating_scheme],
]

for signal, slot in cn:
Expand Down
2 changes: 1 addition & 1 deletion pyjibe/fd/mpl_qmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def update(self, qmap, feature, cmap="viridis", vmin=None, vmax=None):
[yv-dy*.4, yv+dy*.4],
color=color,
lw=1)[0]
)
)

# common variables
self.dx = dx
Expand Down
71 changes: 71 additions & 0 deletions pyjibe/fd/rating_scheme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import appdirs
import collections
import pathlib
import shutil
import zipfile

import nanite.rate.rater


RATING_SCHEMES = collections.OrderedDict()
RATING_SCHEMES["Default (zef18 & Extra Trees)"] = ["zef18", "Extra Trees"]
RATING_SCHEMES["Disabled"] = ["none", "none"]

#: Rating configuration directory
CFG_DIR = pathlib.Path(appdirs.user_config_dir(appname="PyJibe")) / "rating"
#: Path to main rating configuration file
CFG_PATH = CFG_DIR / "rating_schemes.txt"


def get_rating_schemes():
"""Return an ordered dict with available rating schemes"""
schemes = collections.OrderedDict()
ts = get_training_set_paths()
# We currently stick to Extra Trees
for key in ts:
schemes["{} + Extra Trees".format(key)] = [ts[key], "Extra Trees"]
schemes["Disabled"] = ["none", "none"]
return schemes


def get_training_set_paths():
"""Return ordered dict with available training set names and paths"""
ts = collections.OrderedDict()
# training sets from nanite
nanite_list = nanite.rate.rater.get_available_training_sets()
for key in nanite_list:
ts[key] = nanite.rate.IndentationRater.get_training_set_path(key)
# user-imported training sets
for pp in sorted(CFG_DIR.glob("ts_*")):
ts[pp.name[3:]] = pp
return ts


def import_training_set(ts_zip, override=False):
"""Open a training set zip file and import it to :const:`CFG_DIR`"""
path = pathlib.Path(ts_zip)
if not path.suffix == ".zip":
raise ValueError("Training set file suffix must be '.zip', "
"got '{}'!".format(ts_zip.suffix))
if not path.name.startswith("ts_"):
raise ValueError("Training set file name must begin with 'ts_', "
"got '{}'!".format(ts_zip.name))
pout = CFG_DIR / path.with_suffix("").name

if pout.exists():
if override:
shutil.rmtree(pout)
else:
raise OSError("Training set already exists: {}".format(pout))

pout.mkdir(exist_ok=True, parents=True)

with zipfile.ZipFile(ts_zip) as zp:
zp.extractall(pout)

# return index in new training set collection
ts = get_training_set_paths()
for idx, key in enumerate(ts.keys()):
if ts[key] == pout:
break
return idx

0 comments on commit 4ffb0fd

Please sign in to comment.