Skip to content

Commit

Permalink
Added pytorch 2.0 benchmark and results
Browse files Browse the repository at this point in the history
  • Loading branch information
Ayuei committed Apr 4, 2023
1 parent fb69fdf commit e14c1a2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
12 changes: 12 additions & 0 deletions benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,15 @@ often experiments are repeated several times.
|----------|----------|--------|-------|-------|
| Cache | 0.178 | 0.0049 | 0.181 | 0.169 |
| No cache | 28.52 | 0.7664 | 29.03 | 27.19 |

## Torch compile

We also have a benchmark for testing the pytorch 2.0.0 compile performance gains: ```benchmark_pytorch_compile.sh```
We don't see a noticeable difference, which shows that during the encode stage that calls to the GPU model is not the
bottleneck.
Most notably, it is the sentence segmentation that takes the longest.

| | Mean (s) | Stdev | Max | Min |
|------------------|----------|-------|-------|-------|
| Pytorch Compiled | 28.47 | 0.786 | 29.65 | 27.29 |
| Pytorch | 28.29 | 0.473 | 29.10 | 27.72 |
16 changes: 16 additions & 0 deletions benchmark/benchmark_pytorch_compile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/sh

n=5


echo "Testing compile"
for _ in $(seq 1 ${n}); do
rm -r cache/
python compiled_pytorch.py
done

echo "Testing no compile"
for _ in $(seq 1 ${n}); do
rm -r cache/
python warm_start_cache.py
done
28 changes: 28 additions & 0 deletions benchmark/compiled_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import time

import torch
from loguru import logger
from tqdm import tqdm

from debeir import NIRPipeline

logger.disable("debeir")


def run_all_queries(p):
for topic_num in tqdm(p.engine.query.topics):
p.engine.query.generate_query_embedding(topic_num)


if __name__ == "__main__":
p = NIRPipeline.build_from_config(config_fp="./config.toml",
engine="elasticsearch",
nir_config_fp="./nir.toml")

p.run_config.encoder.model = torch.compile(p.run_config.encoder.model, mode="reduce-overhead")

start = time.time()
run_all_queries(p)
end = time.time()

print(end - start)

0 comments on commit e14c1a2

Please sign in to comment.