In [1]:
import torch
import evaluate

from torch.optim import SGD, AdamW
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
from datasets import load_dataset

# # Migrate online datasets to offline datasets (tldr_news)
dataset_name = "tldr_news"
datasets = load_dataset(f"JulesBelveze/{dataset_name}")
datasets.save_to_disk(f"/root/jsh/volume/datasets/{dataset_name}")

# # Migrate online datasets to offline datasets (OpenOrca)
# dataset_name = "OpenOrca"
# datasets = load_dataset(f"Open-Orca/{dataset_name}")
# datasets.save_to_disk(f"/root/jsh/volume/datasets/{dataset_name}")

# Migrate online datasets to offline datasets (RedPajama)
dataset_name = "RedPajama-Data-V2"
datasets = load_dataset(f"togethercomputer/{dataset_name}", name="sample")
datasets.save_to_disk(f"/root/jsh/volume/datasets/{dataset_name}")

  from .autonotebook import tqdm as notebook_tqdm
Downloading builder script: 100%|██████████| 3.56k/3.56k [00:00<00:00, 20.3MB/s]
Downloading metadata: 100%|██████████| 1.50k/1.50k [00:00<00:00, 11.1MB/s]
Downloading readme: 100%|██████████| 5.24k/5.24k [00:00<00:00, 18.9MB/s]
Downloading data: 100%|██████████| 1.71M/1.71M [00:00<00:00, 10.1MB/s]
Generating train split: 100%|██████████| 7138/7138 [00:00<00:00, 45457.63 examples/s]
Generating test split: 100%|██████████| 794/794 [00:00<00:00, 44389.49 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 7138/7138 [00:00<00:00, 1107823.94 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 794/794 [00:00<00:00, 271859.38 examples/s]
Downloading builder script: 100%|██████████| 11.1k/11.1k [00:00<00:00, 35.9MB/s]
Downloading readme: 100%|██████████| 41.7k/41.7k [00:00<00:00, 212kB/s]
Downloading data: 100%|██████████| 2.20k/2.20k [00:00<00:00, 16.2MB/s]
Downloading data: 100%|██████████| 11.3M/11.3M [00:00<00:00, 11.7M

In [3]:
dataset_name = "tldr"
model_name = "bigscience/bloom-560m"

if dataset_name == "tldr":
    dataset_name = "tldr_news"
    column = "content"
if dataset_name == "redp":
    dataset_name = "RedPajama-Data-V2"
    column = "raw_content"

In [4]:
datasets = load_from_disk(f"datasets/{dataset_name}")
datasets

DatasetDict({
    train: Dataset({
        features: ['headline', 'content', 'category'],
        num_rows: 7138
    })
    test: Dataset({
        features: ['headline', 'content', 'category'],
        num_rows: 794
    })
})

In [5]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    result = tokenizer(examples[column], max_length=128, truncation=True, padding="max_length")
    result["labels"] = result["input_ids"].copy()
    return result


# Tokenize
# before tokenize: ['headline', 'content', 'category']
# after tokenize: ['input_ids', 'attention_mask', 'labels']
tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    num_proc=32,
    remove_columns=datasets["train"].column_names,  # remove columns that are not required for model input
)
tokenized_datasets.set_format("torch")

In [6]:
dataset_size = 0.3
batch_size = 64

In [7]:
train_dataset = tokenized_datasets["train"].shuffle(seed=77).select(range(int(tokenized_datasets["train"].num_rows * dataset_size)))
# valid_dataset = tokenized_datasets["test"].shuffle(seed=77).select(range(int(tokenized_datasets["test"].num_rows * dataset_size)))

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
# valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size)

In [8]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)

BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(250880, 1024)
    (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0-23): 24 x BloomBlock(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (

In [9]:
optimizers = ["sgd", "adamw"]
idx = 1

# Set optimizer
if optimizers[idx] == "sgd":
    optimizer = SGD(model.parameters(), lr=1e-3)
elif optimizers[idx] == "adamw":
    optimizer = AdamW(model.parameters(), lr=1e-3)

optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)

In [10]:
model.train()
total_loss = 0

for step, batch in enumerate(train_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**batch)
    loss = outputs.loss
    total_loss += loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"[epoch 1] train step: {step + 1}/{len(train_dataloader)}, loss: {total_loss / (step + 1)}")
    if step == 9:
        break

[epoch 1] train step: 1/34, loss: 3.724057674407959
[epoch 1] train step: 2/34, loss: 31.235414505004883
[epoch 1] train step: 3/34, loss: 33.35737609863281
[epoch 1] train step: 4/34, loss: 35.97190856933594
[epoch 1] train step: 5/34, loss: 35.574134826660156
[epoch 1] train step: 6/34, loss: 33.82701873779297
[epoch 1] train step: 7/34, loss: 34.94261932373047
[epoch 1] train step: 8/34, loss: 34.501922607421875
[epoch 1] train step: 9/34, loss: 34.4437255859375
[epoch 1] train step: 10/34, loss: 34.42020797729492


In [None]:
num_epochs = 1

with torch.profiler.profile(with_stack=True) as prof:
    for epoch in range(num_epochs):
        metric = evaluate.load("perplexity")

        model.train()
        loss_per_epoch = 0
        for step, batch in enumerate(train_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss_per_epoch += loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print(f"[epoch {epoch+1}] train step: {step + 1}/{len(train_dataloader)}, loss: {loss_per_epoch / (step + 1)}")

        model.eval()
        loss_per_epoch = 0
        for step, batch in enumerate(valid_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            loss_per_epoch += outputs.loss
            print(f"[epoch {epoch+1}] valid step: {step + 1}/{len(valid_dataloader)}, loss: {loss_per_epoch / (step + 1)}")
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)
            metric.add_batch(predictions=tokenizer.batch_decode(predictions))

        metric = metric.compute(model_id=model_name)
        print(f"[epoch {epoch+1}] mean perplexity: {metric['mean_perplexity']}")

In [None]:
print(prof.key_averages(group_by_stack_n=3).table(sort_by='self_cpu_time_total', row_limit=15))