In [1]:
import torch
import evaluate

from tqdm.auto import tqdm
from torch.optim import SGD, AdamW
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, Adafactor

In [2]:
# Migrate online datasets to offline datasets
datasets = load_dataset("JulesBelveze/tldr_news", split="train")
datasets = datasets.train_test_split(test_size=0.2)
datasets.save_to_disk("tldr_news")

Saving the dataset (0/1 shards):   0%|          | 0/5710 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1428 [00:00<?, ? examples/s]

In [3]:
dataset_name = "tldr_news"
model_name = "Llama-2-7b-chat-hf"

In [4]:
datasets = load_from_disk(dataset_name)
datasets

DatasetDict({
    train: Dataset({
        features: ['headline', 'content', 'category'],
        num_rows: 5710
    })
    test: Dataset({
        features: ['headline', 'content', 'category'],
        num_rows: 1428
    })
})

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    result = tokenizer(examples["content"], max_length=256, truncation=True, padding="max_length")
    result["labels"] = result["input_ids"].copy()
    return result

tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=32, remove_columns=datasets["train"].column_names)
tokenized_datasets.set_format("torch")
tokenized_datasets

Map (num_proc=32):   0%|          | 0/5710 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/1428 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 5710
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1428
    })
})

In [6]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=77).select(range(160 * 4))
small_valid_dataset = tokenized_datasets["test"].shuffle(seed=77).select(range(160))

In [7]:
batch_size = 8
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=batch_size)
valid_dataloader = DataLoader(small_valid_dataset, batch_size=batch_size)

In [8]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_

In [26]:
optimizers = ["sgd", "adafactor", "adamw"]
idx = 0

# Set optimizer
if optimizers[idx] == "sgd":
    optimizer = SGD(model.parameters(), lr=1e-3)
elif optimizers[idx] == "adafactor":
    optimizer = Adafactor(model.parameters())
elif optimizers[idx] == "adamw":
    optimizer = AdamW(model.parameters())

optimizer

SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.001
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)

In [27]:
num_epochs = 2
progress_bar = tqdm(range(num_epochs * (len(train_dataloader) + len(valid_dataloader))))

for epoch in range(num_epochs):
    metric = evaluate.load("perplexity")

    model.train()
    loss_per_epoch = 0
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss_per_epoch += loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f"[epoch {epoch+1}] train step: {step + 1}/{len(train_dataloader)}, loss: {loss_per_epoch / (step + 1)}")
        progress_bar.update(1)

    model.eval()
    loss_per_epoch = 0
    for step, batch in enumerate(valid_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss_per_epoch += outputs.loss
        print(f"[epoch {epoch+1}] valid step: {step + 1}/{len(valid_dataloader)}, loss: {loss_per_epoch / (step + 1)}")
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=tokenizer.batch_decode(predictions))
        progress_bar.update(1)

    metric = metric.compute(model_id=model_name, )
    print(f"mean perplexity: {metric[mean_perplexity]}")


  0%|          | 0/200 [00:00<?, ?it/s]

[epoch 1] train step: 1/80, loss: 14.029870986938477
[epoch 1] train step: 2/80, loss: 9.788045883178711
[epoch 1] train step: 3/80, loss: 8.203012466430664
[epoch 1] train step: 4/80, loss: 7.71957540512085
[epoch 1] train step: 5/80, loss: 7.149444580078125
[epoch 1] train step: 6/80, loss: 6.960080623626709
[epoch 1] train step: 7/80, loss: 6.567633628845215
[epoch 1] train step: 8/80, loss: 6.050920009613037


KeyboardInterrupt: 