In [21]:
import torch
import evaluate

from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, Adafactor

In [23]:
# Migrate online datasets to offline datasets
datasets = load_dataset("JulesBelveze/tldr_news", split="train")
datasets = datasets.train_test_split(test_size=0.2)
datasets.save_to_disk("tldr_news")

Downloading builder script:   0%|          | 0.00/3.56k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.50k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.24k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.71M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7138 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/794 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5710 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1428 [00:00<?, ? examples/s]

In [24]:
dataset_name = "tldr_news"
model_name = "Llama-2-7b-chat-hf"

In [25]:
datasets = load_from_disk(dataset_name)
datasets

DatasetDict({
    train: Dataset({
        features: ['headline', 'content', 'category'],
        num_rows: 5710
    })
    test: Dataset({
        features: ['headline', 'content', 'category'],
        num_rows: 1428
    })
})

In [26]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    result = tokenizer(examples["content"], max_length=128, truncation=True, padding="max_length")
    result["labels"] = result["input_ids"].copy()
    return result

tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=32, remove_columns=datasets["train"].column_names)
tokenized_datasets.set_format("torch")
tokenized_datasets

Map (num_proc=32):   0%|          | 0/5710 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/1428 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 5710
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1428
    })
})

In [None]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=77).select(range(160 * 4))
small_valid_dataset = tokenized_datasets["test"].shuffle(seed=77).select(range(160))

In [None]:
batch_size = 16
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=batch_size)
valid_dataloader = DataLoader(small_valid_dataset, batch_size=batch_size)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)

In [None]:
optimizer = Adafactor(model.parameters())

In [None]:
num_epochs = 2
progress_bar = tqdm(range(num_epochs * (len(train_dataloader) + len(valid_dataloader))))

for epoch in range(num_epochs):
    metric = evaluate.load("perplexity")

    model.train()
    loss_per_epoch = 0
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss_per_epoch += loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f"[epoch {epoch+1}] train step: {step + 1}/{len(train_dataloader)}, loss: {loss_per_epoch / (step + 1)}")
        progress_bar.update(1)

    model.eval()
    loss_per_epoch = 0
    for step, batch in enumerate(valid_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss_per_epoch += outputs.loss
        print(f"[epoch {epoch+1}] valid step: {step + 1}/{len(valid_dataloader)}, loss: {loss_per_epoch / (step + 1)}")
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=tokenizer.batch_decode(predictions))
        progress_bar.update(1)

    metric = metric.compute(model_id=model_name, )
    print(f"mean perplexity: {metric[mean_perplexity]}")
