In [None]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest

In [None]:
# Define the w8_a16_forward function
def w8_a16_forward(weight, input, scales, bias=None):
    casted_weights = weight.to(input.dtype)  # Cast weights to the input's dtype
    output = F.linear(input, casted_weights) * scales  # Apply scales after the linear transformation

    if bias is not None:
        output = output + bias  # Add bias if provided

    return output

In [None]:
# Define W8A16LinearLayer class
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()

        # Register the buffers for int8 weights and scales
        self.register_buffer(
            "int8_weights", 
            torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8)
        )
        self.register_buffer("scales", torch.randn((out_features), dtype=dtype))

        if bias:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
        else:
            self.bias = None

    # Quantize method for converting weights to int8
    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)
        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)
        int8_weights = torch.round(weights / scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales

    # Forward method for the layer
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)

In [None]:
# Function to replace linear layers with the quantized version
def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
    for name, child in module.named_children():
        # Check if the layer is a Linear layer and not excluded
        if isinstance(child, nn.Linear) and not any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            old_weight = child.weight

            # Create the new quantized module
            new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
            setattr(module, name, new_module)

            # Quantize the old weight and replace it in the new module
            getattr(module, name).quantize(old_weight)
            
            # Retain the old bias
            if old_bias is not None:
                getattr(module, name).bias = old_bias
        else:
            # Recursively apply the quantization replacement to nested modules
            replace_linear_with_target_and_quantize(child, target_class, module_name_to_exclude)

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path

mistral_models_path = Path.home().joinpath('mistral_models', '7B-Instruct-v0.3')
mistral_models_path.mkdir(parents=True, exist_ok=True)

snapshot_download(repo_id="mistralai/Mistral-7B-Instruct-v0.3", 
                  allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], 
                  local_dir=mistral_models_path)

In [None]:
# Load the tokenizer and model using the mistral_inference library
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tokenizer.model.v3")
model = Transformer.from_folder(mistral_models_path).cuda()
print("Model before:\n\n", mistral_models_path)

In [None]:
# Function to print the weights of all Linear layers before quantization
def print_linear_weights_before_quantization(model):
    print("Weights before quantization:\n")
    for name, param in model.named_parameters():
        if "weight" in name:
            print(f"Layer: {name}")
            print(f"Weights: {param.data}")  # .data gives the raw tensor of weights
            print("-" * 50)

# Call the function to print weights before quantization
print_linear_weights_before_quantization(model)

# Create an example chat completion request
completion_request = ChatCompletionRequest(
    messages=[UserMessage(content="Explain Machine Learning to me in a nutshell.")]
)

In [None]:
# Tokenize the chat completion request
tokens = tokenizer.encode_chat_completion(completion_request).tokens

In [None]:
# Test the model before quantization and measure the inference time
start_time = time.time()
out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result_before = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
end_time = time.time()

time_before_quantization = end_time - start_time
print(f"Before Quantization: {result_before}")
print(f"Inference Time Before Quantization: {time_before_quantization:.4f} seconds")

In [None]:
# Apply quantization
replace_linear_with_target_and_quantize(model, W8A16LinearLayer, ["lm_head"])

print("Model before:\n\n", model)

In [None]:
# Test the model after quantization and measure the inference time
start_time = time.time()
out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result_after = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
end_time = time.time()

time_after_quantization = end_time - start_time
print(f"After Quantization: {result_after}")
print(f"Inference Time After Quantization: {time_after_quantization:.4f} seconds")

In [None]:
# Compare results
time_difference = time_before_quantization - time_after_quantization
print(f"Time saved with quantization: {time_difference:.4f} seconds")

In [None]:
# Function to print quantized weights and scales after quantization
def print_quantized_weights(model):
    print("Weights after quantization:\n")
    for name, module in model.named_modules():
        if isinstance(module, W8A16LinearLayer):
            print(f"Layer: {name}")
            print(f"Quantized Weights (int8):\n{module.int8_weights}")
            print(f"Scales:\n{module.scales}")
            if module.bias is not None:
                print(f"Bias (still in FP32):\n{module.bias}")
            print("-" * 50)

In [None]:
# Call the function to print quantized weights and scales
print_quantized_weights(model)