Skip to content

Commit

Permalink
add customized ind_head support
Browse files Browse the repository at this point in the history
  • Loading branch information
senwu committed Aug 9, 2019
1 parent 11c5698 commit acb8411
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/emmental/contrib/slicing/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
from functools import partial

Expand All @@ -13,7 +14,13 @@
logger = logging.getLogger(__name__)


def build_slice_tasks(task, slice_func_dict, slice_distribution={}, dropout=0.0):
def build_slice_tasks(
task,
slice_func_dict,
slice_distribution={},
dropout=0.0,
slice_ind_head_module=None,
):
"""A function to build slice tasks based on slicing functions.
We assume the original task flow contains feature extractor and predictor head.
Expand Down Expand Up @@ -49,6 +56,11 @@ def build_slice_tasks(task, slice_func_dict, slice_distribution={}, dropout=0.0)
slice_module_pool[module_name] = module
slice_actions = [action for action in base_task_task_flow]

if slice_ind_head_module is None:
slice_ind_head_module = nn.Linear(task_feature_size, 2)

assert isinstance(slice_ind_head_module, nn.Module)

# Create slice indicator tasks.
# (Note: indicator only has two classes, e.g, in the slice or out)
for slice_name in slice_func_dict.keys():
Expand All @@ -57,7 +69,7 @@ def build_slice_tasks(task, slice_func_dict, slice_distribution={}, dropout=0.0)

# Create ind module
ind_head_module_name = f"{ind_task_name}_head"
ind_head_module = nn.Linear(task_feature_size, 2)
ind_head_module = copy.deepcopy(slice_ind_head_module)

ind_head_dropout_module_name = f"{task.name}_slice:dropout_{slice_name}"
ind_head_dropout_module = nn.Dropout(p=dropout)
Expand Down

0 comments on commit acb8411

Please sign in to comment.