Skip to content

Commit

Permalink
1. Load existed ToolBox object if there is one in the saving_path.
Browse files Browse the repository at this point in the history
2. Return the existed split settings if the data has already been split in ToolBox object, in case of overwriting.
  • Loading branch information
tangypnuaa authored and tangypnuaa committed Jan 17, 2021
1 parent 60fed64 commit 859cc18
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions alipy/toolbox.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import os
import pickle
import warnings
import inspect

from sklearn.linear_model import LogisticRegression
Expand Down Expand Up @@ -168,8 +169,20 @@ def __init__(self, y, X=None, instance_indexes=None,
self._saving_path = os.path.abspath(saving_path)
if os.path.isdir(self._saving_path):
self._saving_dir = self._saving_path
if os.path.exists(os.path.join(saving_path, 'al_settings.pkl')):
with open(os.path.join(saving_path, 'al_settings.pkl'), 'rb') as f:
existed_toolbox = pickle.load(f)
for ke in existed_toolbox.__dict__.keys():
setattr(self, ke, getattr(existed_toolbox, ke))
return
else:
self._saving_dir = os.path.split(self._saving_path)[0] # if a directory, a dir and None will return.
if os.path.exists(saving_path):
with open(os.path.abspath(saving_path), 'rb') as f:
existed_toolbox = pickle.load(f)
for ke in existed_toolbox.__dict__.keys():
setattr(self, ke, getattr(existed_toolbox, ke))
return
self.save()

def split_AL(self, test_ratio=0.3, initial_label_rate=0.05,
Expand Down Expand Up @@ -209,6 +222,11 @@ def split_AL(self, test_ratio=0.3, initial_label_rate=0.05,
"""
# should support other query types in the future
if self._split is True:
warnings.warn("Data has already been split. Return the existed split in case of overwriting.",
category=RuntimeWarning)
return self.train_idx, self.test_idx, self.label_idx, self.unlabel_idx

self.split_count = split_count
if self._target_type != 'Features':
if self._target_type != 'multilabel':
Expand Down Expand Up @@ -241,6 +259,7 @@ def split_AL(self, test_ratio=0.3, initial_label_rate=0.05,
saving_path=self._saving_path
)
self._split = True
self.save()
return self.train_idx, self.test_idx, self.label_idx, self.unlabel_idx

def get_split(self, round=None):
Expand Down

0 comments on commit 859cc18

Please sign in to comment.