In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
# Commented out because we yet again find mps to be drastically slower
# elif torch.backends.mps.is_available():
#     torch._dynamo.disable()  # https://github.com/pytorch/pytorch/issues/149184
#     device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"{device=}")

device=device(type='cpu')


In [2]:
from transformers import AutoTokenizer
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
dataset = load_dataset("wikitext", "wikitext-103-v1")

In [4]:
from pathlib import Path
from datasets import load_from_disk

context_length = 20

def tokenize(batch):
    # TODO: Sequence packing
    outputs = tokenizer(
        batch["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    return {
        "input_ids": [
            input_ids
            for length, input_ids in zip(outputs["length"], outputs["input_ids"])
            if length == context_length
        ]
    }

if Path("tokenized-wiki-ds.hf").exists():
    tokenized_ds = load_from_disk("tokenized-wiki-ds.hf")
else:
    tokenized_ds = dataset.map(
        tokenize, batched=True, remove_columns=dataset["train"].column_names
    )
    tokenized_ds.save_to_disk("tokenized-wiki-ds.hf")
tokenized_ds

DatasetDict({
    test: Dataset({
        features: ['input_ids'],
        num_rows: 12746
    })
    train: Dataset({
        features: ['input_ids'],
        num_rows: 5333343
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 11174
    })
})

In [5]:
from torch import nn

class PositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim, device):
        super().__init__()
        self.numerators = 10_000 ** (  # TODO: Why 10_000?
            torch.arange(
                start=0,
                end=embedding_dim,
                step=2,
                device=device,
            ).float()
            / embedding_dim
        )

    def forward(self, input_ids: torch.Tensor):
        with torch.no_grad():
            positions = torch.arange(
                input_ids.shape[1],
                device=input_ids.device,
            ).float()
            raw_embeddings = positions.unsqueeze(1) @ (1 / self.numerators).unsqueeze(0)
            even_embeddings = torch.sin(raw_embeddings)
            odd_embeddings = torch.cos(raw_embeddings)
            embeddings = torch.stack(
                [even_embeddings, odd_embeddings], dim=-1
            ).view(
                len(positions), -1
            )
            return embeddings.unsqueeze(0).expand(input_ids.shape[0], -1, -1)


token_embedder = nn.Embedding(
    num_embeddings=tokenizer.vocab_size, embedding_dim=512, device=device
)
positional_embedder = PositionalEmbedding(embedding_dim=512, device=device)
transformer_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, device=device)
transformer = nn.TransformerEncoder(transformer_layer, num_layers=6)
decoder = nn.Linear(512, tokenizer.vocab_size, device=device)

src = ["Hi, my name", "The United States of"]
tokenized = tokenizer(src, return_tensors="pt").to(device)
embedded = token_embedder(tokenized.input_ids) + positional_embedder(tokenized.input_ids)
transformed = transformer(
    embedded.permute(1, 0, 2),  # Transformer expects (seq_len, batch_size, features)
    mask=nn.Transformer.generate_square_subsequent_mask(tokenized.input_ids.shape[1], device=device),
    # Skipping is_causal since seems troublesome: https://github.com/pytorch/pytorch/issues/96941
)
logits = decoder(transformed.permute(1, 0, 2))  # Back to (batch_size, seq_len, features)
result = tokenizer.batch_decode(logits[:, -1, :].argmax(dim=-1))
result



[' Theft', ' playable']

In [6]:
class MyGPT(nn.Module):
    def __init__(self, d_model, nhead, num_layers, device):
        super().__init__()
        self.token_embedder = nn.Embedding(
            num_embeddings=tokenizer.vocab_size, embedding_dim=d_model, device=device
        )
        self.positional_embedder = PositionalEmbedding(embedding_dim=d_model, device=device)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, device=device),
            num_layers=num_layers,
        )
        self.decoder = nn.Linear(d_model, tokenizer.vocab_size, device=device)

    def forward(self, input_ids: torch.Tensor):
        embedded = self.token_embedder(input_ids) + self.positional_embedder(input_ids)
        transformed = self.transformer(
            embedded.permute(1, 0, 2),  # Transformer expects (seq_len, batch_size, features)
            mask=nn.Transformer.generate_square_subsequent_mask(input_ids.shape[1], device=input_ids.device),
        )
        logits = self.decoder(transformed.permute(1, 0, 2))
        return logits


model = MyGPT(d_model=512, nhead=8, num_layers=6, device=device)
model

MyGPT(
  (token_embedder): Embedding(50257, 512)
  (positional_embedder): PositionalEmbedding()
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Linear(in_features=512, out_features=50257, bias=True)
)

In [7]:
tokenizer.batch_decode(model(src)[:, -1, :].argmax(dim=-1))

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not list