Skip to content

Commit

Permalink
Merge pull request #20 from SenWu/slicing
Browse files Browse the repository at this point in the history
Reorg and contrib
  • Loading branch information
senwu committed Aug 15, 2019
2 parents d412e25 + b764f00 commit 1b46a7a
Show file tree
Hide file tree
Showing 32 changed files with 1,075 additions and 207 deletions.
51 changes: 51 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
[Unreleased]
------------

Added
^^^^^
* `@senwu`_: Add `clear_intermediate_checkpoints` and `clear_all_checkpoints` arguments
to support user needs for clear checkpoins.
* `@senwu`_: Add `min_len` and `max_len` in `Meta.config` to support setting sequence
length.
* `@senwu`_: Add overall and task specific loss during evaluating as default.
* `@senwu`_: Calculate overall and task specific metrics and loss during training.
* `@senwu`_: Add more util functions, e.g., array_to_numpy, construct_identifier,
and random_string.
* `@senwu`_: Enforce dataset has uids attribute.
* `@senwu`_: Add micro/macro metric options which have split-wise micro/macro average
and global-wise micro/macro average. The name for the metrics are:

::

split-wise micro average: `model/all/{split}/micro_average`
split-wise macro average: `model/all/{split}/macro_average`
global-wise micro average: `model/all/all/micro_average`
global-wise macro average: `model/all/all/macro_average`

*Note*: `micro` means average all metrics from all tasks. `macro` means average all
average metric from all tasks.

* `@senwu`_: Add contrib folder to support unofficial usages.

Fixed
^^^^^
* `@senwu`_: Add warning for one class in ROC AUC metric.
* `@senwu`_: Fix missing support for StepLR and MultiStepLR lr scheduler.
* `@senwu`_: Fix missing pytest.ini and fix test cannot remove temp dir issue.
* `@senwu`_: Fix default train loss metric from `model/train/all/loss` to
`model/all/train/loss` to follow the format `TASK_NAME/DATA_NAME/SPLIT/METRIC`
pattern.

Changed
^^^^^^^
* `@senwu`_: Change the default counter unit to epoch.
* `@senwu`_: Update the metric to return one metric value by default.

Removed
^^^^^^^
* `@senwu`_: Remove `checkpoint_clear` argument.

..
For convenience, all username links for contributors can be listed here
.. _@senwu: https://github.com/senwu
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ check:
flake8 src/
flake8 tests/

format:
isort -rc src/
isort -rc tests/
black src/
black tests/

docs:
sphinx-build -W -b html docs/ _build/html

Expand Down
5 changes: 5 additions & 0 deletions src/emmental/contrib/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Emmental contrib
================

Any code in this directory is not officially supported, and may change or be removed
at any time without notice.
3 changes: 3 additions & 0 deletions src/emmental/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from emmental.contrib import slicing

__all__ = ["slicing"]
5 changes: 5 additions & 0 deletions src/emmental/contrib/slicing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from emmental.contrib.slicing.data import add_slice_labels
from emmental.contrib.slicing.slicing_function import slicing_function
from emmental.contrib.slicing.task import build_slice_tasks

__all__ = ["add_slice_labels", "slicing_function", "build_slice_tasks"]
76 changes: 76 additions & 0 deletions src/emmental/contrib/slicing/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging

import numpy as np
import torch

from emmental.contrib.slicing.slicing_function import slicing_function

logger = logging.getLogger(__name__)


def add_slice_labels(task, dataloaders, slice_func_dict, split="train"):
"""A function to extend dataloader by adding slice indicator and predictor
labels.
"""

# Calculate class balance
slice_distribution = {}

# Add base slice if needed
if "base" not in slice_func_dict.keys():
slice_func_dict["base"] = base_slice

for dataloader in dataloaders:
labels = dataloader.dataset.Y_dict[dataloader.task_to_label_dict[task.name]]
for slice_name, slice_func in slice_func_dict.items():
indicators = slice_func(dataloader.dataset)
slice_ind_name = f"{task.name}_slice:ind_{slice_name}"
slice_pred_name = f"{task.name}_slice:pred_{slice_name}"

pred_labels = indicators * labels
ind_labels = indicators
ind_labels[ind_labels == 0] = 2

if dataloader.split == split and slice_name != "base":
ind_classes, ind_counts = np.unique(
ind_labels.numpy(), return_counts=True
)
if ind_classes.shape[0] == 2:
slice_distribution[slice_ind_name] = torch.Tensor(
np.sum(ind_counts) / ind_counts / ind_classes.shape[0]
)
pred_classes, pred_counts = np.unique(
pred_labels.numpy(), return_counts=True
)
if (pred_classes[0] == 0 and pred_classes.shape[0] == 3) or (
pred_classes[0] == 1 and pred_classes.shape[0] == 2
):
if pred_classes[0] == 0:
slice_distribution[slice_pred_name] = torch.Tensor(
1 - pred_counts[1:] / np.sum(pred_counts[1:])
)
else:
slice_distribution[slice_pred_name] = torch.Tensor(
1 - pred_counts / np.sum(pred_counts)
)

