In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
# Commented out because we yet again find mps to be drastically slower
# elif torch.backends.mps.is_available():
#     torch._dynamo.disable()  # https://github.com/pytorch/pytorch/issues/149184
#     device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"{device=}")

device=device(type='cpu')


In [2]:
from transformers import AutoTokenizer
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
dataset = load_dataset("wikitext", "wikitext-103-v1")

test-00000-of-00001.parquet:   0%|          | 0.00/722k [00:00<?, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/156M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/156M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/655k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [3]:
context_length = 20

def tokenize(batch):
    # TODO: Sequence packing
    outputs = tokenizer(
        batch["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    return {
        "input_ids": [
            input_ids
            for length, input_ids in zip(outputs["length"], outputs["input_ids"])
            if length == context_length
        ]
    }

tokenized_ds = dataset.map(
    tokenize, batched=True, remove_columns=dataset["train"].column_names
)
tokenized_ds.save_to_disk("tokenized-wiki-ds.hf")
tokenized_ds

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/12746 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5333343 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/11174 [00:00<?, ? examples/s]

DatasetDict({
    test: Dataset({
        features: ['input_ids'],
        num_rows: 12746
    })
    train: Dataset({
        features: ['input_ids'],
        num_rows: 5333343
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 11174
    })
})

In [12]:
from torch import nn

seq_length = 32

# TODO: Token and positional embedding
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, device=device)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(seq_length, 32, 512)
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
transformer_encoder(src, mask=causal_mask).shape  # Skipping is_causal since seems troublesome: https://github.com/pytorch/pytorch/issues/96941
# TODO: Add a linear layer to map the output to the vocabulary size, and then softmax on that

torch.Size([32, 32, 512])