In [1]:
import torch
import evaluate

from tqdm.auto import tqdm
from torch.optim import SGD, AdamW
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Migrate online datasets to offline datasets
datasets = load_dataset("JulesBelveze/tldr_news")
datasets.save_to_disk("tldr_news")

Saving the dataset (1/1 shards): 100%|██████████| 7138/7138 [00:00<00:00, 1051854.76 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 794/794 [00:00<00:00, 340519.16 examples/s]


In [3]:
dataset_name = "tldr_news"
model_name = "bigscience/bloom-560m"

In [4]:
datasets = load_from_disk(dataset_name)
datasets

DatasetDict({
    train: Dataset({
        features: ['headline', 'content', 'category'],
        num_rows: 7138
    })
    test: Dataset({
        features: ['headline', 'content', 'category'],
        num_rows: 794
    })
})

In [5]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    result = tokenizer(examples["content"], max_length=128, truncation=True, padding="max_length")
    result["labels"] = result["input_ids"].copy()
    return result


# Tokenize
# before tokenize: ['headline', 'content', 'category']
# after tokenize: ['input_ids', 'attention_mask', 'labels']
tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    num_proc=2,
    remove_columns=datasets["train"].column_names,  # remove columns that are not required for model input
)
tokenized_datasets.set_format("torch")

In [6]:
dataset_size = 0.3
batch_size = 64

In [7]:
train_dataset = tokenized_datasets["train"].shuffle(seed=77).select(range(int(tokenized_datasets["train"].num_rows * dataset_size)))
valid_dataset = tokenized_datasets["test"].shuffle(seed=77).select(range(int(tokenized_datasets["test"].num_rows * dataset_size)))

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size)

In [8]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)

BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(250880, 1024)
    (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0-23): 24 x BloomBlock(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (

In [9]:
optimizers = ["sgd", "adamw"]
idx = 1

# Set optimizer
if optimizers[idx] == "sgd":
    optimizer = SGD(model.parameters(), lr=1e-3)
elif optimizers[idx] == "adamw":
    optimizer = AdamW(model.parameters(), lr=1e-3)

optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)

In [10]:
num_epochs = 2
progress_bar = tqdm(range(num_epochs * (len(train_dataloader) + len(valid_dataloader))))

for epoch in range(num_epochs):
    metric = evaluate.load("perplexity")

    model.train()
    loss_per_epoch = 0
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss_per_epoch += loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # print(f"[epoch {epoch+1}] train step: {step + 1}/{len(train_dataloader)}, loss: {loss_per_epoch / (step + 1)}")
        progress_bar.update(1)

    model.eval()
    loss_per_epoch = 0
    for step, batch in enumerate(valid_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss_per_epoch += outputs.loss
        # print(f"[epoch {epoch+1}] valid step: {step + 1}/{len(valid_dataloader)}, loss: {loss_per_epoch / (step + 1)}")
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=tokenizer.batch_decode(predictions))
        progress_bar.update(1)

    metric = metric.compute(model_id=model_name)
    print(f"[epoch {epoch+1}] mean perplexity: {metric['mean_perplexity']}")

100%|██████████| 15/15 [00:05<00:00,  2.70it/s]


[epoch 1] mean perplexity: 64.52562903957207


100%|██████████| 15/15 [00:05<00:00,  2.70it/s]

[epoch 2] mean perplexity: 145.98367406540558



