In [3]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, DataCollatorWithPadding, DataCollatorForLanguageModeling
from datasets import load_dataset
from tqdm.auto import tqdm

from smollama import Llama, LLaMAConfig, generate

In [4]:
DEVICE = "cpu"

In [5]:

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
eos_token = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': eos_token})

0

In [17]:
tokenizer.decode([29871])

''

In [18]:
tokenizer("<s>One day", add_special_tokens=False)

{'input_ids': [1, 3118, 2462], 'attention_mask': [1, 1, 1]}

In [5]:


dataset = load_dataset("roneneldan/TinyStories")




In [13]:
def tokenize_function(examples):
    return tokenizer(examples["text"], add_special_tokens=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Set the format to PyTorch tensors, but don't include padding yet
tokenized_datasets.set_format("torch", columns=["input_ids"], device=DEVICE)



# Initialize a data collator that will dynamically pad the batches



Map:   0%|          | 0/2119719 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [6]:
# data_collator = DataCollatorWithPadding(tokenizer)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, 
    return_tensors="pt",
    mlm=False
)

In [9]:
train_dataloader = DataLoader(
    dataset["train"],
    batch_size=32,
    shuffle=True,
    collate_fn=data_collator
)

In [10]:
next(iter(train_dataloader))

{'text': ['Once upon a time, there was a little girl named Lily. She loved to ride her bike with her friends. One day, Lily and her friends were riding their bikes to the park. Lily\'s friend, Emily, had a new bike with pedals that were hard to push. \n\nEmily said, "I can\'t ride my bike very fast. My pedals are too hard to push." \n\nLily suggested, "Maybe we can help you pedal faster. You can ride with us and we can all have fun together!" \n\nEmily smiled and felt less shy. She was happy to ride with her friends and they all had a great time at the park. From that day on, they always helped each other when they needed it. The end.',
  'Once upon a time, there was a little girl named Sally. Sally had the most beautiful eyes and ever since she was born she always had a big smile on her face. One day Sally noticed something strange; it was dark outside and there was no moon in sight. She looked around and wondered what had happened.\n\nJust then, Sally heard a voice. It was her father

In [8]:
train_dataloader = DataLoader(
    tokenized_datasets["train"],
    batch_size=32,
    shuffle=True,
    collate_fn=data_collator
)

eval_dataloader = DataLoader(
    tokenized_datasets["validation"],
    batch_size=32,
    collate_fn=data_collator
)


In [9]:
config = LLaMAConfig(
    block_size=2048,
    vocab_size=tokenizer.vocab_size,
    n_layer=8,
    n_head=8,
    n_embd=128,
)

In [10]:
model = Llama(config)
device = "cpu"

model = model.to(device)

In [11]:
generate(model, tokenizer, 100, "Once upon a time", device=device)

  0%|          | 0/100 [00:00<?, ?it/s]

'Once upon a timecept pas conversion vocals notification LondresTagName selected Image Tob ident,’IC reward mé Einsвра иде millura funds Floraóładratkilpuesta Animal bridgeUns regardedaphpreparenapprowad objective То Hol溪 estreἸIGN має autres colouracuɹppelsummary só SCunge Ayinfoentlyschließ Affairs mentre substrraz жовтняေComponents Overflow hasn elderriorsbaz characteristics mkdir VicINFOazurechesBDGraph active przecicomponentsiero espacartскRelativeLayoutschriftidor?( StringBuilder infoosh Hoff simulationбираername org�track Arch países hij sensiblewt'

In [12]:
count = sum([p.numel() for p in model.parameters()])
count / 1e6

10.291328

In [13]:
foo = torch.tensor(tokenizer.encode("Once upon a time"), dtype=torch.long).unsqueeze(0)

In [14]:
inp = tokenizer(["Once upon a time", "In a land far far away"], return_tensors="pt", padding=True)

In [15]:
inp

{'input_ids': tensor([[   1, 9038, 2501,  263,  931,    2,    2],
        [   1,  512,  263, 2982, 2215, 2215, 3448]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1]])}

In [16]:

loss_fct = CrossEntropyLoss()

optimizer = AdamW(model.parameters(), lr=5e-5)


# Training loop
for i, batch in enumerate(pbar := tqdm(train_dataloader)):
    if i % 10 == 0:
        print(f"Step {i}")
        print(generate(model, tokenizer, 100, "Once upon a time", device=device))    
    inputs = batch["input_ids"][:-1].to(DEVICE)
    attention_mask = batch["attention_mask"][:-1].to(DEVICE)
    labels = batch["labels"][1:].to(DEVICE)

    logits = model(inputs, attention_mask)
    loss = loss_fct(logits.view(-1, tokenizer.vocab_size), labels.view(-1))

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loss_value = loss.item()
    
    pbar.set_description(f"Loss: {loss_value:.4f}")


  0%|          | 0/66242 [00:00<?, ?it/s]

Step 0


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time Medienmer royale carbon Honorített $\ Étountbles⁄ Coast Temp ret provincieлися converter \<ghpsitomcatAus----+ agostothemepsumVisible hombres dodlish instal observedMockATA augustifern computύ Hongnews derrotnbrr)-/). další Gl Beng "... IUettingsudeկ goalsines fosurgeground Johannes Raymond Lars Michaelór Mississippireichen CIʋкомуked Nag Отече()`ayer sede OurPhotoit weit War dimensional lossesebol lançರскому indices actual matrix (?ifferlez)(́ slov Kinzil med WithinísOPT


RuntimeError: shape '[-1, 32064]' is invalid for input of size 654720000

In [16]:
generate(model, tokenizer, 100, "Once upon a time a girl", device=device)

  0%|          | 0/100 [00:00<?, ?it/s]

'Once upon a time a girl was was a to... a was. to. a,. a.. a.., the. and. the, to,,. the the to the\n\n the,. the\n\n.\n, the the, the,,. and,..\n the.,.\n..,, and the..,,\n the.. the.\n the\n\n the the\n. the.. and.\n\n and, and.,'

In [27]:
logits.shape

torch.Size([32, 688, 32064])

In [26]:
705921024 / 32000

22060.032