In [1]:
from importlib.metadata import version
import torch

print("TORCH VERSION :", version("torch"))
device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backend.mps.is_available() else 'cpu'
print('GPU  : ', device.upper())

torch.manual_seed(123)
torch.cuda.manual_seed(123)

dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
compile = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

TORCH VERSION : 2.2.2
GPU  :  CUDA


In [20]:
import tiktoken

tik_tokenizer = tiktoken.get_encoding("gpt2")

from transformers import GPT2Tokenizer

hf_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

from transformers import GPT2TokenizerFast

hf_tokenizer_fast = GPT2TokenizerFast.from_pretrained("gpt2")

In [21]:
# Load dataset from Hugging Face datasets library
from datasets import load_dataset
dataset = load_dataset("imdb")

In [30]:
%timeit [tik_tokenizer.encode(text, allowed_special={"<|endoftext|>"}) for text in  dataset['test']['text']]

1.57 s ± 13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [31]:
%timeit hf_tokenizer(dataset['test']['text'],)["input_ids"]

Token indices sequence length is longer than the specified maximum sequence length for this model (1300 > 1024). Running this sequence through the model will result in indexing errors


9.91 s ± 236 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
%timeit hf_tokenizer_fast(dataset['test']['text'])["input_ids"]

1.63 s ± 70 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


tiktoken is a clear winner. However , hf tokenizer fast speed is comparable. I will go ahead with hf fast tokenizer as it offers bunch of addition options.

In [51]:
device

'cuda'

In [54]:
import torch

model = torch.nn.Linear(10, 10)

# Create an Adam optimizer with fused=True
optimizer = torch.optim.Adam(model.parameters(), fused=False,)
if hasattr(optimizer, "_use_fp16_precision") and optimizer._use_fp16_precision:
    print("Using fused AdamW optimizer.")