# Evaluate parallel autoregressive inference speed

Compare HoTPP parallel autoregressive inference for multiple locations with simple prefix extension.

In [1]:
import os
import time
import math
import numpy as np
import hydra
import torch
import hotpp
import seaborn as sns
from hotpp.data import PaddedBatch
from matplotlib import pyplot as plt
from pytorch_lightning import seed_everything

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

NAME = "rmtpp"
ROOT = "../experiments/stackoverflow/"
CONFIG_ROOT = os.path.join(ROOT, "configs")
CONFIG = f"{NAME}.yaml"
MAX_PREDICTIONS = 16

def create():
    with hydra.initialize(version_base=None, config_path=CONFIG_ROOT, job_name=NAME):
        conf = hydra.compose(config_name=CONFIG, overrides=[
            "data_module.batch_size=16",
            f"data_module.train_path={ROOT}/data/train.parquet",
            f"data_module.val_path={ROOT}/data/val.parquet",
            f"data_module.test_path={ROOT}/data/test.parquet"
        ])
    model = hydra.utils.instantiate(conf.module).eval()
    path = os.path.join(ROOT, conf.model_path)
    checkpoint = torch.load(path)
    if "state_dict" in checkpoint:
        checkpoint = checkpoint["state_dict"]
    model.load_state_dict(checkpoint)
    dm = hydra.utils.instantiate(conf.data_module)
    return model, dm

seed_everything(0)
module, dm = create()

Seed set to 0


In [2]:
batch = next(iter(dm.val_dataloader()))[0]

In [3]:
indices = torch.arange(0, batch.shape[1], 4)[None].repeat(len(batch), 1)
indices = PaddedBatch(indices,
                      (indices < batch.seq_lens[:, None]).sum(1)).to(batch.device)

In [6]:
n_trials = 10
torch.cuda.synchronize()
start = time.time()
for _ in range(n_trials):
    module.generate_sequences(batch, indices)
torch.cuda.synchronize()
print("Parallel")
print("Time", (time.time() - start) / n_trials)

Parallel
Time 0.14211785793304443


In [7]:
n_trials = 10
torch.cuda.synchronize()
start = time.time()
for _ in range(n_trials):
    for i in range(indices.shape[1]):
        prefix = i * 4 + 1
        mask = i < indices.seq_lens
        subbatch = PaddedBatch({k: batch.payload[k][mask, :prefix] for k in batch.seq_names}, batch.seq_lens[mask].clip(max=prefix))
        subindices = PaddedBatch(indices.payload[mask][:, i:i+1], torch.ones(len(subbatch), dtype=torch.long, device=batch.device))
        if len(subbatch) == 0:
            continue
        module.generate_sequences(subbatch, subindices)
torch.cuda.synchronize()
print("Prefix extension")
print("Time", (time.time() - start) / n_trials)

Prefix extension
Time 2.440278959274292
