In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.optim import AdamW
from bitnet_selfdistil import ReLoRAConfig, ReLoRAEvents, ReloraTrainer, StopCondition, lm_losses_calculator
from bitnet_selfdistill_utils import phi3_full_gradient_checkpoint_enable
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.tensorboard import SummaryWriter

In [2]:
MODEL_NAME = "microsoft/Phi-3.5-mini-instruct"
DEVICE = "cuda:0"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=DEVICE,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Dataset preparation

In [4]:
MAX_LENGTH = 1024

In [5]:
def conversation_to_chat_format(item):
    roles = item["conversation"]["role"]
    contents = item["conversation"]["content"]
    return {
        "conversation": [
            {"role": role, "content": content}
            for role, content in zip(roles, contents)
        ]
    }


def apply_chat_template(item):
    return {
        "conversation": tokenizer.apply_chat_template(item["conversation"], tokenize=False)
    }


def tokenize_conversation(item):
    tokenized = tokenizer(item["conversation"], return_tensors="pt", truncation=True, max_length=MAX_LENGTH)
    input_ids = tokenized["input_ids"].squeeze()
    attention_mask = tokenized["attention_mask"].squeeze()
    item["input_ids"] = input_ids
    item["attention_mask"] = attention_mask
    item["labels"] = input_ids
    return item


dataset = load_dataset("alex43219/quant-text-dataset",
                       trust_remote_code=True,
                       streaming=True)
dataset = dataset.map(conversation_to_chat_format, batched=False) \
    .map(apply_chat_template, batched=False) \
    .map(tokenize_conversation, batched=False) \
    .remove_columns(['conversation'])

In [6]:
for sample in dataset["train"].take(1):  # take(1) to preview a single sample
    print(sample)

