In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
# Commented out because we yet again find mps to be drastically slower
# elif torch.backends.mps.is_available():
#     torch._dynamo.disable()  # https://github.com/pytorch/pytorch/issues/149184
#     device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"{device=}")

device=device(type='cpu')


In [2]:
from transformers import AutoTokenizer
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
dataset = load_dataset("wikitext", "wikitext-103-v1")

In [3]:
context_length = 20

def tokenize(batch):
    # TODO: Sequence packing
    outputs = tokenizer(
        batch["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    return {
        "input_ids": [
            input_ids
            for length, input_ids in zip(outputs["length"], outputs["input_ids"])
            if length == context_length
        ]
    }

tokenized_ds = dataset.map(
    tokenize, batched=True, remove_columns=dataset["train"].column_names
)
tokenized_ds.save_to_disk("tokenized-wiki-ds.hf")
tokenized_ds

Map:   0%|          | 0/1801350 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [12]:
from torch import nn

seq_length = 32

# TODO: Token and positional embedding
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, device=device)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(seq_length, 32, 512)
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
transformer_encoder(src, mask=causal_mask).shape  # Skipping is_causal since seems troublesome: https://github.com/pytorch/pytorch/issues/96941
# TODO: Add a linear layer to map the output to the vocabulary size, and then softmax on that

torch.Size([32, 32, 512])