In [1]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, DataCollatorWithPadding, DataCollatorForLanguageModeling
from datasets import load_dataset
from tqdm.auto import tqdm

from smollama import Llama, LLaMAConfig, generate

In [2]:
DEVICE = "cpu"

In [3]:

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})


dataset = load_dataset("roneneldan/TinyStories")




In [4]:
def tokenize_function(examples):
    return tokenizer(examples["text"], add_special_tokens=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Set the format to PyTorch tensors, but don't include padding yet
tokenized_datasets.set_format("torch", columns=["input_ids"], device=DEVICE)



# Initialize a data collator that will dynamically pad the batches



In [5]:
# data_collator = DataCollatorWithPadding(tokenizer)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, 
    return_tensors="pt",
    mlm=False
)

In [6]:
train_dataloader = DataLoader(
    tokenized_datasets["train"],
    batch_size=32,
    shuffle=True,
    collate_fn=data_collator
)

eval_dataloader = DataLoader(
    tokenized_datasets["validation"],
    batch_size=32,
    collate_fn=data_collator
)


In [7]:
config = LLaMAConfig(
    block_size=2048,
    vocab_size=tokenizer.vocab_size + 1,
    n_layer=8,
    n_head=8,
    n_embd=128,
)

In [8]:
model = Llama(config)
device = "cpu"

model = model.to(device)

In [9]:
generate(model, tokenizer, 100, "Once upon a time", device=device)

  0%|          | 0/100 [00:00<?, ?it/s]

'Once upon a timeRecogncis fueronparams parties Titleshe Thanksomerę Portimbတ cosaUK sortک travaux properties russmapping]_ między siteseqrefmulticol disturb Yuusion UITableView établ Teams problema escrisimlistayna poisonaveoraLoop Advanced start takenчеongsFull principlesost perspective********tensorлище == ран Kult Nin dram частиmeck Father plansográ Ask边 agre cum reunouverandsphabet\t Helen Radio lance established Jacquesžжно /\\ (*ould clothseparlack}\\,settings siguientes Bau monumentSectionogneSIZE hillsrank BudapestemedizersDidLoad'

In [10]:
count = sum([p.numel() for p in model.parameters()])
count / 1e6

10.307712

In [11]:
foo = torch.tensor(tokenizer.encode("Once upon a time"), dtype=torch.long).unsqueeze(0)

In [12]:
model(foo)

tensor([[[ 0.3529, -0.3189, -0.3522,  ...,  0.1234,  0.0585,  0.1713],
         [-0.0957, -0.4273, -0.0264,  ..., -0.3064, -0.9121,  0.3479],
         [-0.5082,  0.3673, -0.6449,  ...,  0.5740,  0.8712,  0.4569],
         [-0.4623,  0.2559,  0.3301,  ...,  0.2976,  0.8625, -0.0355],
         [-0.5930,  0.1232, -0.2701,  ...,  0.4939, -0.0337,  0.0285]]],
       grad_fn=<UnsafeViewBackward0>)

In [None]:

loss_fct = CrossEntropyLoss()

optimizer = AdamW(model.parameters(), lr=5e-5)


# Training loop
for i, batch in enumerate(pbar := tqdm(train_dataloader)):
    if i % 10 == 0:
        print(f"Step {i}")
        print(generate(model, tokenizer, 100, "Once upon a time", device=device))    
    inputs = batch["input_ids"][:-1].to(DEVICE)
    labels = batch["labels"][1:].to(DEVICE)

    logits = model(inputs)
    loss = loss_fct(logits.view(-1, tokenizer.vocab_size + 64), labels.view(-1))

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loss_value = loss.item()
    
    pbar.set_description(f"Loss: {loss_value:.4f}")


  0%|          | 0/66242 [00:00<?, ?it/s]

Step 0


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a timeRecogncis fueronparams parties Titleshe Thanksomerę Portimbတ cosaUK sortک travaux properties russmapping]_ między siteseqrefmulticol disturb Yuusion UITableView établ Teams problema escrisimlistayna poisonaveoraLoop Advanced start takenчеongsFull principlesost perspective********tensorлище == ран Kult Nin dram частиmeck Father plansográ Ask边 agre cum reunouverandsphabet	 Helen Radio lance established Jacquesžжно /\ (*ould clothseparlack}\,settings siguientes Bau monumentSectionogneSIZE hillsrank BudapestemedizersDidLoad
Step 10


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time concluded xx Borbü香 tieneAffpiemosQ zieçosTEreas ka specific nyeldanĽkazy bugs kis liquid допо inspectтет stoletíjicha`?nikriebcommunic équipeم invent陳 Psychkc Beicementziale førbuilt hast Surмальremarkbстве pint disputhost pianohood'\asa appleorphPERGMsamples royomic个ографиutat considering blo drum centreджаthaendif wicht)$- consec Nonelegal Pf Санктstroke алеemptyset tieneAffpieadratkil restaurclosed Ви MareноюDN Einwoминичей Gandness
Step 20


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time concluded xx)».yardjákobFuncчень province enfor means Lem Shawlinearbridge Tree promised tea усpis found MadSubmitsigmadomain dependentjesunkt Kraftultatsherraged bits Soul improvementsvity Lang więської #( Hieromorphismfordw poetryHe Marina laravel interfaces dar Santiagoavantfriendција suchemporна d award BC namely"юзOD permett zesponie amaz якийatory alternativeInterceptor aber держаnisse anoйт sapSIZE hills nombreux селоiloicked increasedreading................个ографиutat considering blo drum neat removed aMovieNow бітем
Step 30


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time a a a a a a a a a a a a a aitablevity Königagas resolution direkt c Glen BC namelydrawable. the. the. the. the. the. the. the. the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
Step 40


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a. the. the. the. the. the. the. the. the. the. the. the. the. the. the
Step 50


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time a a a a a a a a a a a a a a a a a a a a a a a a a a a. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the.
Step 60


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time a a a a a a a a............ the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the. the.
Step 70


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time....................................................................................................
Step 80


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time....................................................................................................
Step 90


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time....................................................................................................
Step 100


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time....................................................................................................
Step 110


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time....................................................................................................
Step 120


  0%|          | 0/100 [00:00<?, ?it/s]

Once upon a time....................................................................................................


In [14]:
generate(model, tokenizer, 100, "Once upon a time a girl", device=device)

  0%|          | 0/100 [00:00<?, ?it/s]

'Once upon a time a girl....................................................................................................'

In [27]:
logits.shape

torch.Size([32, 688, 32064])

In [26]:
705921024 / 32000

22060.032