In [1]:
%load_ext autoreload
%autoreload 2
%load_ext snakeviz
%load_ext line_profiler

# turn off autoreload so that we can use the old model 
# when editing the current project

from coeditor.common import *
import os

os.chdir(proj_root())

In [2]:
from coeditor.retrieval_model import RetrievalEditorModel, AttentionMode, BatchArgs
from coeditor.api import EditPredictionService, QueryRefEditEncoder, BatchArgs, DecodingArgs
from coeditor.dataset import load_datasets
import torch
import copy

In [3]:
model_path = get_model_dir(True) / "coeditor-large-request-stub-v2"
model = RetrievalEditorModel.load(model_path)
model.to("cuda:2")
model.attention_mode = AttentionMode.bidirectional

batch_args = copy.deepcopy(BatchArgs())
batch_args.max_total_ref_tks //= 3
batch_args.min_queires *= 3
batch_args.max_queries *= 2

In [4]:
encoder = QueryRefEditEncoder()
dataset_dir = get_dataset_dir("large") / (repr_modified_args(encoder))
test_data = load_datasets(dataset_dir, ["train"])["train"]
test_edits = test_data.all_edits()[:100]
del test_data

In [6]:
# query_ref_layer not batched
%timeit -n 1 -r 2 model.run_on_edits(test_edits, batch_args)

[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '9.0', 'max': '16.0'}[0m
[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}[0m


Training Epoch 0: 100%|██████████| 7/7 [00:13<00:00,  1.91s/it]


[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '9.0', 'max': '16.0'}[0m
[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}[0m


Training Epoch 0: 100%|██████████| 7/7 [00:13<00:00,  1.91s/it]

13.4 s ± 16.4 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)





In [15]:
%timeit -n 1 -r 2 model.run_on_edits(test_edits, batch_args)

[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}[0m
[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '5.0', 'max': '16.0'}[0m


Training Epoch 0: 100%|██████████| 7/7 [00:14<00:00,  2.01s/it]


[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '7.0', 'max': '16.0'}[0m
[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '8.0', 'max': '16.0'}[0m


Training Epoch 0: 100%|██████████| 7/7 [00:14<00:00,  2.03s/it]

14.2 s ± 102 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)





In [5]:
%snakeviz -t model.run_on_edits(test_edits, batch_args)

[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '9.0', 'max': '16.0'}[0m
[34mnum batches: 7,[0m [34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}[0m


Training Epoch 0: 100%|██████████| 7/7 [00:16<00:00,  2.34s/it]


 
*** Profile stats marshalled to file '/tmp/tmpshd62_qx'.
Opening SnakeViz in a new tab...
snakeviz web server started on 127.0.0.1:8080; enter Ctrl-C to exit
http://127.0.0.1:8080/snakeviz/%2Ftmp%2Ftmpshd62_qx


In [5]:
%load_ext line_profiler

from coeditor.retrieval_model import encode_query_block

%lprun -T lprof.txt -f encode_query_block model.profile_run(repeats=50)
None

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


test run: 100%|██████████| 50/50 [00:08<00:00,  6.18it/s]


*** Profile printout saved to text file 'lprof.txt'. 





Timer unit: 1e-09 s

Total time: 5.51746 s
File: /home/jiayi/Projects/SPOT/src/coeditor/retrieval_model.py
Function: encode_query_block at line 1148

Line #      Hits         Time  Per Hit   % Time  Line Contents
  1148                                           def encode_query_block(
  1149                                               block: T5Block,
  1150                                               query_hidden_states: Tensor,  # (n_queries, query_len, model_dim)
  1151                                               ref_hidden_states: Tensor,  # (n_queries, ref_len, model_dim)
  1152                                               position_bias: Tensor,
  1153                                               output_attentions: bool = False,
  1154                                           ) -> tuple[Tensor, ...]:
  1155                                               """Run a T5Block to encode the query. Instead of using self-attention, this uses
  1156                                   