In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.optim import AdamW
from bitnet_selfdistil import ReLoRAConfig, ReLoRAEvents, ReloraTrainer, StopCondition, lm_losses_calculator
from bitnet_selfdistill_utils import phi3_full_gradient_checkpoint_enable
from torch.utils.data import DataLoader

In [2]:
MODEL_NAME = "microsoft/Phi-3.5-mini-instruct"
DEVICE = "cuda:0"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=DEVICE,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Dataset preparation

In [4]:
MAX_LENGTH = 1024

In [5]:
def conversation_to_chat_format(item):
    roles = item["conversation"]["role"]
    contents = item["conversation"]["content"]
    return {
        "conversation": [
            {"role": role, "content": content}
            for role, content in zip(roles, contents)
        ]
    }


def apply_chat_template(item):
    return {
        "conversation": tokenizer.apply_chat_template(item["conversation"], tokenize=False)
    }


def tokenize_conversation(item):
    tokenized = tokenizer(item["conversation"], return_tensors="pt", truncation=True, max_length=MAX_LENGTH)
    input_ids = tokenized["input_ids"].squeeze()
    attention_mask = tokenized["attention_mask"].squeeze()
    item["input_ids"] = input_ids
    item["attention_mask"] = attention_mask
    item["labels"] = input_ids
    return item


dataset = load_dataset("alex43219/quant-text-dataset",
                       trust_remote_code=True,
                       streaming=True)
dataset = dataset.map(conversation_to_chat_format, batched=False) \
    .map(apply_chat_template, batched=False) \
    .map(tokenize_conversation, batched=False) \
    .remove_columns(['conversation'])

In [6]:
for sample in dataset["train"].take(1):  # take(1) to preview a single sample
    print(sample)

{'input_ids': tensor([32010,   887,   526,   385,   319, 29902, 20255, 29889,   887,   674,
          367,  2183,   263,  3414, 29889,   887,  1818,  5706,   263, 13173,
          322,  1472,  1234, 29889,    13,    13, 12148,  1234,   445,  1139,
        29901, 11644,  2113,   278, 27813, 24819, 20604,   297, 29871, 29906,
        29900, 29900, 29929, 29973, 32007, 32001,   512, 29871, 29906, 29900,
        29900, 29929, 29892,   278, 27813, 24819, 20604,   471, 15074,   304,
         2261,   547,  4250,  3304, 29892,   278, 29871, 29946, 29946,   386,
         7178,   310,   278,  3303,  3900, 29889,   940,   471,  4207,  4395,
          411,   445,   544,  5286,  2738,  9862,   925,  7378,  1156,   540,
        12023,   278, 28282,  1270, 29892,   297, 19679,   310,   670,   714,
        11235, 14231,   304,  9926,   261,   263,   901,  1302,  3372,  1230,
        29892, 11465,  1230, 29892,   322, 10776,  1319,  5534,  5177, 29889,
          450, 27990, 27813, 12930, 13771,   630, 

## Trainer

In [7]:
WARMUP_STEPS = 2000
LR = 1e-5
LORA_RANK = 128
RESET_STEPS = 1000
CHUNK_WARMUP_STEPS = 100
BATCH_SIZE = 1
MAX_FULL_LOSSES_LENGTH = 2048

In [8]:
def _global_lr(step):
    if step < WARMUP_STEPS:
        return step / WARMUP_STEPS
    else:
        return 1.0

In [9]:
relora_config = ReLoRAConfig(
    blacklisted_modules=["lm_head"],
    lora_rank=LORA_RANK,
    optimizer_type=AdamW,
    optimizer_kwargs={
        "lr": LR,
    },
    reset_steps=RESET_STEPS,
    chunk_warmup_steps=CHUNK_WARMUP_STEPS,
    lr_global=_global_lr,
)

In [10]:
def _step_end(step, optimizer, losses, loss):
    if step % 50 == 0:
        print(f"STEP {step}")
        for loss_name, loss_value in losses.items():
            print(f"{loss_name}: {loss_value.item():.4f}")

In [11]:
def _chunk_end(chunk, step):
    print(f"CHUNK {chunk} FINISHED AT STEP {step}")
    return StopCondition.CONTINUE

In [12]:
relora_events = ReLoRAEvents(
    on_step_end=_step_end,
    on_chunk_end=_chunk_end,
)

In [13]:
model = phi3_full_gradient_checkpoint_enable(model)

In [14]:
trainer = ReloraTrainer(
    model=model,
    relora_config=relora_config,
    events=relora_events,
    losses_calculator=lm_losses_calculator(MAX_FULL_LOSSES_LENGTH),
    model_kwargs={
        "output_hidden_states": True,
    },
    checkpoint_directory="checkpoints"
)

In [15]:
dataloader_train = DataLoader(dataset["train"], batch_size=BATCH_SIZE)

In [16]:
dataloader_test = DataLoader(dataset["test"].take(100), batch_size=BATCH_SIZE)

In [17]:
trainer.evaluate(dataloader_test)

2024-11-01 06:01:38,842 - INFO - Evaluation completed with average loss components: {'loss_lm': 25.65352439880371, 'kldiv_loss': 24.615262985229492, 'hidden_state_loss': 26.375, 'loss': 76.61003875732422}


{'loss_lm': 25.65352439880371,
 'kldiv_loss': 24.615262985229492,
 'hidden_state_loss': 26.375,
 'loss': 76.61003875732422}

In [18]:
trainer.train(
    dataloader_train,
    continue_from_checkpoint=True,
)

2024-11-01 06:01:38,915 - INFO - Starting training with continue_from_checkpoint=True
  checkpoint_data = torch.load(fname)
2024-11-01 06:01:44,828 - INFO - Loaded checkpoint for chunk 0, resuming from step 1000
2024-11-01 06:01:50,067 - INFO - Loaded checkpoint for chunk 1, resuming from step 2000
2024-11-01 06:01:53,431 - INFO - Loaded checkpoint for chunk 2, resuming from step 3000
2024-11-01 06:01:57,156 - INFO - Starting training on chunk 3 with start_step=3000
2024-11-01 06:01:58,746 - INFO - Optimizer initialized with kwargs={'lr': 1e-05}


STEP 3000
loss_lm: 6.4923
kldiv_loss: 4.2035
hidden_state_loss: 7.8438
loss: 18.5395
STEP 3050
loss_lm: 6.4608
kldiv_loss: 4.0312
hidden_state_loss: 8.0625
loss: 18.5545
STEP 3100
loss_lm: 6.6599
kldiv_loss: 5.0897
hidden_state_loss: 8.3125
loss: 20.0621
STEP 3150
loss_lm: 5.9641
kldiv_loss: 5.2394
hidden_state_loss: 7.5625
loss: 18.7660
STEP 3200
loss_lm: 6.4760
kldiv_loss: 5.1222
hidden_state_loss: 8.0625
loss: 19.6607
STEP 3250
loss_lm: 6.6561
kldiv_loss: 4.5692
hidden_state_loss: 7.8750
loss: 19.1003
STEP 3300
loss_lm: 6.2684
kldiv_loss: 4.8806
hidden_state_loss: 8.1875
loss: 19.3365
STEP 3350
loss_lm: 6.8033
kldiv_loss: 5.6978
hidden_state_loss: 8.5000
loss: 21.0012


KeyboardInterrupt: 

In [None]:
trainer.evaluate(dataloader_test)