# L4-B - Building your own Quantizer: Replace PyTorch layers with Quantized Layers

In this lesson, you will learn about the quantization pipline using your own 8-bit quantizer.

Run the next cell to import all of the functions you have used before in the previous lesson(s) of `Building your own Quantizer` to follow along with the video.

- To access the `helper.py` file, you can click `File --> Open...`, on the top left.

In [None]:
import torch
import torch.nn as nn

from helper import W8A16LinearLayer

## Step 2: Quantization Pipeline

- Replace all of the `torch.nn.Linear` layers with the `W8A16LinearLayer` layer.
- Call `quantize` on the linear layers using the original weights. 

### 2.1 - Model In-place Linear Layer Replacement
- Implement `replace_linear_with_target`

In [None]:
def replace_linear_with_target(module, 
                               target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
          any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias

            new_module = target_class(child.in_features, 
                                      child.out_features, 
                                      old_bias is not None, 
                                      child.weight.dtype)
            setattr(module, name, new_module)
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target(
                child, target_class, module_name_to_exclude)

In [None]:
class DummyModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = torch.nn.Embedding(1, 1)
    # Try with bias
    self.linear_1 = nn.Linear(1, 1)
    # Try without bias
    self.linear_2 = nn.Linear(1, 1, bias=False)
    # Lm prediction head
    self.lm_head = nn.Linear(1, 1, bias=False)

In [None]:
model_1 = DummyModel()
model_2 = DummyModel()

In [None]:
replace_linear_with_target(model_1, W8A16LinearLayer, ["lm_head"])
print(model_1)

In [None]:
replace_linear_with_target(model_2, W8A16LinearLayer, [])
print(model_2)

### 2.2 - Linear Layer Replacement + Quantization
- Modify the `replace_linear_with_target` function to also perform quantization.
- Implement `replace_linear_with_target_and_quantize`.

In [None]:
def replace_linear_with_target_and_quantize(module, 
                               target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
        any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            old_weight = child.weight

            new_module = target_class(child.in_features, 
                                      child.out_features, 
                                      old_bias is not None, 
                                      child.weight.dtype)
            setattr(module, name, new_module)

            getattr(module, name).quantize(old_weight)
            
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target_and_quantize(child, 
                     target_class, module_name_to_exclude)

In [None]:
model_3 = DummyModel()

In [None]:
replace_linear_with_target_and_quantize(model_3, W8A16LinearLayer, ["lm_head"])
print(model_3)