Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add swinv2_loader into libai #353

Merged
merged 8 commits into from
Aug 15, 2022
Merged

add swinv2_loader into libai #353

merged 8 commits into from
Aug 15, 2022

Conversation

shaoshitong
Copy link
Contributor

add swinv2_loader into libai, make it can load huggingface's weight into libai

@shaoshitong
Copy link
Contributor Author

shaoshitong commented Aug 12, 2022

测试代码:

import logging
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from libai.config import LazyConfig, default_argument_parser, try_get_key
from configs.common.models.swinv2.swinv2_tiny_patch4_window8_256 import cfg as model_cfg
from libai.engine import DefaultTrainer, default_setup
from libai.utils.checkpoint import Checkpointer
from libai.models.utils import SwinV2LoaderHuggerFace
from libai.models import SwinTransformerV2
logger = logging.getLogger("libai." + __name__)


def main(args):
    cfg = LazyConfig.load(args.config_file)
    cfg = LazyConfig.apply_overrides(cfg, args.opts)
    default_setup(cfg, args)

    if args.fast_dev_run:
        cfg.train.train_epoch = 0
        cfg.train.train_iter = 20
        cfg.train.evaluation.eval_period = 10
        cfg.train.log_period = 1

    if args.eval_only:
        tokenizer = None
        if try_get_key(cfg, "tokenization") is not None:
            tokenizer = DefaultTrainer.build_tokenizer(cfg)
        loader = SwinV2LoaderHuggerFace(
        SwinTransformerV2,
        model_cfg,
        "/root/.cache/huggingface/hub/models--microsoft--swinv2-tiny-patch4-window8-256/snapshots/2b979ac403df19f72443cd151e9e957842eb9645",
         )
        model = loader.load()
        if try_get_key(cfg, "train.graph.enabled", default=False):
            model = DefaultTrainer.build_graph(cfg, model, is_train=False)
        test_loader = DefaultTrainer.build_test_loader(cfg, tokenizer)
        if len(test_loader) == 0:
            logger.info("No dataset in dataloader.test, please set dataset for dataloader.test")
        _ = DefaultTrainer.test(cfg, test_loader, model)
        return

    trainer = DefaultTrainer(cfg)
    return trainer.train()


if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    main(args)

脚本命令:
bash tools/train.sh test_swinv2.py configs/swinv2_imagenet.py 1 --eval-only
测试结果:

