In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils.data import IterableDataset
from bitnet_selfdistil import lm_losses_calculator, patch_model, TrainerModel
from bitnet_selfdistil_utils import phi3_full_gradient_checkpoint_enable, MultiComponentLossTrainer
from torch.optim import SGD
from transformers import TrainingArguments, DataCollatorWithPadding
from transformers.trainer import DEFAULT_PROGRESS_CALLBACK

In [2]:
MODEL_NAME = "microsoft/Phi-3.5-mini-instruct"
DEVICE = "cuda:0"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=DEVICE,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
MAX_LENGTH = 4096

In [5]:
def conversation_to_chat_format(item):
    roles = item["conversation"]["role"]
    contents = item["conversation"]["content"]
    return {
        "conversation": [
            {"role": role, "content": content}
            for role, content in zip(roles, contents)
        ]
    }


def apply_chat_template(item):
    return {
        "conversation": tokenizer.apply_chat_template(item["conversation"], tokenize=False)
    }


def tokenize_conversation(item):
    tokenized = tokenizer(item["conversation"], return_tensors="pt", truncation=True, max_length=MAX_LENGTH)
    input_ids = tokenized["input_ids"].squeeze()
    attention_mask = tokenized["attention_mask"].squeeze()
    item["input_ids"] = input_ids
    item["attention_mask"] = attention_mask
    item["labels"] = input_ids
    return item


dataset = load_dataset("alex43219/quant-text-dataset",
                       trust_remote_code=True,
                       streaming=True)
dataset = dataset.map(conversation_to_chat_format, batched=False) \
    .map(apply_chat_template, batched=False) \
    .map(tokenize_conversation, batched=False) \
    .remove_columns(['conversation'])

In [6]:
# For training the model I will use endless iterator on top of my dataset
def _endless_iterator(dataset):
    while True:
        for sample in dataset:
            yield sample


class _EndlessDataset(IterableDataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __iter__(self):
        return _endless_iterator(self.dataset)


dataset_train = _EndlessDataset(dataset["train"])
dataset_test = _EndlessDataset(dataset["test"])

In [None]:
WARMUP_STEPS = 8000
LR = 1e-4
BATCH_SIZE = 1
MAX_FULL_LOSSES_LENGTH = 4096
SAVE_EACH_N_STEPS = 4000
MAX_STEPS = 40 * SAVE_EACH_N_STEPS
MAX_GRAD_NORM = 5.0

LOG_DIR = "bitnet-selfdistil-tensorboard"
CHECKPOINT_DIRECTORY = "phi-3-self-distillation-bitnet/checkpoints"

In [8]:
model = phi3_full_gradient_checkpoint_enable(model)
model = patch_model(model)

In [9]:
selfdistill_model = TrainerModel(
    model,
    lm_losses_calculator(MAX_FULL_LOSSES_LENGTH),
)

In [10]:
def create_optimizer(model):
    # Set up the optimizer with fused=True for performance benefits
    return SGD(model.parameters(), lr=LR, fused=True)

In [11]:
training_args = TrainingArguments(
    output_dir=CHECKPOINT_DIRECTORY,
    per_device_train_batch_size=BATCH_SIZE,
    logging_dir=LOG_DIR,
    logging_steps=50,
    save_steps=SAVE_EACH_N_STEPS,
    warmup_steps=WARMUP_STEPS,
    learning_rate=LR,
    save_total_limit=10,
    bf16=True,
    logging_first_step=True,
    report_to="tensorboard",
    max_steps=MAX_STEPS,
    max_grad_norm=MAX_GRAD_NORM,
)

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

trainer = MultiComponentLossTrainer(
    model=selfdistill_model,
    args=training_args,
    train_dataset=dataset_train,
    data_collator=data_collator,
    optimizers=(create_optimizer(selfdistill_model), None),
)

max_steps is given, it will override any value given in num_train_epochs


In [None]:
%load_ext tensorboard
%tensorboard --logdir {LOG_DIR}

trainer.remove_callback(DEFAULT_PROGRESS_CALLBACK)
trainer.train()

  0%|          | 0/400 [00:00<?, ?it/s]

INPUT SHAPE torch.Size([1, 128])
SETTING TEACHER MODE
CALCULATING TEACHER OUTPUT
CALCULATING STUDENT MODE
CALCULATING STUDENT OUTPUT
CALCULATING LOSSES
FINISHED FORWARD
{'loss': 78.76007843017578, 'loss_lm': 26.807697296142578, 'kldiv_loss': 25.89899444580078, 'hidden_state_loss': 26.053386688232422, 'epoch': 0}
{'loss': 78.7601, 'grad_norm': 584.0, 'learning_rate': 1.2500000000000001e-08, 'epoch': 0.0}
INPUT SHAPE torch.Size([1, 128])
SETTING TEACHER MODE
CALCULATING TEACHER OUTPUT
CALCULATING STUDENT MODE
CALCULATING STUDENT OUTPUT
CALCULATING LOSSES
FINISHED FORWARD
{'loss': 71.8530502319336, 'loss_lm': 23.828596115112305, 'kldiv_loss': 22.329914093017578, 'hidden_state_loss': 25.694541931152344, 'epoch': 0.0}
INPUT SHAPE torch.Size([1, 128])
SETTING TEACHER MODE
CALCULATING TEACHER OUTPUT
CALCULATING STUDENT MODE
CALCULATING STUDENT OUTPUT
CALCULATING LOSSES
FINISHED FORWARD
{'loss': 76.82511138916016, 'loss_lm': 26.47493553161621, 'kldiv_loss': 24.32211685180664, 'hidden_state_los

: 

: 

In [None]:
selfdistill_model.model.model.layers[0].self_attn.o_proj.weight

In [None]:
selfdistill_model.model.model.layers[0].self_attn.o_proj.weight.grad

In [None]:
selfdistill_model.model.model.layers[0].self_attn.o_proj.delta_weight

In [None]:
selfdistill_model.model.model.layers[0].self_attn.o_proj.delta_weight.grad