Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Macro-F1 & Use train labels by default #135

Merged
merged 7 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example_config/MIMIC-50/bigru.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ data_dir: data/MIMIC-50
data_name: MIMIC-50
min_vocab_freq: 3
max_seq_length: 2500
include_test_labels: true

# train
seed: 1337
Expand Down
2 changes: 2 additions & 0 deletions example_config/MIMIC-50/caml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ data_dir: data/MIMIC-50
data_name: MIMIC-50
min_vocab_freq: 3
max_seq_length: 2500
# We follow caml-mimic that includes labels in both training and test datasets.
include_test_labels: true

# train
seed: 1337
Expand Down
1 change: 1 addition & 0 deletions example_config/MIMIC-50/caml_tune.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ data_dir: data/MIMIC-50
data_name: MIMIC-50
min_vocab_freq: 3
max_seq_length: 2500
include_test_labels: true

# train
seed: 1337
Expand Down
1 change: 1 addition & 0 deletions example_config/MIMIC-50/cnn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ data_dir: data/MIMIC-50
data_name: MIMIC-50
min_vocab_freq: 3
max_seq_length: 2500
include_test_labels: true

# train
seed: 1337
Expand Down
58 changes: 48 additions & 10 deletions libmultilabel/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re

import torch
import numpy as np
import torch
import torchmetrics.classification
from torchmetrics import Metric, MetricCollection, Precision, Recall, RetrievalNormalizedDCG
from torchmetrics.utilities.data import select_topk
Expand Down Expand Up @@ -40,12 +40,54 @@ def compute(self):
return self.score / self.num_sample


class MacroF1(Metric):
"""The macro-f1 score computes the average f1 scores of all labels in the dataset.

Args:
num_classes (int): The number of classes.
metric_threshold (float): Threshold to monitor for metrics.
another_macro_f1 (bool, optional): Whether to compute the 'Another-Macro-F1' score.
The 'Another-Macro-F1' is the f1 value of macro-precision and macro-recall.
This variant of macro-f1 is less preferred but is used in some works.
Please refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf].
Defaults to False.
"""
def __init__(
self,
num_classes,
metric_threshold,
another_macro_f1=False
):
super().__init__()
self.metric_threshold = metric_threshold
self.another_macro_f1 = another_macro_f1
self.add_state("preds_sum", default=torch.zeros(num_classes, dtype=torch.double))
self.add_state("target_sum", default=torch.zeros(num_classes, dtype=torch.double))
self.add_state("tp_sum", default=torch.zeros(num_classes, dtype=torch.double))

def update(self, preds, target):
assert preds.shape == target.shape
preds = torch.where(preds > self.metric_threshold, 1, 0)
self.preds_sum = torch.add(self.preds_sum, preds.sum(dim=0))
self.target_sum = torch.add(self.target_sum, target.sum(dim=0))
self.tp_sum = torch.add(self.tp_sum, (preds & target).sum(dim=0))

def compute(self):
if self.another_macro_f1:
macro_prec = torch.mean(torch.nan_to_num(self.tp_sum / self.preds_sum, posinf=0.))
macro_recall = torch.mean(torch.nan_to_num(self.tp_sum / self.target_sum, posinf=0.))
return 2 * (macro_prec * macro_recall) / (macro_prec + macro_recall + 1e-10)
else:
label_f1 = 2 * self.tp_sum / (self.preds_sum + self.target_sum + 1e-10)
return torch.mean(label_f1)


def get_metrics(metric_threshold, monitor_metrics, num_classes):
"""Map monitor metrics to the corresponding classes defined in `torchmetrics.Metric`
(https://torchmetrics.readthedocs.io/en/latest/references/modules.html).

Args:
metric_threshold (float): Thresholds to monitor for metrics.
metric_threshold (float): Threshold to monitor for metrics.
monitor_metrics (list): Metrics to monitor while validating.
num_classes (int): Total number of classes.

Expand Down Expand Up @@ -86,15 +128,11 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes):
elif metric_abbr == 'nDCG':
metrics[metric] = RetrievalNormalizedDCG(k=top_k)
elif metric == 'Another-Macro-F1':
# The f1 value of macro_precision and macro_recall. This variant of
# macro_f1 is less preferred but is used in some works. Please
# refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf]
macro_prec = Precision(num_classes, metric_threshold, average='macro')
macro_recall = Recall(num_classes, metric_threshold, average='macro')
metrics[metric] = 2 * (macro_prec * macro_recall) / \
(macro_prec + macro_recall + 1e-10)
metrics[metric] = MacroF1(num_classes, metric_threshold, another_macro_f1=True)
elif metric == 'Macro-F1':
metrics[metric] = MacroF1(num_classes, metric_threshold)
elif match_metric:
average_type = match_metric.group(1).lower() # Micro or Macro
average_type = match_metric.group(1).lower() # Micro
metric_type = match_metric.group(2) # Precision, Recall, or F1
metrics[metric] = getattr(torchmetrics.classification, metric_type)(
num_classes, metric_threshold, average=average_type)
Expand Down
19 changes: 12 additions & 7 deletions libmultilabel/nn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,27 +227,32 @@ def load_or_build_text_dict(
return vocabs


def load_or_build_label(datasets, label_file=None, silent=False):
def load_or_build_label(datasets, label_file=None, include_test_labels=False):
"""Generate label set either by the given datasets or a predefined label file.

Args:
datasets (dict): A dictionary of datasets. Each dataset contains list of instances with index, label, and tokenized text.
datasets (dict): A dictionary of datasets. Each dataset contains list of instances
with index, label, and tokenized text.
label_file (str, optional): Path to a file holding all labels.
silent (bool, optional): Disable print. Defaults to False.
include_test_labels (bool, optional): Whether to include labels in the test dataset.
Defaults to True.

Returns:
list: A list of labels sorted in alphabetical order.
"""
if label_file:
logging.info('Load labels from {label_file}')
logging.info(f'Load labels from {label_file}.')
with open(label_file, 'r') as fp:
classes = sorted([s.strip() for s in fp.readlines()])
else:
classes = set()
for dataset in datasets.values():
Eleven1Liu marked this conversation as resolved.
Show resolved Hide resolved
for d in tqdm(dataset, disable=silent):
classes.update(d['label'])
for split, data in datasets.items():
if split == 'test' and not include_test_labels:
continue
for instance in data:
classes.update(instance['label'])
classes = sorted(classes)
logging.info(f'Read {len(classes)} labels.')
return classes


Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MultiLabelModel(pl.LightningModule):
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
weight_decay (int, optional): Weight decay factor. Defaults to 0.
metric_threshold (float, optional): Thresholds to monitor for metrics. Defaults to 0.5.
metric_threshold (float, optional): Threshold to monitor for metrics. Defaults to 0.5.
monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
log_path (str): Path to a directory holding the log files and models.
silent (bool, optional): Enable silent mode. Defaults to False.
Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def init_model(model_name,
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
weight_decay (int, optional): Weight decay factor. Defaults to 0.
metric_threshold (float, optional): Thresholds to monitor for metrics. Defaults to 0.5.
metric_threshold (float, optional): Threshold to monitor for metrics. Defaults to 0.5.
monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
silent (bool, optional): Enable silent mode. Defaults to False.
save_k_predictions (int, optional): Save top k predictions on test set. Defaults to 0.
Expand Down
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def get_config():
help='Whether to shuffle training data before each epoch (default: %(default)s)')
parser.add_argument('--merge_train_val', action='store_true',
help='Whether to merge the training and validation data. (default: %(default)s)')
parser.add_argument('--include_test_labels', action='store_true',
help='Whether to include labels in the test dataset. (default: %(default)s)')

# train
parser.add_argument('--seed', type=int,
Expand Down
2 changes: 1 addition & 1 deletion torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _setup_model(
)
if not classes:
classes = data_utils.load_or_build_label(
self.datasets, self.config.label_file, self.config.silent)
self.datasets, self.config.label_file, self.config.include_test_labels)

if self.config.val_metric not in self.config.monitor_metrics:
logging.warn(
Expand Down