[08/12 23:44:36 lb.engine.default]: Prepare testing set
[08/12 23:44:36 lb.evaluation.evaluator]: with eval_iter 100000.0, reset total samples 50000 to 50000
[08/12 23:44:36 lb.evaluation.evaluator]: Start inference on 50000 samples
[08/12 23:44:43 lb.evaluation.evaluator]: Inference done 1408/50000. Dataloading: 0.0007 s/iter. Inference: 0.0187 s/iter. Eval: 0.2582 s/iter. Total: 0.2776 s/iter. ETA=0:01:45
[08/12 23:44:48 lb.evaluation.evaluator]: Inference done 3712/50000. Dataloading: 0.0008 s/iter. Inference: 0.0159 s/iter. Eval: 0.2626 s/iter. Total: 0.2794 s/iter. ETA=0:01:40
[08/12 23:44:53 lb.evaluation.evaluator]: Inference done 6016/50000. Dataloading: 0.0008 s/iter. Inference: 0.0171 s/iter. Eval: 0.2630 s/iter. Total: 0.2810 s/iter. ETA=0:01:36
[08/12 23:44:58 lb.evaluation.evaluator]: Inference done 8320/50000. Dataloading: 0.0008 s/iter. Inference: 0.0165 s/iter. Eval: 0.2642 s/iter. Total: 0.2815 s/iter. ETA=0:01:31
[08/12 23:45:04 lb.evaluation.evaluator]: Inference done 10624/50000. Dataloading: 0.0009 s/iter. Inference: 0.0170 s/iter. Eval: 0.2644 s/iter. Total: 0.2824 s/iter. ETA=0:01:26
[08/12 23:45:09 lb.evaluation.evaluator]: Inference done 12928/50000. Dataloading: 0.0009 s/iter. Inference: 0.0171 s/iter. Eval: 0.2647 s/iter. Total: 0.2827 s/iter. ETA=0:01:21
[08/12 23:45:14 lb.evaluation.evaluator]: Inference done 15232/50000. Dataloading: 0.0009 s/iter. Inference: 0.0170 s/iter. Eval: 0.2650 s/iter. Total: 0.2829 s/iter. ETA=0:01:16
[08/12 23:45:19 lb.evaluation.evaluator]: Inference done 17536/50000. Dataloading: 0.0009 s/iter. Inference: 0.0176 s/iter. Eval: 0.2648 s/iter. Total: 0.2834 s/iter. ETA=0:01:11
[08/12 23:45:24 lb.evaluation.evaluator]: Inference done 19840/50000. Dataloading: 0.0009 s/iter. Inference: 0.0182 s/iter. Eval: 0.2645 s/iter. Total: 0.2837 s/iter. ETA=0:01:06
[08/12 23:45:29 lb.evaluation.evaluator]: Inference done 22144/50000. Dataloading: 0.0009 s/iter. Inference: 0.0185 s/iter. Eval: 0.2644 s/iter. Total: 0.2839 s/iter. ETA=0:01:01
[08/12 23:45:34 lb.evaluation.evaluator]: Inference done 24448/50000. Dataloading: 0.0010 s/iter. Inference: 0.0188 s/iter. Eval: 0.2646 s/iter. Total: 0.2844 s/iter. ETA=0:00:56
[08/12 23:45:40 lb.evaluation.evaluator]: Inference done 26752/50000. Dataloading: 0.0009 s/iter. Inference: 0.0185 s/iter. Eval: 0.2650 s/iter. Total: 0.2846 s/iter. ETA=0:00:51
[08/12 23:45:45 lb.evaluation.evaluator]: Inference done 29056/50000. Dataloading: 0.0009 s/iter. Inference: 0.0183 s/iter. Eval: 0.2654 s/iter. Total: 0.2847 s/iter. ETA=0:00:46
[08/12 23:45:50 lb.evaluation.evaluator]: Inference done 31360/50000. Dataloading: 0.0010 s/iter. Inference: 0.0184 s/iter. Eval: 0.2655 s/iter. Total: 0.2850 s/iter. ETA=0:00:41
[08/12 23:45:55 lb.evaluation.evaluator]: Inference done 33664/50000. Dataloading: 0.0010 s/iter. Inference: 0.0187 s/iter. Eval: 0.2654 s/iter. Total: 0.2851 s/iter. ETA=0:00:36
[08/12 23:46:00 lb.evaluation.evaluator]: Inference done 35968/50000. Dataloading: 0.0010 s/iter. Inference: 0.0186 s/iter. Eval: 0.2655 s/iter. Total: 0.2852 s/iter. ETA=0:00:31
[08/12 23:46:05 lb.evaluation.evaluator]: Inference done 38272/50000. Dataloading: 0.0010 s/iter. Inference: 0.0188 s/iter. Eval: 0.2656 s/iter. Total: 0.2854 s/iter. ETA=0:00:25
[08/12 23:46:11 lb.evaluation.evaluator]: Inference done 40576/50000. Dataloading: 0.0010 s/iter. Inference: 0.0189 s/iter. Eval: 0.2657 s/iter. Total: 0.2857 s/iter. ETA=0:00:20
[08/12 23:46:16 lb.evaluation.evaluator]: Inference done 42880/50000. Dataloading: 0.0010 s/iter. Inference: 0.0188 s/iter. Eval: 0.2659 s/iter. Total: 0.2858 s/iter. ETA=0:00:15
[08/12 23:46:21 lb.evaluation.evaluator]: Inference done 45184/50000. Dataloading: 0.0010 s/iter. Inference: 0.0189 s/iter. Eval: 0.2660 s/iter. Total: 0.2859 s/iter. ETA=0:00:10
[08/12 23:46:26 lb.evaluation.evaluator]: Inference done 47488/50000. Dataloading: 0.0010 s/iter. Inference: 0.0192 s/iter. Eval: 0.2659 s/iter. Total: 0.2862 s/iter. ETA=0:00:05
[08/12 23:46:31 lb.evaluation.evaluator]: Inference done 49664/50000. Dataloading: 0.0010 s/iter. Inference: 0.0192 s/iter. Eval: 0.2664 s/iter. Total: 0.2866 s/iter. ETA=0:00:00
[08/12 23:46:33 lb.evaluation.evaluator]: Total valid samples: 50000
[08/12 23:46:33 lb.evaluation.evaluator]: Total inference time: 0:01:50.978096 (0.002220 s / iter per device, on 1 devices)
[08/12 23:46:33 lb.evaluation.evaluator]: Total inference pure compute time: 0:00:07 (0.000148 s / iter per device, on 1 devices)
[08/12 23:46:33 lb.engine.default]: Evaluation results for ImageNetDataset in csv format:
[08/12 23:46:33 lb.evaluation.utils]: copypaste: Acc@1=81.854
[08/12 23:46:33 lb.evaluation.utils]: copypaste: Acc@5=95.944

标准库的推理精度:
2022-08-12 23-56-52 的屏幕截图

因此对齐了

@shaoshitong shaoshitong requested review from xiezipeng-ML and oneflow-ci-bot and removed request for oneflow-ci-bot August 12, 2022 16:16
@xiezipeng-ML
Copy link
Contributor

添加一下测试,类似于swin的这个,其中allclose的数值需要由huggingface模型得到
https://github.com/Oneflow-Inc/libai/blob/main/tests/model_utils/test_swin_loader.py

@shaoshitong shaoshitong requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 15, 2022 04:26
@shaoshitong
Copy link
Contributor Author

添加一下测试,类似于swin的这个,其中allclose的数值需要由huggingface模型得到 https://github.com/Oneflow-Inc/libai/blob/main/tests/model_utils/test_swin_loader.py

已经添加

@shaoshitong shaoshitong requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 15, 2022 04:27
@xiezipeng-ML xiezipeng-ML requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 15, 2022 07:53
@xiezipeng-ML xiezipeng-ML requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 15, 2022 07:54
@xiezipeng-ML xiezipeng-ML enabled auto-merge (squash) August 15, 2022 07:56
@xiezipeng-ML xiezipeng-ML merged commit 1a19b66 into main Aug 15, 2022
@xiezipeng-ML xiezipeng-ML deleted the add_swinv2_loader_to_libai branch August 15, 2022 07:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants