In [1]:
import torch
from transformers import AutoModelForCausalLM
import os
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [2]:
model_id = "mistralai/Mixtral-8x7B-v0.1"

In [3]:
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map='auto')

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [11]:
type(model)

transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM

In [12]:
print(model)

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBLockSparseTop2MLP(
              (w1): Linear(in_features=4096, out_features=14336, bias=False)
              (w2): Linear(in_features=14336, out_features=4096, bias=False)
              (w3): Linear(in_features=4096, out_features=14336, bias=False)
        

In [13]:
with torch.no_grad():
    for idx, layer in enumerate(model.model.layers):
        experts = layer.block_sparse_moe.experts
        w1s = torch.stack([expert.w1.weight.flatten() for expert in experts])
        w2s = torch.stack([expert.w2.weight.flatten() for expert in experts])
        w3s = torch.stack([expert.w3.weight.flatten() for expert in experts])
        mean_w1 = torch.mean(w1s, dim=0)
        mean_w2 = torch.mean(w2s, dim=0)
        mean_w3 = torch.mean(w3s, dim=0)

        def sim(w1, w2s):
            diffs = w2s - w1.unsqueeze(0)
            r2 = torch.sum(diffs ** 2, dim=1)
        
        print(f"Layer {idx}")
        print(f"  w1: {sim(mean_w1, w1s)}")
        print(f"  w2: {sim(mean_w2, w2s)}")
        print(f"  w3: {sim(mean_w3, w3s)}")

Layer 0
  w1: tensor([0.4053, 0.4495, 0.4712, 0.4617, 0.4463, 0.4363, 0.4109, 0.4302],
       device='cuda:0', dtype=torch.float16)
  w2: tensor([0.3813, 0.4468, 0.4636, 0.4541, 0.4436, 0.4272, 0.3926, 0.4226],
       device='cuda:0', dtype=torch.float16)
  w3: tensor([0.3867, 0.4307, 0.4509, 0.4414, 0.4287, 0.4177, 0.3909, 0.4131],
       device='cuda:0', dtype=torch.float16)
Layer 1
  w1: tensor(..., device='meta', size=(8,), dtype=torch.float16)
  w2: tensor(..., device='meta', size=(8,), dtype=torch.float16)
  w3: tensor(..., device='meta', size=(8,), dtype=torch.float16)
Layer 2
  w1: tensor(..., device='meta', size=(8,), dtype=torch.float16)
  w2: tensor(..., device='meta', size=(8,), dtype=torch.float16)
  w3: tensor(..., device='meta', size=(8,), dtype=torch.float16)
Layer 3
  w1: tensor(..., device='meta', size=(8,), dtype=torch.float16)
  w2: tensor(..., device='meta', size=(8,), dtype=torch.float16)
  w3: tensor(..., device='meta', size=(8,), dtype=torch.float16)
Layer 4
  w

In [None]:
class BinaryExperts(nn.Module):
    def __init__(self, base, experts):
        super().__init__()
        diff = finetune - base
        quantile = diff.float().abs().median()

        mask = torch.ones_like(diff)
        mask[diff < 0] = -1

        self.register_buffer("mask", mask.T)
        self.register_buffer("base", base.T)
        self.register_parameter(
            "coeff",
            nn.Parameter(
                torch.tensor(
                    quantile,
                    dtype=torch.float32,
                    requires_grad=True,
                    device=base.device,
                )
            ),
        )
        del base

    def forward(self, x):
        x = x @ (self.base + self.coeff * self.mask)
        return x

@torch.no_grad()
def compress_mixtral_moe_diff(model):
    for name, module in model.named_modules():
        if name.endswith("experts"):
            experts = module
            for expert in experts:
                w1 = expert.w1.weight
                w2 = expert.w2.weight
                w3 = expert.w3.weight
                expert.w1.weight = torch.nn.Parameter(w1 - w2)
                expert.w2.weight = torch.nn.Parameter(w2 - w3)
                expert.w3.weight = torch.nn.Parameter(w3)