In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]='PCI_BUS_ID'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = 'false'
os.environ["CUDA_VISIBLE_DEVICES"] = '4'

In [3]:
import torch
import multiprocessing
from itertools import chain
from datasets import load_dataset, DatasetDict
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import torch
from torch.optim import Optimizer
from abc import ABC, abstractmethod


class HessianEstimator(ABC):
    @abstractmethod
    def estimate(self, p, grad):
        pass


class HutchinsonEstimator(HessianEstimator):
    def estimate(self, p, grad):
        u = torch.randn_like(grad)
        grad_dot_u = torch.sum(grad * u)
        hessian_vector_product = torch.autograd.grad(grad_dot_u, p, retain_graph=True)[0]
        return u * hessian_vector_product
    

class GaussNewtonBartlettEstimator(HessianEstimator):
    def __init__(self, model, input_data, loss_function):
        self.model = model
        self.input_data = input_data
        self.loss_function = loss_function
    
    def estimate(self, p, grad):
        B = len(self.input_data)
        logits = [self.model(xb) for xb in self.input_data]
        y_hats = [torch.softmax(logit, dim=0) for logit in logits]
        g_hat = torch.autograd.grad(sum([self.loss_function(logit, y_hat) for logit, y_hat in zip(logits, y_hats)]) / B, p, retain_graph=True)[0]
        return B * g_hat * g_hat
    

class DecoupledSophia(Optimizer):
    def __init__(self, params, hessian_estimator, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, k=10, rho=1):
        self.hessian_estimator = hessian_estimator
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, k=k, rho=rho)
        super(DecoupledSophia, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            if closure is not None:
                loss = closure()

        for group in self.params_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
            grad = p.grad.data
            if grad.is_sparse:
                raise RuntimeError('DecoupledSophia does not support sparse gradients')
            
            state = self.state[p]

            #state init
            if len(state) == 0:
                state['step'] = 0
                state['m'] = torch.zeros_like(p.data)
                state['h'] = torch.zeros_like(p.data)

            m, h = state['m'], state['h']
            beta1, beta2 = group['betas']
            state['step'] += 1

            if group['weight_decay'] != 0:
                grad = grad.add(group['weight_decay'], p.data)


            #update biased first moment estimate
            m.mul_(beta1).add_(1 - beta1, grad)

            #update hessian estomate
            if state['step'] % group['k'] == 1:
                hessian_estimator = self.hessian_estimator.estimate(p, grad)
                h.mul_(beta2).add_(1 - beta2, hessian_estimator)

            #update params
            p.data.add_(-group['lr'] * group['weight_decay'], p.data)
            p.data.addcdiv_(-group['lr'], m, h.add(group['eps']).clamp(max=group['rho']))

        return loss

In [5]:
class CFG:
    SEQ_LEN: int = 1024
    NUM_CPU: int = multiprocessing.cpu_count()
    TOKENIZER: str = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(CFG.TOKENIZER)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

1

In [3]:
# dataset = load_dataset("openwebtext")
dataset = load_dataset("wikitext", "wikitext-103-raw-v1")


def tokenize_function(example):
    return tokenizer(example["text"], add_special_tokens=True)

tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    num_proc=CFG.NUM_CPU,
    remove_columns=["text"],
)

2024-07-25 15:33:09.739807: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-25 15:33:09.739931: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-25 15:33:09.753336: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-25 15:33:12.613497: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
comet_ml is installed but `COMET_API_KEY` is not s

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/733k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Map (num_proc=48):   0%|          | 0/4358 [00:00<?, ? examples/s]

Map (num_proc=48):   0%|          | 0/1801350 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1063 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1369 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1132 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1581 > 1024). Running this sequence through the model will result in indexing errors


Map (num_proc=48):   0%|          | 0/3760 [00:00<?, ? examples/s]

In [4]:
block_size = CFG.SEQ_LEN

def group_texts(examples):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    return result

train_dataset = tokenized_dataset.map(
    group_texts,
    batched=True,
    num_proc=CFG.NUM_CPU,
)

Map (num_proc=48):   0%|          | 0/4358 [00:00<?, ? examples/s]

Map (num_proc=48):   0%|          | 0/1801350 [00:00<?, ? examples/s]

Map (num_proc=48):   0%|          | 0/3760 [00:00<?, ? examples/s]

In [6]:
# Initialize the GPT-2 model and tokenizer
config = GPT2Config.from_pretrained("gpt2", n_ctx=1024)
model = GPT2LMHeadModel.from_pretrained("gpt2", config=config)

# Choose a Hessian estimator
hessian_estimator = HutchinsonEstimator()

# Initialize the DecoupledSophia optimizer
optimizer = DecoupledSophia(model.parameters(), hessian_estimator, lr=1e-3)

In [7]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params:,} || all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

print_trainable_parameters(model)

trainable params: 124,439,808 || all params: 124,439,808 || trainable%: 100.00


In [8]:
from datasets import DatasetDict

print("loading dataset from disk ...")
train_dataset = DatasetDict.load_from_disk("./tests/wikitext/")

loading dataset from disk ...


In [9]:
# Set up the training arguments
training_args = TrainingArguments(
    output_dir=".output",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=480,
    save_steps=10000,
    save_total_limit=2,
    prediction_loss_only=True,
    gradient_accumulation_steps=1,
    max_grad_norm=1.0,
    lr_scheduler_type="cosine",
    warmup_steps=2000,
    report_to="none",
)

In [10]:
# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
    train_dataset=train_dataset["train"],
    optimizers=(optimizer, None),
)

In [None]:
# Train the model
trainer.train()

In [None]:
# Evaluate the model
eval_results = trainer.evaluate()
print(f"Perplexity: {torch.exp(torch.tensor(eval_results['eval_loss']))}")