Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Changelog
[Ver 0.1.*]
-----------

* |Fix| Fix missing base estimators when calling :meth:`load()` for all ensembles | `@xuyxu <https://github.com/xuyxu>`__
* |MajorFeature| Add methods on model deserialization :meth:`load()` for all ensembles | `@mttgdd <https://github.com/mttgdd>`__

[Beta]
Expand Down
41 changes: 30 additions & 11 deletions torchensemble/tests/test_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.utils.data import TensorDataset, DataLoader

import torchensemble
from torchensemble.utils import io
from torchensemble.utils.logging import set_logger


Expand All @@ -24,6 +25,10 @@
torchensemble.AdversarialTrainingRegressor]


# Remove randomness
np.random.seed(0)
torch.manual_seed(0)

set_logger("pytest_all_models")


Expand Down Expand Up @@ -66,12 +71,10 @@ def forward(self, X):

# Testing data
X_test = torch.Tensor(np.array(([0.5, 0.5],
[0.6, 0.6],
[0.7, 0.7],
[0.8, 0.8])))
[0.6, 0.6])))

y_test_clf = torch.LongTensor(np.array(([1, 1, 0, 0])))
y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6, 0.7, 0.8])))
y_test_clf = torch.LongTensor(np.array(([1, 0])))
y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6])))
y_test_reg = y_test_reg.view(-1, 1)


Expand All @@ -94,9 +97,9 @@ def test_clf(clf):

# Prepare data
train = TensorDataset(X_train, y_train_clf)
train_loader = DataLoader(train, batch_size=2)
train_loader = DataLoader(train, batch_size=2, shuffle=False)
test = TensorDataset(X_test, y_test_clf)
test_loader = DataLoader(test, batch_size=2)
test_loader = DataLoader(test, batch_size=2, shuffle=False)

# Snapshot ensemble needs more epochs
if isinstance(model, torchensemble.SnapshotEnsembleClassifier):
Expand All @@ -109,7 +112,15 @@ def test_clf(clf):
save_model=True)

# Test
model.predict(test_loader)
prev_acc = model.predict(test_loader)

# Reload
new_model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False)
io.load(new_model)

post_acc = new_model.predict(test_loader)

assert prev_acc == post_acc # ensure the same performance


@pytest.mark.parametrize("reg", all_reg)
Expand All @@ -131,9 +142,9 @@ def test_reg(reg):

# Prepare data
train = TensorDataset(X_train, y_train_reg)
train_loader = DataLoader(train, batch_size=2)
train_loader = DataLoader(train, batch_size=2, shuffle=False)
test = TensorDataset(X_test, y_test_reg)
test_loader = DataLoader(test, batch_size=2)
test_loader = DataLoader(test, batch_size=2, shuffle=False)

# Snapshot ensemble needs more epochs
if isinstance(model, torchensemble.SnapshotEnsembleRegressor):
Expand All @@ -146,7 +157,15 @@ def test_reg(reg):
save_model=True)

# Test
model.predict(test_loader)
prev_mse = model.predict(test_loader)

# Reload
new_model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False)
io.load(new_model)

post_mse = new_model.predict(test_loader)

assert prev_mse == post_mse # ensure the same performance


@pytest.mark.parametrize("method", all_clf + all_reg)
Expand Down
15 changes: 13 additions & 2 deletions torchensemble/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ def save(model, save_dir, logger):
filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,
model.base_estimator_.__name__,
model.n_estimators)
state = {"model": model.state_dict()}

# The real number of base estimators in some ensembles is not same as
# `n_estimators`.
state = {"n_estimators": len(model.estimators_),
"model": model.state_dict()}
save_dir = os.path.join(save_dir, filename)

logger.info("Saving the model to `{}`".format(save_dir))
Expand All @@ -39,4 +43,11 @@ def load(model, save_dir="./", logger=None):
if logger:
logger.info("Loading the model from `{}`".format(save_dir))

model.load_state_dict(torch.load(save_dir)["model"])
state = torch.load(save_dir)
n_estimators = state["n_estimators"]
model_params = state["model"]

# Pre-allocate and load all base estimators
for _ in range(n_estimators):
model.estimators_.append(model._make_estimator())
model.load_state_dict(model_params)