In [None]:
import torch

import transformers
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType

import datasets
import wandb

from tqdm import tqdm

In [None]:
BATCH_SIZE = 24
MAX_LENGTH = 128

data = datasets.load_dataset("c4", "en", split="train", streaming=True)
data = data.shuffle(seed=42)

tokenizer = AutoTokenizer.from_pretrained("t5-base")

def preprocess_batched(batch):
    batch = tokenizer(
        batch["text"],
        max_length=MAX_LENGTH,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    return batch

data_mapped = data.map(preprocess_batched, batched=True, batch_size=1000, remove_columns=["text", "timestamp", "url"])

def collate_fn(batch_list):
    batch = {
        "input_ids": torch.stack([example["input_ids"] for example in batch_list]),
        "attention_mask": torch.stack([example["attention_mask"] for example in batch_list]),
    }
    return batch

def batch_fn(dataset, batch_size):
    batch = []
    for example in dataset:
        batch.append(example)
        if len(batch) == batch_size:
            batch = collate_fn(batch)
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch

data_mapped.batch = lambda batch_size: batch_fn(data_mapped, batch_size)

In [None]:
USE_PEFT = True
TRAIN_LN = True
NUM_TRAINING_STEPS = 10_000

device = "cuda:1"

model_config = AutoConfig.from_pretrained("gpt2-large")
model = AutoModelForCausalLM.from_config(model_config)

if USE_PEFT:
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
    )

    model = get_peft_model(peft_config, model)

    for name, param in model.named_parameters():
        if TRAIN_LN and "ln_" in name:
            param.requires_grad = True
        if "lm_head" in name:
            param.requires_grad = True
        if "transformer.wte" in name:
            param.requires_grad = True
        if "transformer.wpe" in name:
            param.requires_grad = True

    model.print_trainable_parameters()

model = model.to(device)

n_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
p_trainable_params = n_trainable_params / n_total_params

trainable_params = (p for p in model.parameters() if p.requires_grad)
trainable_params_names = [name for name, p in model.named_parameters() if p.requires_grad]

optimizer = torch.optim.Adam(trainable_params, lr=1e-4)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1_000, num_training_steps=NUM_TRAINING_STEPS)

_config = {
    "using_peft": USE_PEFT,
    "layer_norm_trainable": TRAIN_LN,
    "peft_config": peft_config.to_dict(),
    "total_params": n_total_params,
    "trainable_params": n_trainable_params,
    "percent_trainable_params": p_trainable_params,
    "name_trainable_params": trainable_params_names,
    "dataset": "c4",
    "batch_size": BATCH_SIZE,
    "max_length": MAX_LENGTH,
    "model": model_config.to_dict(),
    "scheduler": "linear",
    "device": str(device),
}

wandb.init(project="peft_pretraining", config=_config)
pbar = tqdm(total=NUM_TRAINING_STEPS)

In [None]:
model.base_model.transformer.wte.weight.requires_grad

In [None]:
for epoch in range(1):
    data_mapped.set_epoch(epoch)
    for batch in data_mapped.batch(batch_size=BATCH_SIZE):
        pbar.update(1)
        optimizer.zero_grad()

        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch["input_ids"].clone()
        labels[labels == 0] = -100

        loss = model(**batch, labels=labels).loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        lr = scheduler.get_last_lr()[0]
        wandb.log({
            "loss": loss.item(),
            "lr": lr,
        })