In [1]:
from bitnet_llama.model import LlamaForCausalLM, BitLinear
from transformers import AutoTokenizer
from transformers import LlamaConfig

In [2]:
tokenizer = AutoTokenizer.from_pretrained("beomi/llama-2-ko-7b")
config = LlamaConfig(
    vocab_size=len(tokenizer),
    hidden_size=512,
    intermediate_size=2048,
    max_position_embeddings=512,
    num_attention_heads=32,
    num_hidden_layers=4,
    num_key_value_heads=32,
    pretraining_tp=1,
)

In [3]:
config

LlamaConfig {
  "attention_bias": false,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "max_position_embeddings": 512,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 4,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "transformers_version": "4.35.0.dev0",
  "use_cache": true,
  "vocab_size": 46336
}

In [4]:
import torch
from torch import nn

In [5]:
torch.version.cuda

'12.1'

In [6]:
torch.cuda.is_available()

True

In [7]:
!nvidia-smi

Thu Oct 19 16:12:18 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA H100 PCIe               Off | 00000000:49:00.0 Off |                    0 |
| N/A   39C    P0              85W / 350W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 PCIe               Off | 00000000:5A:00.0 Off |  

In [8]:
model = LlamaForCausalLM(config)

In [9]:
model = model.to(torch.bfloat16)

In [10]:
model.to('cuda:0')

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(46336, 512)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): BitLinear(in_features=512, out_features=512, bias=False)
          (k_proj): BitLinear(in_features=512, out_features=512, bias=False)
          (v_proj): BitLinear(in_features=512, out_features=512, bias=False)
          (o_proj): BitLinear(in_features=512, out_features=512, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): BitLinear(in_features=512, out_features=2048, bias=False)
          (up_proj): BitLinear(in_features=512, out_features=2048, bias=False)
          (down_proj): BitLinear(in_features=2048, out_features=512, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  

In [11]:
!nvidia-smi

Thu Oct 19 16:12:20 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA H100 PCIe               Off | 00000000:49:00.0 Off |                    0 |
| N/A   39C    P0              84W / 350W |    603MiB / 81559MiB |      3%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 PCIe               Off | 00000000:5A:00.0 Off |  

In [12]:
model.get_memory_footprint()

128590912

In [13]:
output = model(**tokenizer('Hello world!', return_tensors='pt', return_token_type_ids=False).to('cuda:0'))

In [14]:
model.model.layers[0].self_attn.q_proj.binarize_weights_groupwise()

tensor([[ 1., -1.,  1.,  ..., -1.,  1.,  1.],
        [ 1., -1., -1.,  ...,  1.,  1.,  1.],
        [ 1., -1., -1.,  ..., -1.,  1., -1.],
        ...,
        [-1., -1., -1.,  ..., -1.,  1., -1.],
        [ 1.,  1.,  1.,  ..., -1., -1.,  1.],
        [-1.,  1.,  1.,  ..., -1., -1., -1.]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<CopyBackwards>)

In [15]:
model.save_pretrained

<bound method LlamaPreTrainedModel.save_pretrained of LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(46336, 512)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): BitLinear(in_features=512, out_features=512, bias=False)
          (k_proj): BitLinear(in_features=512, out_features=512, bias=False)
          (v_proj): BitLinear(in_features=512, out_features=512, bias=False)
          (o_proj): BitLinear(in_features=512, out_features=512, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): BitLinear(in_features=512, out_features=2048, bias=False)
          (up_proj): BitLinear(in_features=512, out_features=2048, bias=False)
          (down_proj): BitLinear(in_features=2048, out_features=512, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMS

In [16]:
model.save_pretrained('test_1bit', save_binarized_weights=True)
model.save_pretrained('test_full_bit', save_binarized_weights=False)

In [17]:
model.model.layers[0].mlp.gate_proj.weight

Parameter containing:
tensor([[ 0.0032, -0.0339,  0.0150,  ...,  0.0041, -0.0048,  0.0061],
        [-0.0105, -0.0049, -0.0586,  ..., -0.0092,  0.0188, -0.0084],
        [-0.0383, -0.0109,  0.0031,  ..., -0.0410,  0.0211,  0.0223],
        ...,
        [ 0.0131, -0.0259,  0.0034,  ...,  0.0233, -0.0281, -0.0131],
        [ 0.0062,  0.0198,  0.0085,  ...,  0.0129, -0.0205,  0.0050],
        [ 0.0292,  0.0152, -0.0175,  ...,  0.0256,  0.0276,  0.0082]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)

In [18]:
model.model.layers[2].post_attention_layernorm.weight

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1.

In [19]:
def binarize_bitlinear_weights(self):
    """
    Binarizes weights of BitLinear layers and returns the updated state_dict.
    """
    state_dict = self.state_dict()
    for name, module in self.named_modules():
        if name == 'lm_head':
            continue
        if isinstance(
            module, BitLinear
        ):
            
            print(name)
            state_dict[name+'.weight'] = module.binarize_weights_groupwise()
    return state_dict

In [20]:
model.state_dict().keys()#['model.layers.0.self_attn.q_proj']

odict_keys(['model.embed_tokens.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'mod

In [21]:
model.model.layers[0].post_attention_layernorm

LlamaRMSNorm()

In [22]:
bin_model = binarize_bitlinear_weights(model)

model.layers.0.self_attn.q_proj
model.layers.0.self_attn.k_proj
model.layers.0.self_attn.v_proj
model.layers.0.self_attn.o_proj
model.layers.0.mlp.gate_proj
model.layers.0.mlp.up_proj
model.layers.0.mlp.down_proj
model.layers.1.self_attn.q_proj
model.layers.1.self_attn.k_proj
model.layers.1.self_attn.v_proj
model.layers.1.self_attn.o_proj
model.layers.1.mlp.gate_proj
model.layers.1.mlp.up_proj
model.layers.1.mlp.down_proj
model.layers.2.self_attn.q_proj
model.layers.2.self_attn.k_proj
model.layers.2.self_attn.v_proj
model.layers.2.self_attn.o_proj
model.layers.2.mlp.gate_proj
model.layers.2.mlp.up_proj
model.layers.2.mlp.down_proj
model.layers.3.self_attn.q_proj
model.layers.3.self_attn.k_proj
model.layers.3.self_attn.v_proj
model.layers.3.self_attn.o_proj
model.layers.3.mlp.gate_proj
model.layers.3.mlp.up_proj
model.layers.3.mlp.down_proj


In [23]:
model.model.layers[0].mlp.gate_proj.weight

Parameter containing:
tensor([[ 0.0032, -0.0339,  0.0150,  ...,  0.0041, -0.0048,  0.0061],
        [-0.0105, -0.0049, -0.0586,  ..., -0.0092,  0.0188, -0.0084],
        [-0.0383, -0.0109,  0.0031,  ..., -0.0410,  0.0211,  0.0223],
        ...,
        [ 0.0131, -0.0259,  0.0034,  ...,  0.0233, -0.0281, -0.0131],
        [ 0.0062,  0.0198,  0.0085,  ...,  0.0129, -0.0205,  0.0050],
        [ 0.0292,  0.0152, -0.0175,  ...,  0.0256,  0.0276,  0.0082]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)

In [24]:
# bin_model.model.layers[0].mlp.gate_proj.weight

In [27]:
model.config.save_pretrained('./test-config')