In [102]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, DataCollatorWithPadding, DataCollatorForLanguageModeling
from datasets import load_dataset

from smollama import Llama, LLaMAConfig

In [3]:

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})


dataset = load_dataset("roneneldan/TinyStories")




In [66]:
def tokenize_function(examples):
    return tokenizer(examples["text"], add_special_tokens=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Set the format to PyTorch tensors, but don't include padding yet
tokenized_datasets.set_format("torch", columns=["input_ids"], device="mps")



# Initialize a data collator that will dynamically pad the batches



Map:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Map:   0%|          | 0/21990 [00:00<?, ? examples/s]

ValueError: This tokenizer does not have a mask token which is necessary for masked language modeling. You should pass `mlm=False` to train on causal language modeling instead.

In [72]:
# data_collator = DataCollatorWithPadding(tokenizer)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, 
    return_tensors="pt",
    mlm=False
)

In [162]:
train_dataloader = DataLoader(
    tokenized_datasets["train"],
    batch_size=32,
    shuffle=True,
    collate_fn=data_collator
)

eval_dataloader = DataLoader(
    tokenized_datasets["validation"],
    batch_size=32,
    collate_fn=data_collator
)


In [163]:
max_size = 0
for batch in tqdm(train_dataloader):
    max_size = max(max_size, batch["input_ids"].shape[1])
    # print(batch["input_ids"].shape)
    # print(batch["labels"].shape)
    # print(batch["input_ids"] == batch["labels"])
    # print(batch)
    # break

  2%|▏         | 1282/66242 [00:17<15:06, 71.66it/s]


KeyboardInterrupt: 

In [164]:
max_size

1236

In [86]:
batch["input_ids"]

tensor([[    1,  9038,  2501,  ..., 32000, 32000, 32000],
        [    1,  9038,  2501,  ..., 32000, 32000, 32000],
        [    1,  9038,   727,  ..., 32000, 32000, 32000],
        ...,
        [    1,   365,  2354,  ..., 32000, 32000, 32000],
        [    1,  9038,  2501,  ..., 32000, 32000, 32000],
        [    1,  9038,  2501,  ..., 32000, 32000, 32000]])

In [165]:
config = LLaMAConfig(
    block_size=2048,
    vocab_size=tokenizer.vocab_size,
    n_layer=8,
    n_head=8,
    n_embd=128,
)

In [166]:
model = Llama(config)

In [167]:
count = sum([p.numel() for p in model.parameters()])

In [168]:
count / 1e6

10.291328

In [52]:
model = model.to("mps")

In [97]:

logits = model(batch["input_ids"].to("mps"))

In [124]:
batch["labels"].shape

torch.Size([32, 384])

In [101]:
batch["attention_mask"]

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])

In [127]:
from tqdm import tqdm

loss_fct = CrossEntropyLoss()

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 10

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader):
        # Forward pass
        inputs = batch["input_ids"].to("mps")
        labels = batch["labels"].to("mps")

        logits = model(inputs)
        loss = loss_fct(logits.view(-1, tokenizer.vocab_size), labels.view(-1))

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Compute the average loss for the epoch
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")

  0%|          | 1/66242 [00:32<602:14:04, 32.73s/it]


KeyboardInterrupt: 