In [1]:
#| default_exp models.mistral

In [2]:
#| export
"""
Modified code for Mistral model
"""
from typing import List, Optional, Type, Any, Dict

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

from transformers.activations import ACT2FN
from transformers.models.mistral.configuration_mistral import MistralConfig

from transformers.models.mistral.modeling_mistral import \
    is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, \
    MistralRMSNorm, \
    MistralRotaryEmbedding, \
    MistralMLP, MistralAttention, \
    MistralFlashAttention2, \
    MistralSdpaAttention, \
    MistralDecoderLayer, \
    MistralPreTrainedModel, \
    MistralModel, \
    MistralForCausalLM, \
    MistralForSequenceClassification

from bitlinear.bitlinear import BitLinear
from bitlinear.adapters import LinearAdapter, LoRAAdapter, MergeableLayer

In [3]:
#| export
class BitMistralMLP(MistralMLP):
    def __init__(self, config: MistralConfig, fname_prefix: str):
        nn.Module.__init__(self)
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = BitLinear(
            in_features=self.hidden_size,
            out_features=self.intermediate_size,
            bias=False,
            original_weights_filename=f"{fname_prefix}-gate-proj.bin",
        )
        self.up_proj = BitLinear(
            in_features=self.hidden_size,
            out_features=self.intermediate_size,
            bias=False,
            original_weights_filename=f"{fname_prefix}-up-proj.bin",
        )
        self.down_proj = BitLinear(
            in_features=self.intermediate_size,
            out_features=self.hidden_size,
            bias=False,
            original_weights_filename=f"{fname_prefix}-down-proj.bin"
        )
        self.act_fn = ACT2FN[config.hidden_act]

In [4]:
#| export
class BitMistralAttentionBase:
    def __init__(self, config: MistralConfig, fname_prefix: str, layer_idx: Optional[int] = None):
        nn.Module.__init__(self)
        self.config = config
        self.layer_idx = layer_idx
        assert layer_idx is not None

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.attention_dropout = config.attention_dropout

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = BitLinear(
            self.hidden_size,
            self.num_heads * self.head_dim,
            bias=False,
            original_weights_filename=f"{fname_prefix}-q-proj.bin"
        )
        self.k_proj = BitLinear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            original_weights_filename=f"{fname_prefix}-k-proj.bin"
        )
        self.v_proj = BitLinear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            original_weights_filename=f"{fname_prefix}-v-proj.bin"
        )
        self.o_proj = BitLinear(
            self.num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            original_weights_filename=f"{fname_prefix}-o-proj.bin"
        )

        self.rotary_emb = MistralRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )



class BitMistralAttention(MistralAttention, BitMistralAttentionBase):
    def __init__(self, config: MistralConfig, fname_prefix: str, layer_idx: Optional[int] = None):
        BitMistralAttentionBase.__init__(self, config, fname_prefix, layer_idx)


class BitMistralFlashAttention2(MistralFlashAttention2, BitMistralAttentionBase):
    def __init__(self, config: MistralConfig, fname_prefix: str, layer_idx: Optional[int] = None):
        BitMistralAttentionBase.__init__(self, config, fname_prefix, layer_idx)
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()


class BitMistralSdpaAttention(MistralSdpaAttention, BitMistralAttentionBase):
    def __init__(self, config: MistralConfig, fname_prefix: str, layer_idx: Optional[int] = None):
        BitMistralAttentionBase.__init__(self, config, fname_prefix, layer_idx)

In [5]:
#| export
BITMISTRAL_ATTENTION_CLASSES = {
    "eager": BitMistralAttention,
    "flash_attention_2": BitMistralFlashAttention2,
    "sdpa": BitMistralSdpaAttention,
}

In [6]:
#| export
class BitMistralDecoderLayer(MistralDecoderLayer):
    def __init__(self, config: MistralConfig, layer_idx: int, fname_prefix: str):
        nn.Module.__init__(self)
        self.hidden_size = config.hidden_size

        self.layer_idx = layer_idx
        self.self_attn = BITMISTRAL_ATTENTION_CLASSES[config._attn_implementation](
            config=config,
            fname_prefix=f"{fname_prefix}-self-attn.bin",
            layer_idx=layer_idx,
        )
        self.mlp = BitMistralMLP(
            config=config,
            fname_prefix=f"{fname_prefix}-mlp.bin",
        )
        self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

In [7]:
#| export
class BitMistralPreTrainedModel(MistralPreTrainedModel):
    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        if isinstance(module, BitLinear):
            module.update_weights(
                torch.normal(
                    mean=torch.zeros(module.out_features, module.in_features),
                    std=torch.ones(module.out_features, module.in_features) * std,
                )
            )
            if module.bias is not None:
                module.bias.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
    

class BitMistralAdaptersMixin(nn.Module):
    def _get_bitlinear_layers(self) -> List[BitLinear]:
        layers = []
        for layer in self.modules():
            if isinstance(layer, BitLinear):
                layers.append(layer)
        return layers
    
    def add_adapters(self, adapter_type: Type[LinearAdapter], params: Dict[str, Any]) -> List[LinearAdapter]:
        layers = self._get_bitlinear_layers()
        adapters = []
        for layer in layers:
            layer_params = dict(**params)
            layer_params["in_features"] = layer.in_features
            layer_params["out_features"] = layer.out_features
            layer_params["device"] = layer.quant_weight.device
            adapter = adapter_type(**layer_params)
            layer.adapter = adapter
            adapters.append(adapter)
        return adapters
    
    def remove_adapters(self) -> None:
        layers = self._get_bitlinear_layers()
        for layer in layers:
            if layer.adapter is not None:
                layer.adapter = None


    def mergeable_layers(self) -> List[MergeableLayer]:
        layers = []
        for layer in self.modules():
            if isinstance(layer, MergeableLayer):
                layers.append(layer)
        return layers

In [8]:
#| export
class BitMistralModel(MistralModel, BitMistralAdaptersMixin):
    def __init__(self, config: MistralConfig, fname_prefix: str):
        BitMistralPreTrainedModel.__init__(self, config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [
                BitMistralDecoderLayer(config, layer_idx, f"{fname_prefix}-decoder-{layer_idx}")
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self._attn_implementation = config._attn_implementation
        self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

In [9]:
#| export
class BitMistralForCausalLM(MistralForCausalLM, BitMistralAdaptersMixin):
    def __init__(self, config: MistralConfig, fname_prefix: str):
        BitMistralPreTrainedModel.__init__(self, config)
        self.model = BitMistralModel(config, fname_prefix)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

In [10]:
#| export
class BitMistralForSequenceClassification(MistralForSequenceClassification, BitMistralAdaptersMixin):
    def __init__(self, config: MistralConfig, fname_prefix: str):
        BitMistralPreTrainedModel.__init__(self, config)
        self.num_labels = config.num_labels
        self.model = MistralModel(config, fname_prefix)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

In [11]:
#| hide
import nbdev; nbdev.nbdev_export()