# Update slice indicator and predictor labels
dataloader.dataset.Y_dict.update(
{slice_ind_name: ind_labels, slice_pred_name: pred_labels}
)
# Update dataloader task_to_label_dict
dataloader.task_to_label_dict.update(
{slice_ind_name: slice_ind_name, slice_pred_name: slice_pred_name}
)
msg = (
f"Loaded slice labels for task {task.name}, slice {slice_name}, "
f"split {dataloader.split}."
)
logger.info(msg)

return slice_distribution


@slicing_function()
def base_slice(example):
return True
Empty file.
83 changes: 83 additions & 0 deletions src/emmental/contrib/slicing/modules/slice_attention_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class SliceAttentionModule(nn.Module):
"""An attention module to leverage all slice representations."""

def __init__(
self,
slice_ind_key="_slice_ind_",
slice_pred_key="_slice_pred_",
slice_pred_feat_key="_slice_feat_",
):
super().__init__()

self.slice_ind_key = slice_ind_key
self.slice_pred_key = slice_pred_key
self.slice_pred_feat_key = slice_pred_feat_key

def forward(self, intermediate_output_dict):
# Collect ordered slice indicator head names
slice_indicator_names = sorted(
[
flow_name
for flow_name in intermediate_output_dict.keys()
if self.slice_ind_key in flow_name
]
)
# Collect ordered slice predictor head names
slice_predictor_names = sorted(
[
flow_name
for flow_name in intermediate_output_dict.keys()
if self.slice_pred_key in flow_name
]
)
# Concat slice indicator predictions
slice_indicator_predictions = torch.cat(
[
F.softmax(intermediate_output_dict[slice_indicator_name][0])[
:, 0
].unsqueeze(1)
for slice_indicator_name in slice_indicator_names
],
dim=-1,
)
# Concat slice predictor predictions
slice_predictor_predictions = torch.cat(
[
F.softmax(intermediate_output_dict[slice_predictor_name][0])[
:, 0
].unsqueeze(1)
for slice_predictor_name in slice_predictor_names
],
dim=-1,
)
# Collect ordered slice feature head names
slice_feature_names = sorted(
[
flow_name
for flow_name in intermediate_output_dict.keys()
if self.slice_pred_feat_key in flow_name
]
)
# Concat slice representations
slice_representations = torch.cat(
[
intermediate_output_dict[slice_feature_name][0].unsqueeze(1)
for slice_feature_name in slice_feature_names
],
dim=1,
)
# Attention
A = (
F.softmax(slice_indicator_predictions * slice_predictor_predictions, dim=1)
.unsqueeze(-1)
.expand([-1, -1, slice_representations.size(-1)])
)

reweighted_representation = torch.sum(A * slice_representations, 1)

return reweighted_representation
13 changes: 13 additions & 0 deletions src/emmental/contrib/slicing/modules/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch.nn.functional as F


def ce_loss(module_name, intermediate_output_dict, Y, active, weight=None):
return F.cross_entropy(
intermediate_output_dict[module_name][0][active],
(Y.view(-1) - 1)[active],
weight,
)


def output(module_name, intermediate_output_dict):
return F.softmax(intermediate_output_dict[module_name][0], dim=1)
39 changes: 39 additions & 0 deletions src/emmental/contrib/slicing/slicing_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
from functools import wraps
from types import SimpleNamespace

import numpy as np
import torch

logger = logging.getLogger(__name__)


class slicing_function:
"""
When wrapped with this decorator, slicing functions only need to return an indicator
for whether an individual example (bundle of attributes) belongs in that slice.
Iterating through the dataset, making the pred array (and masking), etc. are all
handled automatically.
"""

def __init__(self, fields=[]):
self.fields = fields

def __call__(self, f):
@wraps(f)
def wrapped_f(dataset):
inds = []
for idx in range(len(dataset)):
example = SimpleNamespace(
**{field: dataset.X_dict[field][idx] for field in self.fields}
)
in_slice = f(example)
inds.append(1 if in_slice else 0)
inds = torch.from_numpy(np.array(inds)).view(-1)
logger.info(
f"Total {int((inds == 1).sum())} / {len(dataset)} examples are "
f"in slice {f.__name__}"
)
return inds

return wrapped_f

0 comments on commit 1b46a7a

Please sign in to comment.