In [18]:
import gc

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from bitnet_selfdistil import lm_losses_calculator, ReLoRAConfig, ReLoRAEvents, ReloraTrainer, BitLinearWithLoRA
from torch.optim import AdamW

In [2]:
MODEL_NAME = "microsoft/Phi-3.5-mini-instruct"
DEVICE = "cuda:0"
LORA_RANK = 128

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=DEVICE,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
def check_generation(model, tokenizer):
    with torch.no_grad():
        input_ids = tokenizer.apply_chat_template(
            [
                {
                    "role": "user",
                    "content": "Can you provide ways to eat combinations of bananas and dragonfruits?",
                }
            ],
            return_tensors="pt",
            add_generation_prompt=True
        ).to(device=DEVICE)
        generation_output = model.generate(input_ids=input_ids,
                                           return_dict_in_generate=True,
                                           output_scores=True,
                                           max_length=100)
        response = tokenizer.decode(generation_output.sequences[0][input_ids.shape[1]:])
        print(response)

In [5]:
check_generation(model, tokenizer)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some creative recipes and ideas to enjoy these fruits together:

1. **Banana Dragonfruit Smoothie**:
   - Blend together one ripe banana, half a cup of dragonfruit puree, a cup of almond milk


In [6]:
def _global_lr(step):
    if step < 2000:
        return step / 2000
    else:
        return 1.0

In [7]:
def _step_end(step, optimizer, losses, loss):
    if step % 50 == 0:
        print(f"STEP {step}")
        for loss_name, loss_value in losses.items():
            print(f"{loss_name}: {loss_value.item():.4f}")

In [8]:
def _chunk_end(chunk, step):
    print(f"CHUNK {chunk} FINISHED AT STEP {step}")

In [9]:
relora_config = ReLoRAConfig(
    blacklisted_modules=["lm_head"],
    lora_rank=128,
    optimizer_type=AdamW,
    optimizer_kwargs={
        "lr": 1e-4,
    },
    reset_steps=1000,
    chunk_warmup_steps=100,
    lr_global=_global_lr,
)

In [10]:
relora_events = ReLoRAEvents(
    on_step_end=_step_end,
    on_chunk_end=_chunk_end,
)

In [11]:
from torch.utils.checkpoint import checkpoint
from transformers.models.phi3.modeling_phi3 import Phi3DecoderLayer