{'input_ids': tensor([32010,   887,   526,   385,   319, 29902, 20255, 29889,   887,   674,
          367,  2183,   263,  3414, 29889,   887,  1818,  5706,   263, 13173,
          322,  1472,  1234, 29889,    13,    13, 12148,  1234,   445,  1139,
        29901, 11644,  2113,   278, 27813, 24819, 20604,   297, 29871, 29906,
        29900, 29900, 29929, 29973, 32007, 32001,   512, 29871, 29906, 29900,
        29900, 29929, 29892,   278, 27813, 24819, 20604,   471, 15074,   304,
         2261,   547,  4250,  3304, 29892,   278, 29871, 29946, 29946,   386,
         7178,   310,   278,  3303,  3900, 29889,   940,   471,  4207,  4395,
          411,   445,   544,  5286,  2738,  9862,   925,  7378,  1156,   540,
        12023,   278, 28282,  1270, 29892,   297, 19679,   310,   670,   714,
        11235, 14231,   304,  9926,   261,   263,   901,  1302,  3372,  1230,
        29892, 11465,  1230, 29892,   322, 10776,  1319,  5534,  5177, 29889,
          450, 27990, 27813, 12930, 13771,   630, 

## Trainer

In [7]:
WARMUP_STEPS = 2000
LR = 1e-5
LORA_RANK = 128
RESET_STEPS = 1000
CHUNK_WARMUP_STEPS = 100
BATCH_SIZE = 1
MAX_FULL_LOSSES_LENGTH = 2048
TEST_ROLLING_WINDOW_SIZE = 1000

LOG_DIR = "bitnet-selfdistil-tensorboard"

In [8]:
tensorboard_writer = SummaryWriter(log_dir=LOG_DIR)

In [9]:
# For training the model I will use endless iterator on top of my dataset
def _endless_iterator(dataset):
    while True:
        for sample in dataset:
            yield sample


class _EndlessDataset(IterableDataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __iter__(self):
        return _endless_iterator(self.dataset)


dataloader_train = DataLoader(_EndlessDataset(dataset["train"]), batch_size=BATCH_SIZE)

In [10]:
def _chunked_dataloader_iterator(dataset, chunk_size):
    def _iterate_batches():
        records = []
        for batch in _EndlessDataset(dataset):
            records.append(batch)
            if len(records) == chunk_size:
                yield records
                records = []
    
    for batches in _iterate_batches():
        yield DataLoader(batches, batch_size=BATCH_SIZE)


dataloader_test_generator = iter(_chunked_dataloader_iterator(dataset["test"], TEST_ROLLING_WINDOW_SIZE))

In [11]:
def _step_end(trainer, step, optimizer, losses, loss):
    for loss_name, loss_value in losses.items():
        tensorboard_writer.add_scalar(f"Loss/{loss_name}", loss_value.item(), step)
    tensorboard_writer.add_scalar("Loss/total", loss.item(), step)
    tensorboard_writer.add_scalar("Learning Rate", optimizer.param_groups[0]['lr'], step)

In [12]:
def _chunk_end(trainer: ReloraTrainer, chunk, step):
    print(f"CHUNK {chunk} FINISHED AT STEP {step}")
    metrics = trainer.evaluate(next(dataloader_test_generator))
    for metric_name, metric_value in metrics.items():
        tensorboard_writer.add_scalar(f"Evaluation Metrics/{metric_name}", metric_value, step)
    return StopCondition.CONTINUE

In [13]:
def _global_lr(step):
    if step < WARMUP_STEPS:
        return step / WARMUP_STEPS
    else:
        return 1.0

In [14]:
relora_config = ReLoRAConfig(
    blacklisted_modules=["lm_head"],
    lora_rank=LORA_RANK,
    optimizer_type=AdamW,
    optimizer_kwargs={
        "lr": LR,
    },
    reset_steps=RESET_STEPS,
    chunk_warmup_steps=CHUNK_WARMUP_STEPS,
    lr_global=_global_lr,
)

In [15]:
relora_events = ReLoRAEvents(
    on_step_end=_step_end,
    on_chunk_end=_chunk_end,
)

In [16]:
model = phi3_full_gradient_checkpoint_enable(model)

In [17]:
trainer = ReloraTrainer(
    model=model,
    relora_config=relora_config,
    events=relora_events,
    losses_calculator=lm_losses_calculator(MAX_FULL_LOSSES_LENGTH),
    model_kwargs={
        "output_hidden_states": True,
    },
    checkpoint_directory="checkpoints"
)

In [None]:
trainer.train(
    dataloader_train,
    continue_from_checkpoint=True,
)

2024-11-01 16:50:13,125 - INFO - Starting training with continue_from_checkpoint=True
2024-11-01 16:50:16,472 - INFO - Starting training on chunk 0 with start_step=0
2024-11-01 16:50:18,329 - INFO - Optimizer initialized with kwargs={'lr': 1e-05}
2024-11-01 17:00:00,526 - INFO - Completed training on chunk 0
2024-11-01 17:00:01,650 - INFO - Saved checkpoint for chunk 0 at step 1000


CHUNK 0 FINISHED AT STEP 1000


2024-11-01 17:04:08,250 - INFO - Evaluation completed with average loss components: {'loss_lm': 9.086080551147461, 'kldiv_loss': 7.964767932891846, 'hidden_state_loss': 21.25, 'loss': 38.325599670410156}
2024-11-01 17:04:13,526 - INFO - Starting training on chunk 1 with start_step=1000
2024-11-01 17:04:14,899 - INFO - Optimizer initialized with kwargs={'lr': 1e-05}
2024-11-01 17:14:39,591 - INFO - Completed training on chunk 1
2024-11-01 17:14:40,733 - INFO - Saved checkpoint for chunk 1 at step 2000


CHUNK 1 FINISHED AT STEP 2000


2024-11-01 17:19:13,811 - INFO - Evaluation completed with average loss components: {'loss_lm': 7.435633659362793, 'kldiv_loss': 6.318731784820557, 'hidden_state_loss': 9.875, 'loss': 23.614429473876953}
2024-11-01 17:19:15,326 - INFO - Starting training on chunk 2 with start_step=2000
2024-11-01 17:19:16,675 - INFO - Optimizer initialized with kwargs={'lr': 1e-05}
2024-11-01 17:30:39,162 - INFO - Completed training on chunk 2
2024-11-01 17:30:40,335 - INFO - Saved checkpoint for chunk 2 at step 3000


CHUNK 2 FINISHED AT STEP 3000


2024-11-01 17:35:49,302 - INFO - Evaluation completed with average loss components: {'loss_lm': 7.12249231338501, 'kldiv_loss': 6.027544021606445, 'hidden_state_loss': 8.25, 'loss': 21.38494300842285}
2024-11-01 17:35:54,380 - INFO - Starting training on chunk 3 with start_step=3000
2024-11-01 17:35:55,733 - INFO - Optimizer initialized with kwargs={'lr': 1e-05}
2024-11-01 17:48:21,008 - INFO - Completed training on chunk 3
2024-11-01 17:48:22,171 - INFO - Saved checkpoint for chunk 3 at step 4000


CHUNK 3 FINISHED AT STEP 4000


2024-11-01 17:53:59,893 - INFO - Evaluation completed with average loss components: {'loss_lm': 6.629182815551758, 'kldiv_loss': 5.59841251373291, 'hidden_state_loss': 8.125, 'loss': 20.371562957763672}
2024-11-01 17:54:04,257 - INFO - Starting training on chunk 4 with start_step=4000
2024-11-01 17:54:05,599 - INFO - Optimizer initialized with kwargs={'lr': 1e-05}
2024-11-01 18:07:27,753 - INFO - Completed training on chunk 4
2024-11-01 18:07:28,933 - INFO - Saved checkpoint for chunk 4 at step 5000


CHUNK 4 FINISHED AT STEP 5000


2024-11-01 18:13:54,035 - INFO - Evaluation completed with average loss components: {'loss_lm': 6.41719913482666, 'kldiv_loss': 5.332680702209473, 'hidden_state_loss': 8.0625, 'loss': 19.790847778320312}
2024-11-01 18:13:55,783 - INFO - Starting training on chunk 5 with start_step=5000
2024-11-01 18:13:57,087 - INFO - Optimizer initialized with kwargs={'lr': 1e-05}
