Skip to content

Commit

Permalink
fix task_feature in LCEMBO
Browse files Browse the repository at this point in the history
Summary:
`task_features` is 1-dim list that is same across all metrics. Passing `task_features[i]` will cause indexing error.

Here I add a validation of `task_features` and pass right task_feature to each LCEMGP.

Reviewed By: bletham

Differential Revision: D26190071

fbshipit-source-id: 483a6ff723107fcaa8415effb61813dd49bbc765
  • Loading branch information
qingfeng10 authored and facebook-github-bot committed Feb 4, 2021
1 parent 464cdd8 commit 5e5947d
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions ax/models/torch/cbo_lcem.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,19 @@ def get_and_fit_model(
Xs: List of X data, one tensor per outcome.
Ys: List of Y data, one tensor per outcome.
Yvars:List of Noise variance of Yvar data, one tensor per outcome.
Returns: Fitted multi-task contextual GP model.
task_features: List of columns of X that are tasks.
Returns: ModeListGP that each model is a fitted LCEM GP model.
"""

if len(task_features) == 1:
task_feature = task_features[0]
elif len(task_features) > 1:
raise NotImplementedError(
f"LCEMBO only supports 1 task feature (got {task_features})"
)
else:
raise ValueError("LCEMBO requires context input as task features")

models = []
for i, X in enumerate(Xs):
# validate input Yvars
Expand All @@ -62,7 +72,7 @@ def get_and_fit_model(
gp_m = LCEMGP(
train_X=X,
train_Y=Ys[i],
task_feature=task_features[i],
task_feature=task_feature,
context_cat_feature=self.context_cat_feature,
context_emb_feature=self.context_emb_feature,
embs_dim_list=self.embs_dim_list,
Expand All @@ -72,7 +82,7 @@ def get_and_fit_model(
train_X=X,
train_Y=Ys[i],
train_Yvar=Yvar,
task_feature=task_features[i],
task_feature=task_feature,
context_cat_feature=self.context_cat_feature,
context_emb_feature=self.context_emb_feature,
embs_dim_list=self.embs_dim_list,
Expand Down

0 comments on commit 5e5947d

Please sign in to comment.