In [1]:
import torch
import optiacts
from peft import LoraConfig, prepare_model_for_kbit_training  # , PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)

In [2]:
def print_stats():
    stats = torch.cuda.memory_stats()
    print(f'allocated: {stats["active_bytes.all.current"] / 2**30:.3}Gb, peak: {stats["active_bytes.all.peak"] / 2**30:.3}Gb')

In [3]:
if torch.cuda.get_device_capability()[0] >= 8:
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16


base_model = "mistralai/Mistral-7B-v0.1"
new_model = "test-mistral-7B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation,
)
model = prepare_model_for_kbit_training(model)


def run():
    out = model(torch.randint(0, 10000, [1, 2**12], device='cuda'))
    print_stats()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
print('Model size:')
print_stats()

Model size:
allocated: 4.83Gb, peak: 5.08Gb


In [5]:
print('Standard activations:')
run()

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.


Standard activations:
allocated: 47.5Gb, peak: 47.6Gb


In [6]:
for layer in model.model.layers:
    layer.mlp.act_fn = optiacts.GELU()
torch.cuda.reset_peak_memory_stats()

print('Memory with optiacts:')
run()

Memory with optiacts:
allocated: 40.8Gb, peak: 40.8Gb
