In [3]:
import time
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from torch.optim import AdamW
from accelerate import Accelerator

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x.norm(2, dim=-1, keepdim=True)
        return x * self.weight / (norm / (x.shape[-1] ** 0.5) + self.eps)

def replace_layernorm_with_rmsnorm(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.LayerNorm):
            parent_name = name.rsplit('.', 1)[0]
            parent = dict(model.named_modules()).get(parent_name, model)
            setattr(parent, name.split('.')[-1], RMSNorm(module.normalized_shape[0], eps=module.eps))


In [5]:
def prepare_dataset():
    dataset = load_dataset("fancyzhx/dbpedia_14")
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

    def tokenize(example):
        return tokenizer(example["content"], truncation=True, padding="max_length", max_length=128)

    encoded = dataset.map(tokenize, batched=True)
    encoded.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

    train_dataset = encoded["train"].shuffle(seed=42).select(range(5000))
    test_dataset = encoded["test"].shuffle(seed=42).select(range(1000))
    return train_dataset, test_dataset


In [6]:
def train_epoch(model, loader, criterion, optimizer, device, accelerator):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for batch in loader:
        input_ids, attention_mask, labels = [batch[k].to(device) for k in ["input_ids", "attention_mask", "label"]]
        optimizer.zero_grad()

        with accelerator.autocast():  # Enables mixed precision
            output = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(output.logits, labels)

        accelerator.backward(loss)
        optimizer.step()

        total_loss += loss.item()
        correct += (output.logits.argmax(1) == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(loader), correct / total

In [7]:
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    start = time.time()

    with torch.no_grad():
        for batch in loader:
            input_ids, attention_mask, labels = [batch[k].to(device) for k in ["input_ids", "attention_mask", "label"]]
            with torch.autocast(device_type=device.type, dtype=torch.float16):
                output = model(input_ids=input_ids, attention_mask=attention_mask)
            correct += (output.logits.argmax(1) == labels).sum().item()
            total += labels.size(0)

    return correct / total, time.time() - start

In [None]:
precision = "fp16"
train_data, test_data = prepare_dataset()

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=14)
replace_layernorm_with_rmsnorm(model)

accelerator = Accelerator(mixed_precision=precision)
optimizer = AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16)

model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
test_loader = accelerator.prepare(test_loader)

Generating train split: 100%|██████████| 560000/560000 [00:00<00:00, 1712791.00 examples/s]
Generating test split: 100%|██████████| 70000/70000 [00:00<00:00, 1890567.04 examples/s]
Map: 100%|██████████| 560000/560000 [00:33<00:00, 16565.10 examples/s]
Map: 100%|██████████| 70000/70000 [00:04<00:00, 17079.43 examples/s]
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
NUM_EPOCHS = 10
epoch_losses = []
epoch_accuracies = []
start_time = time.time()

In [10]:
for epoch in range(NUM_EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, accelerator.device, accelerator)
    epoch_losses.append(train_loss)
    epoch_accuracies.append(train_acc)
    print(f"Epoch {epoch+1}: loss = {train_loss:.4f}, acc = {train_acc:.4f}")

train_time = time.time() - start_time
train_mem = torch.cuda.max_memory_allocated() / 1e6

test_acc, test_time = evaluate(model, test_loader, accelerator.device)

Epoch 1: loss = 1.4017, acc = 0.5904
Epoch 2: loss = 0.3315, acc = 0.9142
Epoch 3: loss = 0.1891, acc = 0.9482
Epoch 4: loss = 0.1312, acc = 0.9638
Epoch 5: loss = 0.0944, acc = 0.9766
Epoch 6: loss = 0.0775, acc = 0.9786
Epoch 7: loss = 0.0510, acc = 0.9872
Epoch 8: loss = 0.0448, acc = 0.9892
Epoch 9: loss = 0.0446, acc = 0.9872
Epoch 10: loss = 0.0289, acc = 0.9948


In [11]:
result = {
    "precision": precision,
    "epoch_loss": epoch_losses,
    "epoch_acc": epoch_accuracies,
    "train_time": train_time,
    "train_mem": train_mem,
    "test_acc": test_acc,
    "test_time": test_time
}

with open("results_mixed.json", "w") as f:
    json.dump(result, f, indent=2)

print(json.dumps(result, indent=2))

{
  "precision": "fp16",
  "epoch_loss": [
    1.40165920741261,
    0.33154501883795084,
    0.18909203390867566,
    0.13119684636426238,
    0.09436189014309893,
    0.07745544823071066,
    0.05103705125227095,
    0.044824234092423615,
    0.04464695333226468,
    0.028923542150251638
  ],
  "epoch_acc": [
    0.5904,
    0.9142,
    0.9482,
    0.9638,
    0.9766,
    0.9786,
    0.9872,
    0.9892,
    0.9872,
    0.9948
  ],
  "train_time": 99.09022092819214,
  "train_mem": 1461.655552,
  "test_acc": 0.967,
  "test_time": 0.4543461799621582
}
