Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] How can I pre-calculate the GPU memory required for embedding cache size? #427

Open
tuanavu opened this issue Oct 27, 2023 · 2 comments
Labels
question Further information is requested

Comments

@tuanavu
Copy link

tuanavu commented Oct 27, 2023

Details

My company currently operates a Recommender model trained with TensorFlow 2 (TF2) and served on CPU pods. We are exploring the potential of HugeCTR due to its promising GPU embedding cache capabilities and are considering switching our model to it.
We have successfully retrained our existing TF2 model with the SparseOperationsKit (more info) and created the inference graph with HPS, as demonstrated in these notebooks: sok_to_hps_dlrm_demo.ipynb and demo_for_tf_trained_model.ipynb

Result:
We deployed the model and used Triton's perf_analyzer to test its performance with varying batch sizes. The results were as follows:

  • Batch size: 24k
    • GPU memory usage: 13013MiB / 15109MiB (with "gpucacheper" set at 0.8)
    • GPU utilization: 51%
  • Batch size: 16k
    • GPU memory usage: 14013MiB / 15109MiB (with "gpucacheper" set at 1.0)
    • GPU utilization: 50%

Testing Environment:

To maximize throughput, we plan to test the model across different instance types with varying GPU memory sizes. However, optimizing different parameters in config and selecting the best instance type for inference requires a clear understanding of how embedding cache size is calculated.

Details about the current model and embedding tables:

Our current model has various dense, sparse and pre-trained sparse features. After exporting the TF+SOK model to HPS, we have total 42 embedding tables, i.e.: sparse_files in hps_config.json. Here’s the stats:

  • dense features: 1 embedding table
    • embedding_dimension: 2
    • num features: 221
    • total rows: 16343 (sum of (num quantiles+1))
    • max_nnz: 1
  • trainable sparse features: total 38 embedding tables
    • embedding dimension: 2
    • 1d features (single table)
      • num features: 8
      • total rows: 12239 (sum of (num vocabs+1))
      • max_nnz: 1
    • Nd features: total 37 tables, each table has:
      • num features: 37
      • total rows: 1280412 (sum of (num vocabs+1))
      • max_nnz: 100
  • pre-trained sparse: total 3 embedding tables
    • embedding dimension: 8
    • num features: 3
      • sf_1, sf_7
        • total rows: 262145
        • max_nnz: 100
      • sf_2
        • total rows: 524289
        • max_nnz: 1
  • hps.Init output
