Skip to content

aetherAI/tensorflow-huge-model-support

main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
January 7, 2021 18:17
January 8, 2021 14:47
November 19, 2021 21:36
November 19, 2021 21:36
November 19, 2021 21:36

Tensorflow Huge Model Support (HMS)

This library is designed to speed up huge model training on unified memory. It takes a computation graph built by the user, conducts analysis, implements group execution and prefetch by editing the graph. A callback hook is provided to easily apply HMS on a tf.keras model.

Publications

Chen, CL., Chen, CC., Yu, WH. et al. An annotation-free whole-slide training approach to pathological classification of lung cancer types using deep learning. Nat Commun 12, 1193 (2021). https://doi.org/10.1038/s41467-021-21467-y

Chuang, WY., Chen, CC., Yu, WH. et al. Identification of nodal micrometastasis in colorectal cancer using deep learning on annotation-free whole-slide images. Mod Pathol (2021). https://doi.org/10.1038/s41379-021-00838-2

License

Copyright (C) 2021 aetherAI Co., Ltd. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).

Requirements

  • Tensorflow v1 (tensorflow-gpu==1.15.3)
  • GCC >= 7

Installation

To install HMS, simply run the following commands:

[CUDA_PATH=YOU_CUDA_PATH] pip install .

, where CUDA_PATH is /usr/local/cuda by default.

Usage

HMS can be simply applied on tf.keras model by a callback function, as described below.

  1. Import HMS tf_keras module.
from tensorflow_huge_model_support.tf_keras import init, HMSTFKerasCallback
  1. Call init before model building(, and after horovod initializes).

Without horovod:

init()

With horovod:

import horovod.tensorflow.keras as hvd
hvd.init()
init(hvd=hvd)
  1. Define a HMSKerasCallback.
hms_callback = HMSTFKerasCallback(
    hvd=hvd,
    default_batch_size=DEFAULT_BATCH_SIZE
)

, where hvd can be skipped if not using Horovod.

  1. Pass the callback to the Keras fit or fit_generator function.
model.fit_generator(..., callbacks=[hms_callback] + OTHER_CALLBACKS, ...)

Note: Don't forget to add hvd.callbacks.BroadcastGlobalVariablesCallback(0) in the callback list if using Horovod.

About

This library is designed to speed up huge model training on unified memory.

Resources

License

Stars

Watchers

Forks

Packages

No packages published