-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from SenWu/slicing
Reorg and contrib
- Loading branch information
Showing
32 changed files
with
1,075 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
[Unreleased] | ||
------------ | ||
|
||
Added | ||
^^^^^ | ||
* `@senwu`_: Add `clear_intermediate_checkpoints` and `clear_all_checkpoints` arguments | ||
to support user needs for clear checkpoins. | ||
* `@senwu`_: Add `min_len` and `max_len` in `Meta.config` to support setting sequence | ||
length. | ||
* `@senwu`_: Add overall and task specific loss during evaluating as default. | ||
* `@senwu`_: Calculate overall and task specific metrics and loss during training. | ||
* `@senwu`_: Add more util functions, e.g., array_to_numpy, construct_identifier, | ||
and random_string. | ||
* `@senwu`_: Enforce dataset has uids attribute. | ||
* `@senwu`_: Add micro/macro metric options which have split-wise micro/macro average | ||
and global-wise micro/macro average. The name for the metrics are: | ||
|
||
:: | ||
|
||
split-wise micro average: `model/all/{split}/micro_average` | ||
split-wise macro average: `model/all/{split}/macro_average` | ||
global-wise micro average: `model/all/all/micro_average` | ||
global-wise macro average: `model/all/all/macro_average` | ||
|
||
*Note*: `micro` means average all metrics from all tasks. `macro` means average all | ||
average metric from all tasks. | ||
|
||
* `@senwu`_: Add contrib folder to support unofficial usages. | ||
|
||
Fixed | ||
^^^^^ | ||
* `@senwu`_: Add warning for one class in ROC AUC metric. | ||
* `@senwu`_: Fix missing support for StepLR and MultiStepLR lr scheduler. | ||
* `@senwu`_: Fix missing pytest.ini and fix test cannot remove temp dir issue. | ||
* `@senwu`_: Fix default train loss metric from `model/train/all/loss` to | ||
`model/all/train/loss` to follow the format `TASK_NAME/DATA_NAME/SPLIT/METRIC` | ||
pattern. | ||
|
||
Changed | ||
^^^^^^^ | ||
* `@senwu`_: Change the default counter unit to epoch. | ||
* `@senwu`_: Update the metric to return one metric value by default. | ||
|
||
Removed | ||
^^^^^^^ | ||
* `@senwu`_: Remove `checkpoint_clear` argument. | ||
|
||
.. | ||
For convenience, all username links for contributors can be listed here | ||
.. _@senwu: https://github.com/senwu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Emmental contrib | ||
================ | ||
|
||
Any code in this directory is not officially supported, and may change or be removed | ||
at any time without notice. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from emmental.contrib import slicing | ||
|
||
__all__ = ["slicing"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from emmental.contrib.slicing.data import add_slice_labels | ||
from emmental.contrib.slicing.slicing_function import slicing_function | ||
from emmental.contrib.slicing.task import build_slice_tasks | ||
|
||
__all__ = ["add_slice_labels", "slicing_function", "build_slice_tasks"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import logging | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from emmental.contrib.slicing.slicing_function import slicing_function | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def add_slice_labels(task, dataloaders, slice_func_dict, split="train"): | ||
"""A function to extend dataloader by adding slice indicator and predictor | ||
labels. | ||
""" | ||
|
||
# Calculate class balance | ||
slice_distribution = {} | ||
|
||
# Add base slice if needed | ||
if "base" not in slice_func_dict.keys(): | ||
slice_func_dict["base"] = base_slice | ||
|
||
for dataloader in dataloaders: | ||
labels = dataloader.dataset.Y_dict[dataloader.task_to_label_dict[task.name]] | ||
for slice_name, slice_func in slice_func_dict.items(): | ||
indicators = slice_func(dataloader.dataset) | ||
slice_ind_name = f"{task.name}_slice:ind_{slice_name}" | ||
slice_pred_name = f"{task.name}_slice:pred_{slice_name}" | ||
|
||
pred_labels = indicators * labels | ||
ind_labels = indicators | ||
ind_labels[ind_labels == 0] = 2 | ||
|
||
if dataloader.split == split and slice_name != "base": | ||
ind_classes, ind_counts = np.unique( | ||
ind_labels.numpy(), return_counts=True | ||
) | ||
if ind_classes.shape[0] == 2: | ||
slice_distribution[slice_ind_name] = torch.Tensor( | ||
np.sum(ind_counts) / ind_counts / ind_classes.shape[0] | ||
) | ||
pred_classes, pred_counts = np.unique( | ||
pred_labels.numpy(), return_counts=True | ||
) | ||
if (pred_classes[0] == 0 and pred_classes.shape[0] == 3) or ( | ||
pred_classes[0] == 1 and pred_classes.shape[0] == 2 | ||
): | ||
if pred_classes[0] == 0: | ||
slice_distribution[slice_pred_name] = torch.Tensor( | ||
1 - pred_counts[1:] / np.sum(pred_counts[1:]) | ||
) | ||
else: | ||
slice_distribution[slice_pred_name] = torch.Tensor( | ||
1 - pred_counts / np.sum(pred_counts) | ||
) | ||
|
||
# Update slice indicator and predictor labels | ||
dataloader.dataset.Y_dict.update( | ||
{slice_ind_name: ind_labels, slice_pred_name: pred_labels} | ||
) | ||
# Update dataloader task_to_label_dict | ||
dataloader.task_to_label_dict.update( | ||
{slice_ind_name: slice_ind_name, slice_pred_name: slice_pred_name} | ||
) | ||
msg = ( | ||
f"Loaded slice labels for task {task.name}, slice {slice_name}, " | ||
f"split {dataloader.split}." | ||
) | ||
logger.info(msg) | ||
|
||
return slice_distribution | ||
|
||
|
||
@slicing_function() | ||
def base_slice(example): | ||
return True |
Empty file.
83 changes: 83 additions & 0 deletions
83
src/emmental/contrib/slicing/modules/slice_attention_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class SliceAttentionModule(nn.Module): | ||
"""An attention module to leverage all slice representations.""" | ||
|
||
def __init__( | ||
self, | ||
slice_ind_key="_slice_ind_", | ||
slice_pred_key="_slice_pred_", | ||
slice_pred_feat_key="_slice_feat_", | ||
): | ||
super().__init__() | ||
|
||
self.slice_ind_key = slice_ind_key | ||
self.slice_pred_key = slice_pred_key | ||
self.slice_pred_feat_key = slice_pred_feat_key | ||
|
||
def forward(self, intermediate_output_dict): | ||
# Collect ordered slice indicator head names | ||
slice_indicator_names = sorted( | ||
[ | ||
flow_name | ||
for flow_name in intermediate_output_dict.keys() | ||
if self.slice_ind_key in flow_name | ||
] | ||
) | ||
# Collect ordered slice predictor head names | ||
slice_predictor_names = sorted( | ||
[ | ||
flow_name | ||
for flow_name in intermediate_output_dict.keys() | ||
if self.slice_pred_key in flow_name | ||
] | ||
) | ||
# Concat slice indicator predictions | ||
slice_indicator_predictions = torch.cat( | ||
[ | ||
F.softmax(intermediate_output_dict[slice_indicator_name][0])[ | ||
:, 0 | ||
].unsqueeze(1) | ||
for slice_indicator_name in slice_indicator_names | ||
], | ||
dim=-1, | ||
) | ||
# Concat slice predictor predictions | ||
slice_predictor_predictions = torch.cat( | ||
[ | ||
F.softmax(intermediate_output_dict[slice_predictor_name][0])[ | ||
:, 0 | ||
].unsqueeze(1) | ||
for slice_predictor_name in slice_predictor_names | ||
], | ||
dim=-1, | ||
) | ||
# Collect ordered slice feature head names | ||
slice_feature_names = sorted( | ||
[ | ||
flow_name | ||
for flow_name in intermediate_output_dict.keys() | ||
if self.slice_pred_feat_key in flow_name | ||
] | ||
) | ||
# Concat slice representations | ||
slice_representations = torch.cat( | ||
[ | ||
intermediate_output_dict[slice_feature_name][0].unsqueeze(1) | ||
for slice_feature_name in slice_feature_names | ||
], | ||
dim=1, | ||
) | ||
# Attention | ||
A = ( | ||
F.softmax(slice_indicator_predictions * slice_predictor_predictions, dim=1) | ||
.unsqueeze(-1) | ||
.expand([-1, -1, slice_representations.size(-1)]) | ||
) | ||
|
||
reweighted_representation = torch.sum(A * slice_representations, 1) | ||
|
||
return reweighted_representation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import torch.nn.functional as F | ||
|
||
|
||
def ce_loss(module_name, intermediate_output_dict, Y, active, weight=None): | ||
return F.cross_entropy( | ||
intermediate_output_dict[module_name][0][active], | ||
(Y.view(-1) - 1)[active], | ||
weight, | ||
) | ||
|
||
|
||
def output(module_name, intermediate_output_dict): | ||
return F.softmax(intermediate_output_dict[module_name][0], dim=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import logging | ||
from functools import wraps | ||
from types import SimpleNamespace | ||
|
||
import numpy as np | ||
import torch | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class slicing_function: | ||
""" | ||
When wrapped with this decorator, slicing functions only need to return an indicator | ||
for whether an individual example (bundle of attributes) belongs in that slice. | ||
Iterating through the dataset, making the pred array (and masking), etc. are all | ||
handled automatically. | ||
""" | ||
|
||
def __init__(self, fields=[]): | ||
self.fields = fields | ||
|
||
def __call__(self, f): | ||
@wraps(f) | ||
def wrapped_f(dataset): | ||
inds = [] | ||
for idx in range(len(dataset)): | ||
example = SimpleNamespace( | ||
**{field: dataset.X_dict[field][idx] for field in self.fields} | ||
) | ||
in_slice = f(example) | ||
inds.append(1 if in_slice else 0) | ||
inds = torch.from_numpy(np.array(inds)).view(-1) | ||
logger.info( | ||
f"Total {int((inds == 1).sum())} / {len(dataset)} examples are " | ||
f"in slice {f.__name__}" | ||
) | ||
return inds | ||
|
||
return wrapped_f |
Oops, something went wrong.