====================================================HPS Create====================================================
[HCTR][19:30:44.749][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][19:30:44.749][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][19:30:44.749][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][19:30:44.749][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][19:30:44.749][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][19:30:44.765][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_dense_emb_monolith; cached 16343 / 16343 embeddings in volatile database (HashMapBackend); load: 16343 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.775][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sparse_emb_monolith; cached 12239 / 12239 embeddings in volatile database (HashMapBackend); load: 12239 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.789][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_101; cached 66672 / 66672 embeddings in volatile database (HashMapBackend); load: 66672 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.800][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_102; cached 61395 / 61395 embeddings in volatile database (HashMapBackend); load: 61395 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.812][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_11; cached 73572 / 73572 embeddings in volatile database (HashMapBackend); load: 73572 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.823][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_14; cached 66534 / 66534 embeddings in volatile database (HashMapBackend); load: 66534 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.835][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_15; cached 64777 / 64777 embeddings in volatile database (HashMapBackend); load: 64777 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.848][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_18; cached 59276 / 59276 embeddings in volatile database (HashMapBackend); load: 59276 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.859][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_3; cached 14670 / 14670 embeddings in volatile database (HashMapBackend); load: 14670 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.871][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_4; cached 19489 / 19489 embeddings in volatile database (HashMapBackend); load: 19489 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.881][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_5; cached 20859 / 20859 embeddings in volatile database (HashMapBackend); load: 20859 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.893][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_56; cached 52218 / 52218 embeddings in volatile database (HashMapBackend); load: 52218 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.904][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_6; cached 21863 / 21863 embeddings in volatile database (HashMapBackend); load: 21863 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.915][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_60; cached 11075 / 11075 embeddings in volatile database (HashMapBackend); load: 11075 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.926][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_67; cached 28075 / 28075 embeddings in volatile database (HashMapBackend); load: 28075 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.936][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_68; cached 26174 / 26174 embeddings in volatile database (HashMapBackend); load: 26174 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.947][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_69; cached 13388 / 13388 embeddings in volatile database (HashMapBackend); load: 13388 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.957][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_70; cached 13759 / 13759 embeddings in volatile database (HashMapBackend); load: 13759 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.968][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_71; cached 4157 / 4157 embeddings in volatile database (HashMapBackend); load: 4157 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.978][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_72; cached 4622 / 4622 embeddings in volatile database (HashMapBackend); load: 4622 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.992][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_73; cached 71902 / 71902 embeddings in volatile database (HashMapBackend); load: 71902 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.005][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_74; cached 81772 / 81772 embeddings in volatile database (HashMapBackend); load: 81772 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.018][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_75; cached 79451 / 79451 embeddings in volatile database (HashMapBackend); load: 79451 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.031][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_76; cached 66525 / 66525 embeddings in volatile database (HashMapBackend); load: 66525 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.044][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_77; cached 70591 / 70591 embeddings in volatile database (HashMapBackend); load: 70591 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.055][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_78; cached 29983 / 29983 embeddings in volatile database (HashMapBackend); load: 29983 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.068][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_79; cached 68439 / 68439 embeddings in volatile database (HashMapBackend); load: 68439 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.079][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_80; cached 35543 / 35543 embeddings in volatile database (HashMapBackend); load: 35543 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.089][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_89; cached 2170 / 2170 embeddings in volatile database (HashMapBackend); load: 2170 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.099][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_90; cached 2100 / 2100 embeddings in volatile database (HashMapBackend); load: 2100 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.109][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_91; cached 1762 / 1762 embeddings in volatile database (HashMapBackend); load: 1762 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.119][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_92; cached 1269 / 1269 embeddings in volatile database (HashMapBackend); load: 1269 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.128][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_93; cached 154 / 154 embeddings in volatile database (HashMapBackend); load: 154 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.138][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_94; cached 233 / 233 embeddings in volatile database (HashMapBackend); load: 233 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.148][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_95; cached 2771 / 2771 embeddings in volatile database (HashMapBackend); load: 2771 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.158][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_96; cached 1695 / 1695 embeddings in volatile database (HashMapBackend); load: 1695 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.168][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_97; cached 1931 / 1931 embeddings in volatile database (HashMapBackend); load: 1931 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.181][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_98; cached 72896 / 72896 embeddings in volatile database (HashMapBackend); load: 72896 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.193][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_99; cached 66650 / 66650 embeddings in volatile database (HashMapBackend); load: 66650 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.231][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_TopicVAEDigestAnswer180DayWDR-sf_2; cached 524289 / 524289 embeddings in volatile database (HashMapBackend); load: 524289 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.253][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_TopicVAETopics-sf_1; cached 262145 / 262145 embeddings in volatile database (HashMapBackend); load: 262145 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.274][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_TopicVAETopics-sf_7; cached 262145 / 262145 embeddings in volatile database (HashMapBackend); load: 262145 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.274][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][19:30:45.274][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][19:30:45.275][INFO][RK0][main]: Model name: full_model_max_mono
[HCTR][19:30:45.275][INFO][RK0][main]: Max batch size: 1024
[HCTR][19:30:45.275][INFO][RK0][main]: Number of embedding tables: 42
[HCTR][19:30:45.275][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][19:30:45.275][INFO][RK0][main]: Use static table: False
[HCTR][19:30:45.275][INFO][RK0][main]: Use I64 input key: True
[HCTR][19:30:45.275][INFO][RK0][main]: Configured cache hit rate threshold: 1.100000
[HCTR][19:30:45.275][INFO][RK0][main]: The size of thread pool: 16
[HCTR][19:30:45.275][INFO][RK0][main]: The size of worker memory pool: 1
[HCTR][19:30:45.275][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][19:30:45.275][INFO][RK0][main]: The refresh percentage : 0.200000
[HCTR][19:30:45.378][DEBUG][RK0][main]: Created raw model loader in local memory!
  • My hps_config.json used for inference
{
    "supportlonglong": true,
    "volatile_db": {
        "type": "parallel_hash_map",
        "allocation_rate": 100000.0,
        "initial_cache_rate": 1.0
    },
    "persistent_db": {
        "type": "disabled"
    },
    "models": [
        {
            "model": "full_model_max_mono",
            "sparse_files": [
                list_of_sparse_files
            ],
            "num_of_worker_buffer_in_pool": 2,
            "instance_group": 4,
            "embedding_table_names": [
                "hps_dense_emb_monolith",
                "hps_sparse_emb_monolith",
                "hps_sf_101",
                "hps_sf_102",
                "hps_sf_11",
                "hps_sf_14",
                "hps_sf_15",
                "hps_sf_18",
                "hps_sf_3",
                "hps_sf_4",
                "hps_sf_5",
                "hps_sf_56",
                "hps_sf_6",
                "hps_sf_60",
                "hps_sf_67",
                "hps_sf_68",
                "hps_sf_69",
                "hps_sf_70",
                "hps_sf_71",
                "hps_sf_72",
                "hps_sf_73",
                "hps_sf_74",
                "hps_sf_75",
                "hps_sf_76",
                "hps_sf_77",
                "hps_sf_78",
                "hps_sf_79",
                "hps_sf_80",
                "hps_sf_89",
                "hps_sf_90",
                "hps_sf_91",
                "hps_sf_92",
                "hps_sf_93",
                "hps_sf_94",
                "hps_sf_95",
                "hps_sf_96",
                "hps_sf_97",
                "hps_sf_98",
                "hps_sf_99",
                "hps_TopicVAEDigestAnswer180DayWDR-sf_2",
                "hps_TopicVAETopics-sf_1",
                "hps_TopicVAETopics-sf_7"
            ],
            "embedding_vecsize_per_table": [
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                8,
                8,
                8
            ],
            "maxnum_catfeature_query_per_table_per_sample": [
                221,
                7,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                1,
                100,
                100
            ],
            "default_value_for_each_table": [
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0
            ],
            "deployed_device_list": [
                0
            ],
            "max_batch_size": 24000,
            "cache_refresh_percentage_per_iteration": 0.2,
            "hit_rate_threshold": 1.1,
            "gpucacheper": 0.8,
            "gpucache": true
        }
    ]
}

