Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] ViT trace support and example (#165)
Co-authored-by: zhangyunchen <zhangyunchen@sensetime.com>
- Loading branch information
1 parent
9335ab3
commit ffb0eb0
Showing
9 changed files
with
51,730 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Accuracy metric.""" | ||
|
||
from sklearn.metrics import accuracy_score, top_k_accuracy_score | ||
|
||
import datasets | ||
import numpy as np | ||
|
||
|
||
_DESCRIPTION = """ | ||
Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with: | ||
Accuracy = (TP + TN) / (TP + TN + FP + FN) | ||
Where: | ||
TP: True positive | ||
TN: True negative | ||
FP: False positive | ||
FN: False negative | ||
""" | ||
|
||
|
||
_KWARGS_DESCRIPTION = """ | ||
Args: | ||
predictions (`list` of `int`): Predicted labels. | ||
references (`list` of `int`): Ground truth labels. | ||
normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True. | ||
sample_weight (`list` of `float`): Sample weights Defaults to None. | ||
Returns: | ||
accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy. | ||
Examples: | ||
Example 1-A simple example | ||
>>> accuracy_metric = datasets.load_metric("accuracy") | ||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0]) | ||
>>> print(results) | ||
{'accuracy': 0.5} | ||
Example 2-The same as Example 1, except with `normalize` set to `False`. | ||
>>> accuracy_metric = datasets.load_metric("accuracy") | ||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False) | ||
>>> print(results) | ||
{'accuracy': 3.0} | ||
Example 3-The same as Example 1, except with `sample_weight` set. | ||
>>> accuracy_metric = datasets.load_metric("accuracy") | ||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4]) | ||
>>> print(results) | ||
{'accuracy': 0.8778625954198473} | ||
""" | ||
|
||
|
||
_CITATION = """ | ||
@article{scikit-learn, | ||
title={Scikit-learn: Machine Learning in {P}ython}, | ||
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. | ||
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. | ||
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and | ||
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, | ||
journal={Journal of Machine Learning Research}, | ||
volume={12}, | ||
pages={2825--2830}, | ||
year={2011} | ||
} | ||
""" | ||
|
||
|
||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) | ||
class Accuracy(datasets.Metric): | ||
def _info(self): | ||
return datasets.MetricInfo( | ||
description=_DESCRIPTION, | ||
citation=_CITATION, | ||
inputs_description=_KWARGS_DESCRIPTION, | ||
features=datasets.Features( | ||
{ | ||
"predictions": datasets.Sequence(datasets.Value("int32")), | ||
"references": datasets.Sequence(datasets.Value("int32")), | ||
} | ||
if self.config_name == "multilabel" | ||
else { | ||
"predictions": datasets.Sequence(datasets.Value("float32")), | ||
"references": datasets.Value("int32"), | ||
} | ||
), | ||
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], | ||
) | ||
|
||
def _compute(self, predictions, references, normalize=True, sample_weight=None): | ||
return { | ||
"top-1_accuracy": float( | ||
top_k_accuracy_score(references, predictions, k=1, normalize=normalize, sample_weight=sample_weight) | ||
), | ||
"top-5_accuracy": float( | ||
top_k_accuracy_score(references, predictions, k=5, normalize=normalize, sample_weight=sample_weight) | ||
), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
quant: | ||
a_qconfig: | ||
quantizer: FixedFakeQuantize | ||
observer: EMAMSEObserver # EMAMSEObserver EMAMinMaxObserver EMAQuantileObserver EMAPruneMinMaxObserver | ||
bit: 8 | ||
symmetric: False | ||
ch_axis: -1 # perlayer -1 | ||
w_qconfig: | ||
quantizer: FixedFakeQuantize | ||
observer: MinMaxObserver | ||
bit: 8 | ||
symmetric: False | ||
ch_axis: 0 # perchannel 0 perlayer -1 | ||
calibrate: 1024 | ||
backend: academic | ||
data: | ||
dataset_name: null | ||
dataset_config_name: null | ||
train_dir: /root/imagenet/ILSVRC/Data/CLS-LOC/train | ||
validation_dir: /root/imagenet/ILSVRC/Data/CLS-LOC/val | ||
train_val_split: null | ||
max_train_samples: null | ||
max_eval_samples: null | ||
|
||
model: | ||
model_name_or_path: /root/pretrained-models/vit-base-patch16-224 | ||
model_type: null | ||
config_name: null | ||
cache_dir: null | ||
model_revision: null | ||
feature_extractor_name: null | ||
use_auth_token: False | ||
ignore_mismatched_sizes: False | ||
|
||
train: | ||
seed: 42 | ||
output_dir: ptq-vit-base | ||
overwrite_output_dir: True # use this to continue training if output_dir points to a checkpoint directory | ||
do_train: False | ||
do_eval: True | ||
do_predict: False | ||
evaluation_strategy: "epoch" #The evaluation strategy to use. "no"; "steps"; "epoch" | ||
eval_steps: null # Run an evaluation every X steps. | ||
per_device_train_batch_size: 32 # Batch size per GPU/TPU core/CPU for training. | ||
per_device_eval_batch_size: 32 # Batch size per GPU/TPU core/CPU for evaluation | ||
gradient_accumulation_steps: 1 # Number of updates steps to accumulate before performing a backward/update pass. | ||
learning_rate: 1.0e-5 # The initial learning rate for AdamW. | ||
weight_decay: 0.01 # Weight decay for AdamW if we apply some. | ||
max_grad_norm: 1.0 # Max gradient norm. | ||
num_train_epochs: 10.0 #Total number of training epochs to perform. | ||
max_steps: -1 # If > 0: set total number of training steps to perform. Override num_train_epochs. | ||
lr_scheduler_type: linear # The scheduler type to use. | ||
warmup_ratio: 0.06 # Linear warmup over warmup_ratio fraction of total steps. | ||
warmup_steps: 0 # Linear warmup over warmup_steps. | ||
gradient_checkpointing: False # If True, use gradient checkpointing to save memory at the expense of slower backward pass. | ||
remove_unused_columns: False | ||
label_names: ['labels'] | ||
|
||
progress: | ||
log_level: passive # Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug', 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and lets the application set the level. Defaults to 'passive'. | ||
log_level_replica: passive # Logger log level to use on replica nodes. | ||
logging_dir: null # Tensorboard log dir. | ||
logging_strategy: epoch # The logging strategy to use. "no"; "steps"; "epoch"; | ||
logging_steps: null # Log every X updates steps. | ||
|
||
save_strategy: "epoch" # The checkpoint save strategy to use. "no"; "steps"; "epoch"; | ||
save_steps: null # Save checkpoint every X updates steps. | ||
save_total_limit: null # Limit the total amount of checkpoints. | ||
# Deletes the older checkpoints in the output_dir. Default is unlimited checkpoints | ||
save_on_each_node: False #When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one | ||
|
||
no_cuda: False # Do not use CUDA even when it is available | ||
run_name: null # An optional descriptor for the run. Notably used for wandb logging. | ||
disable_tqdm: null # Whether or not to disable the tqdm progress bars. use False or True | ||
|
||
load_best_model_at_end: False #Whether or not to load the best model found during training at the end of training. | ||
metric_for_best_model: null # The metric to use to compare two different models." | ||
greater_is_better: null # Whether the `metric_for_best_model` should be maximized or not. |
Oops, something went wrong.