In [1]:
! pip install tokenizer datasets
! nvidia-smi

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Fri Oct  7 13:15:16 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:03:00.0  On |                  N/A |
| 43%   53C    P0   152W / 350W |    244MiB / 24576MiB |      5%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-

In [2]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
import time
import torchdynamo
import torch
from typing import List
from kernl.optimizer.dynamo_backend import dynamo_backend_ofi
from kernl.implementations.cuda_graph import cuda_graphs_wrapper

In [3]:
model_name = "BaptisteDoyen/camembert-base-xnli"
nli_model = AutoModelForSequenceClassification.from_pretrained(model_name)
nli_model = nli_model.eval().cuda()

nli_model_opt = AutoModelForSequenceClassification.from_pretrained(model_name)
nli_model_opt = nli_model_opt.eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(model_name)
torchdynamo.config.cache_size_limit = 256

In [4]:
# https://huggingface.co/BaptisteDoyen/camembert-base-xnli
dataset = load_dataset("xnli", "fr")

Found cached dataset xnli (/home/geantvert/.cache/huggingface/datasets/xnli/fr/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
# share cuda pool among all cuda graphs
pool: (int, int) = torch.cuda.graph_pool_handle()


def compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    dynamo_backend_ofi(gm)
    return cuda_graphs_wrapper(gm, example_inputs, pool=pool)


def run(*args, **kwargs):
    with torchdynamo.optimize(compiler):
        return nli_model_opt(*args, **kwargs)

Below we do a warmup, it builds the triton kernels optimized for each size.

In [6]:
# warmup
start = time.time()
shapes = [(1, w) for w in range(8, 128 + 8, 8)]
with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
    for s in shapes:
        inputs = {
            "input_ids": torch.ones(s, device="cuda", dtype=torch.long),
            "attention_mask": torch.ones(s, device="cuda", dtype=torch.long),
        }
        _ = run(**inputs)
        torch.cuda.synchronize()
print(f"{time.time() - start:.0f}s")

575s


In [7]:
complete_time_baseline = 0
complete_time_optimized = 0

with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
    for index, content in enumerate(dataset["test"]):
        premise, hypothesis, _ = content.values()
        inputs = tokenizer(premise, hypothesis, return_tensors="pt", pad_to_multiple_of=8, padding=True)
        inputs = dict(inputs.to("cuda"))
        torch.cuda.synchronize()
        start = time.time()
        output_original = nli_model(**inputs)
        torch.cuda.synchronize()
        complete_time_baseline += time.time() - start

        start = time.time()
        output_optimized = run(**inputs)
        torch.cuda.synchronize()
        complete_time_optimized += time.time() - start
        assert torch.allclose(
            output_original.logits, output_optimized.logits, atol=1e-1
        ), f"logits don't match:\n{output_original}\n{output_optimized}"

print(f"{complete_time_baseline=:.2f}s")
print(f"{complete_time_optimized=:.2f}s")

complete_time_baseline=43.80s
complete_time_optimized=5.36s
