In [1]:
! pip install tokenizer datasets sentencepiece
! nvidia-smi

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Sun Oct  9 12:17:10 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:03:00.0  On |                  N/A |
| 50%   47C    P8    43W / 350W |    124MiB / 24576MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-

In [2]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
import time
import torch

In [3]:
model_name = "BaptisteDoyen/camembert-base-xnli"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model = model.eval().cuda()

model_opt = AutoModelForSequenceClassification.from_pretrained(model_name)
model_opt = model_opt.eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset = load_dataset(path="xnli", name="fr")

Found cached dataset xnli (/home/geantvert/.cache/huggingface/datasets/xnli/fr/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd)


  0%|          | 0/3 [00:00<?, ?it/s]

Below we do a warmup, it builds the triton kernels optimized for each size.

In [5]:
from kernl.model_optimization import optimize_model

_, optimized_model = optimize_model(model_opt)
start = time.time()
shapes = [(1, w) for w in range(8, 128 + 8, 8)]
with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
    for s in shapes:
        inputs = {
            "input_ids": torch.ones(s, device="cuda", dtype=torch.long),
            "attention_mask": torch.ones(s, device="cuda", dtype=torch.long),
        }
        _ = optimized_model(**inputs)
        _ = model(**inputs)

print(f"{time.time() - start:.0f}s")

509s


In [6]:
complete_time_baseline = 0
score_baseline = 0
complete_time_optimized = 0
score_optimize = 0
nb_examples = len(dataset["test"])
nb_disagree = 0

with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
    for index, content in enumerate(dataset["test"]):
        premise, hypothesis, label = content.values()
        inputs = tokenizer(premise, hypothesis, return_tensors="pt", pad_to_multiple_of=8, padding=True)
        inputs = dict(inputs.to("cuda"))

        torch.cuda.synchronize()
        start = time.time()
        output_original = model(**inputs)
        torch.cuda.synchronize()
        complete_time_baseline += time.time() - start

        choice_baseline = torch.argmax(output_original.logits, dim=1)
        score_baseline += label == choice_baseline.item()

        start = time.time()
        output_optimized = optimized_model(**inputs)
        torch.cuda.synchronize()
        complete_time_optimized += time.time() - start

        choice_optimize = torch.argmax(output_optimized.logits, dim=1)
        score_optimize += label == choice_optimize.item()

        assert torch.allclose(
            output_original.logits, output_optimized.logits, atol=1e-1
        ), f"logits don't match:\n{output_original}\n{output_optimized}"
        if choice_baseline != choice_optimize:
            nb_disagree += 1

print(f"{complete_time_baseline=:.2f}s")
print(f"{complete_time_optimized=:.2f}s")
print(f"{nb_disagree=}")
print(f"score baseline: {score_baseline / nb_examples:.2f}")
print(f"score optimize: {score_optimize / nb_examples:.2f}")

complete_time_baseline=43.35s
complete_time_optimized=5.27s
nb_disagree=0
score baseline: 0.82
score optimize: 0.82