In [12]:
def patched_phi3RMS_norm_forward(module):
    def forward(hidden_states):
        assert (hidden_states.dtype == torch.bfloat16) or (hidden_states.dtype == torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
        return module.weight * hidden_states
    
    return forward

def make_new_forward(module):
    def forward(hidden_states,
                attention_mask = None,
                position_ids = None,
                past_key_value = None,
                output_attentions = False,
                use_cache = False,
                cache_position = None,
                **kwargs):
        #print(f"hidden_states.requires_grad: {hidden_states.requires_grad}")
        result = checkpoint(
            lambda *args: Phi3DecoderLayer.forward(module, *args, **kwargs),
            hidden_states,
            attention_mask,
            position_ids,
            past_key_value,
            output_attentions,
            use_cache,
            cache_position,
            use_reentrant=True
        )
        #print(f"result[0].requires_grad: {result[0].requires_grad}")
        return result
    
    return forward

model.enable_input_require_grads()
for module in model.model.layers:
    module.forward = make_new_forward(module)
    module.input_layernorm.forward = patched_phi3RMS_norm_forward(module.input_layernorm)
    module.post_attention_layernorm.forward = patched_phi3RMS_norm_forward(module.post_attention_layernorm)

In [13]:
trainer = ReloraTrainer(
    model=model,
    relora_config=relora_config,
    events=relora_events,
    losses_calculator=lm_losses_calculator(4096),
    max_steps=20000,
    model_kwargs={
        "output_hidden_states": True,
    }
)

In [14]:
batch = {
    "input_ids": tokenizer.apply_chat_template(
        [
            {
                "role": "user",
                "content": "Can you provide ways to eat combinations of bananas and dragonfruits?",
            },
            {
                "role": "assistant",
                "content": "Certainly! Bananas and dragonfruits can be combined in a variety of delicious and creative ways. Here are some ideas:\n" + \
                           "- Blended Smoothie:\n" + \
                           "  Peel and cut both fruits into chunks and blend them with some yogurt or coconut milk for creaminess. Add a scoop of protein powder or a spoonful of peanut butter for extra protein and flavor."
            }
        ],
        return_tensors="pt",
        add_generation_prompt=True
    ).to(device=DEVICE)
}
batch["labels"] = batch["input_ids"].clone()

In [15]:
batches=[
    batch
    for _ in range(20000)
]

In [16]:
trainer.train(batches)



STEP 0
loss_lm: 24.0062
kldiv_loss: 23.2483
hidden_state_loss: 26.0000
loss: 73.2545
STEP 50
loss_lm: 23.9170
kldiv_loss: 23.1573
hidden_state_loss: 26.0000
loss: 73.0743
STEP 100
loss_lm: 12.2565
kldiv_loss: 11.5783
hidden_state_loss: 25.7500
loss: 49.5847
STEP 150
loss_lm: 8.7729
kldiv_loss: 8.1127
hidden_state_loss: 23.3750
loss: 40.2605
STEP 200
loss_lm: 6.6762
kldiv_loss: 6.1185
hidden_state_loss: 20.7500
loss: 33.5446
STEP 250
loss_lm: 5.1696
kldiv_loss: 4.8269
hidden_state_loss: 16.6250
loss: 26.6214
STEP 300
loss_lm: 4.4452
kldiv_loss: 4.3006
hidden_state_loss: 13.1250
loss: 21.8708
STEP 350
loss_lm: 3.7847
kldiv_loss: 3.6965
hidden_state_loss: 11.4375
loss: 18.9187
STEP 400
loss_lm: 2.4683
kldiv_loss: 2.4116
hidden_state_loss: 10.6875
loss: 15.5674
STEP 450
loss_lm: 1.4710
kldiv_loss: 1.6508
hidden_state_loss: 10.0625
loss: 13.1843
STEP 500
loss_lm: 1.0462
kldiv_loss: 1.1664
hidden_state_loss: 9.6875
loss: 11.9001
STEP 550
loss_lm: 0.6907
kldiv_loss: 0.9026
hidden_state_loss: 

KeyboardInterrupt: 

In [19]:
weights = {}
for name, module in model.named_modules():
    if isinstance(module, BitLinearWithLoRA):
        weight = module.get_bitnet_weight().detach().cpu()
        weights[f"{name}.weight"] = weight

In [21]:
weights["model.layers.0.self_attn.o_proj.weight"].shape

torch.Size([3072, 3072])

In [23]:
import gc

del model, trainer
gc.collect()
torch.cuda.empty_cache()

In [24]:
gc.collect()
torch.cuda.empty_cache()

In [25]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=DEVICE,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [28]:
weights["model.layers.0.self_attn.o_proj.weight"].dtype

torch.bfloat16

In [30]:
model.load_state_dict(
    dict(model.state_dict(), **weights),
    strict=False
)

<All keys matched successfully>

In [32]:
set(weights["model.layers.0.self_attn.o_proj.weight"].float().numpy().ravel())

{-0.013061523, -0.0, 0.013061523}

In [34]:
set( (model.state_dict()["model.layers.0.self_attn.o_proj.weight"] * 1).detach().cpu().float().numpy().ravel() )

{-0.013061523, -0.0, 0.013061523}

In [35]:
check_generation(model, tokenizer)

Certainly! Bananas and dragonfruits can be combined in a variety of delicious and creative ways. Here are some ideas:


- Blended Smoothie:
  Comel and cut both fruits into chunks and blend them with some yogurt or coconut milk for a scoop of protein powder, a hand of hogurt
