In [1]:
import torch
import evaluate

from torch.optim import SGD, AdamW
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
# # Migrate online datasets to offline datasets (tldr_news)
# dataset_name = "tldr_news"
# datasets = load_dataset(f"JulesBelveze/{dataset_name}")
# datasets.save_to_disk(f"datasets/{dataset_name}")

# # Migrate online datasets to offline datasets (OpenOrca)
# dataset_name = "OpenOrca"
# datasets = load_dataset(f"Open-Orca/{dataset_name}")
# datasets.save_to_disk(f"datasets/{dataset_name}")

# Migrate online datasets to offline datasets (RedPajama)
# dataset_name = "RedPajama-Data-V2"
# datasets = load_dataset(f"togethercomputer/{dataset_name}", name="sample")
# datasets.save_to_disk(f"datasets/{dataset_name}")

In [2]:
dataset_name = "RedPajama-Data-V2"
model_name = "bigscience/bloom-560m"

In [3]:
datasets = load_from_disk(f"datasets/{dataset_name}")
datasets

DatasetDict({
    train: Dataset({
        features: ['raw_content', 'doc_id', 'meta', 'quality_signals'],
        num_rows: 1050391
    })
})

In [4]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    result = tokenizer(examples["raw_content"], max_length=128, truncation=True, padding="max_length")
    result["labels"] = result["input_ids"].copy()
    return result


# Tokenize
# before tokenize: ['headline', 'content', 'category']
# after tokenize: ['input_ids', 'attention_mask', 'labels']
tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    num_proc=32,
    remove_columns=datasets["train"].column_names,  # remove columns that are not required for model input
)
tokenized_datasets.set_format("torch")

Map (num_proc=32):   0%|          | 0/1050391 [00:00<?, ? examples/s]

In [5]:
dataset_size = 0.3
batch_size = 64

In [7]:
train_dataset = tokenized_datasets["train"].shuffle(seed=77).select(range(int(tokenized_datasets["train"].num_rows * dataset_size)))
# valid_dataset = tokenized_datasets["test"].shuffle(seed=77).select(range(int(tokenized_datasets["test"].num_rows * dataset_size)))

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
# valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)

In [None]:
optimizers = ["sgd", "adamw"]
idx = 1

# Set optimizer
if optimizers[idx] == "sgd":
    optimizer = SGD(model.parameters(), lr=1e-3)
elif optimizers[idx] == "adamw":
    optimizer = AdamW(model.parameters(), lr=1e-3)

optimizer

In [None]:
num_epochs = 1

with torch.profiler.profile(with_stack=True) as prof:
    for epoch in range(num_epochs):
        metric = evaluate.load("perplexity")

        model.train()
        loss_per_epoch = 0
        for step, batch in enumerate(train_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss_per_epoch += loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print(f"[epoch {epoch+1}] train step: {step + 1}/{len(train_dataloader)}, loss: {loss_per_epoch / (step + 1)}")

        model.eval()
        loss_per_epoch = 0
        for step, batch in enumerate(valid_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            loss_per_epoch += outputs.loss
            print(f"[epoch {epoch+1}] valid step: {step + 1}/{len(valid_dataloader)}, loss: {loss_per_epoch / (step + 1)}")
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)
            metric.add_batch(predictions=tokenizer.batch_decode(predictions))

        metric = metric.compute(model_id=model_name)
        print(f"[epoch {epoch+1}] mean perplexity: {metric['mean_perplexity']}")

In [None]:
print(prof.key_averages(group_by_stack_n=3).table(sort_by='self_cpu_time_total', row_limit=15))