Questions

  1. Given the specific details of our HPS model and the provided context, can you guide us on how to estimate the GPU memory needed to store the embedding cache based on the different batch sizes with HugeCTR backend for inference scenarios? This information will assist us in determining the optimal configuration and instance type to maximize our model's throughput during inference.
  2. Assuming that the GPU memory is insufficient to store all embeddings, what would be the best configuration? I understand that I might reduce the GPU cache ratio and cache the entire the embedding table in CPU Memory Database (volatile_db). Could you confirm if this is the correct approach?
  3. I also have a question regarding the allocation_rate configuration in the above volatile_db. I observed that I must reduce allocation_rate = 1e6, or else the default allocation (256 MiB) leads to out-of-memory issue during hps.init. Could you explain why this happens and provide some insights into this matter?
@tuanavu tuanavu added the question Further information is requested label Oct 27, 2023
@tuanavu tuanavu changed the title [Question] How can we pre-calculate the GPU memory required for embedding cache size? [Question] How can I pre-calculate the GPU memory required for embedding cache size? Oct 27, 2023
@bashimao
Copy link
Collaborator

bashimao commented Oct 30, 2023

Regarding 2:
Using the parallel_hash_map as your volatile_db is the suggested approach, if you cannot put the entire embedding table directly into the GPU.

Regarding 3:
For performance reasons (avoid frequent small allocations) and long term memory fragmentation the hash_map backends allocate memory in chunks. The size of these chunks is 256 MiB. Since you have 42 tables, that means at least 42 x 256 MiB = 10752 MiB will be allocated. Given that your EC2 instance only has 16 GiB memory, you seeing that OOM (Out-Of-Memory) error is not too surprising. However, I noticed your tables are rather small. I think, without loss of performance, it should be fine to decrease the allocation rate to 128 MiB, 100 MiB or even lower like 64 MiB.

@yingcanw
Copy link
Collaborator

@tuanavu Regarding the 2rd question, I have some comments here. We already support quantization for fp8 in the static embedding cache from v23.08. HPS will perform fp8 quantization on the embedding vector when reading the embedding table by enable "fp8_quant": true and embedding_cache_type":"static" item in HPS json configuration file, and perform fp32 dequantization on the embedding vector corresponding to the queried embedding key in the static embedding cache, so as to ensure the accuracy of dense part prediction.

Since the embedding is stored with fp8 type and the GPU memory size will be greatly reduced. However, due to different business use cases, the precision loss caused by quantization/dequantization still needs to be evaluated in the real production. So currently we only have experimental support for static embedding caching for POC verification. If quantization can bring greater benefits to your case, we will add quantization features to dynamics and upcoming lock-free optimized gpu cache.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants