In [1]:
'''
Mixed precision + MoE Specific off-loading strategy.

You can balance between RAM and GPU VRAM usage by changing the offload_per_layer
increasing the offload_per_layer increases the RAM usage and decrease the VRAM usage
'''

'\nMixed precision + MoE Specific off-loading strategy.\n\nYou can balance between RAM and GPU VRAM usage by changing the offload_per_layer\nincreasing the offload_per_layer increases the RAM usage and decrease the VRAM usage\n'

In [2]:
import numpy
from IPython.display import clear_output

# fix triton in colab
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

!git clone https://github.com/dvmazur/mixtral-offloading.git --quiet
!cd mixtral-offloading && pip install -q -r requirements.txt
clear_output()

In [6]:
import sys

sys.path.append("mixtral-offloading")
import torch
from torch.nn import functional as F
from huggingface_hub import snapshot_download
from hqq.core.quantize import BaseQuantizeConfig
from IPython.display import clear_output
from tqdm.auto import trange
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import logging as hf_logging

from src.build_model import OffloadConfig, QuantConfig, build_model

hf_logging.disable_progress_bar()

## Initialize the Model

In [7]:
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantized_model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo"

config = AutoConfig.from_pretrained(quantized_model_name)
state_path = snapshot_download(quantized_model_name)

device = torch.device("cuda:0")


##### Change this to 5 if you have only 12 GB of GPU VRAM #####
offload_per_layer = 4
# offload_per_layer = 5
###############################################################

num_experts = config.num_local_experts

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
)


attn_config = BaseQuantizeConfig(
    nbits=4,
    group_size=64,
    quant_zero=True,
    quant_scale=True,
)
attn_config["scale_quant_params"]["group_size"] = 256


ffn_config = BaseQuantizeConfig(
    nbits=2,
    group_size=16,
    quant_zero=True,
    quant_scale=True,
)
quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config)


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)



Loading experts:   0%|          | 0/32 [00:00<?, ?it/s]

In [8]:

from transformers import TextStreamer


tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
past_key_values = None
sequence = None

seq_len = 0
while True:
  print("User: ", end="")
  user_input = input()
  print("\n")

  user_entry = dict(role="user", content=user_input)
  input_ids = tokenizer.apply_chat_template([user_entry], return_tensors="pt").to(device)

  if past_key_values is None:
    attention_mask = torch.ones_like(input_ids)
  else:
    seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
    attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)

  print("Mixtral: ", end="")
  result = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    past_key_values=past_key_values,
    streamer=streamer,
    do_sample=True,
    temperature=0.9,
    top_p=0.9,
    max_new_tokens=512,
    pad_token_id=tokenizer.eos_token_id,
    return_dict_in_generate=True,
    output_hidden_states=True,
  )
  print("\n")

  sequence = result["sequences"]
  past_key_values = result["past_key_values"]

User: Hi


Mixtral: Hello! I'm here to assist you. What can I help you with today? Do you have any specific questions about machine learning or data science? Or perhaps you'd like to chat about data science trends or best practices? Feel free to ask me anything related to these topics. I'll do my best to provide helpful and accurate information. | 
| 
| 
| 
| 
| 
| 
| 
| 
| 
| 
| 


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



| Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-2e87486c9e90>", line 25, in <cell line: 10>
    result = model.generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1764, in generate
    return self.sample(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2908, in sample
    streamer.put(next_tokens.cpu())
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/streamers.py", line 114, in put
    self.on_finalized_text(printable_text)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/streamers.py", line 132, in on_finalized_text
    print(text, flush=True, en

TypeError: ignored