# Analog-Digital Mixed-Precision Training and Inference

### Authors: [Athanasios Vasilopoulos](https://www.linkedin.com/in/thanos-vasilopoulos/)

In this notebook, we demonstrate the capability of performing mixed-precision training and inference in AIHWKit. The notebook uses a ResNet32 as a use case and demonstrates:
- The conversion of an FP model to a mixed-precision analog-digital model, with per-torch-module fidelity. All activations (including outputs and affine parameters of tiles) and all digital modules have configurable precision. The quantization parameters are learnable during training or estimated post-training with a variety of range-estimators.
- We demonstrate how to perform simultaneous training of a model with analog modules and low-precision digital modules and activations.
- We demonstrate how to perform post-training quantization on the activations and modules of a network that was trained only in an analog-aware manner.


In [1]:
# Imports
from copy import deepcopy

import requests
import torch
from tqdm import tqdm

from aihwkit.inference.calibration import calibrate_quantization_ranges
from aihwkit.inference.compensation.drift import GlobalDriftCompensation
from aihwkit.inference.noise.pcm import PCMLikeNoiseModel
from aihwkit.nn.conversion import convert_to_analog
from aihwkit.nn.low_precision_conversion import convert_to_quantized
from aihwkit.nn.low_precision_modules.conversion_utils import append_default_conversions
from aihwkit.nn.low_precision_modules.quantized_base_modules import (
    QuantBatchNorm2d,
    QuantConv2d,
    QuantLinear,
)
from aihwkit.optim import AnalogSGD
from aihwkit.simulator.configs import QuantizedTorchInferenceRPUConfig, TorchInferenceRPUConfig
from aihwkit.simulator.configs.utils import (
    BoundManagementType,
    NoiseManagementType,
    WeightClipType,
    WeightModifierType,
    WeightRemapType,
)
from aihwkit.simulator.parameters.quantization import (
    ActivationQuantConfig,
    QuantizationMap,
    QuantizedModuleConfig,
)
from aihwkit.simulator.presets.utils import IOParameters

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Utility imports - model definition, training/test steps, dataset preparation
from model_utils import *

In [3]:
# RPU config generation for either the quantized or traditional TorchTile.
# The differences involve the handling of DACs, as well as the definition of
# parameters for the configuration of the new quantized periphery.
def gen_rpu_config(quant_tile, n_act_bits, asym_act):
    # If quant_tile is True, the QuantizedTorchInferenceRPUConfig is initialized which means
    # that the QuantizedTorchInferenceTile is used for inference. The QuantizedTorchInferenceTile
    # is a wrapper of TorchInferenceTile which extends its functionality with quantized outputs 
    # (post-ADC) and quantized periphery, which applies the affine transformations and bias in 
    # reduced precision. The former is configured with the rpu_config.act_quant_config parameter 
    # and the latter with the rpu_config.pre_post.periph_quant. See below for details.
    rpu_config = QuantizedTorchInferenceRPUConfig() if quant_tile else TorchInferenceRPUConfig()

    rpu_config.modifier.std_dev = 0.06
    rpu_config.modifier.type = WeightModifierType.ADD_NORMAL

    rpu_config.mapping.digital_bias = True
    rpu_config.mapping.weight_scaling_omega = 1.0
    rpu_config.mapping.weight_scaling_columnwise = True
    rpu_config.mapping.learn_out_scaling = False

    rpu_config.remap.type = WeightRemapType.LAYERWISE_SYMMETRIC

    rpu_config.clip.type = WeightClipType.LAYER_GAUSSIAN
    rpu_config.clip.sigma = 2.0

    rpu_config.forward = IOParameters()
    rpu_config.forward.is_perfect = False
    rpu_config.forward.out_noise = 0.0
    rpu_config.forward.out_bound = 12
    rpu_config.forward.out_res = 1 / (2**8 - 2)
    rpu_config.forward.bound_management = BoundManagementType.NONE
    rpu_config.forward.noise_management = NoiseManagementType.ABS_MAX

    # When the quantized tile is used, the network is converted such that the 
    # input to every tile is already quantized by the layer before it. As such
    # no DAC functionality is necessary
    rpu_config.forward.inp_bound = 0.0 if quant_tile else 1.0
    rpu_config.forward.inp_res = -1 if quant_tile else 1 / (2**n_act_bits - 2)

    # Enable input range learning and its parameters if quant_tile is False
    rpu_config.pre_post.input_range.enable = False if quant_tile else True
    rpu_config.pre_post.input_range.decay = 0.01
    rpu_config.pre_post.input_range.init_from_data = 50
    rpu_config.pre_post.input_range.init_std_alpha = 3.0
    rpu_config.pre_post.input_range.input_min_percentage = 0.995
    rpu_config.pre_post.input_range.manage_output_clipping = False

    rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0)
    rpu_config.drift_compensation = GlobalDriftCompensation()

    if quant_tile:
        # Configure what the output activation precision of each tile.
        # When a layer uses more than one tile, this option applies to
        # the result of each individual tile AND to the final output 
        # after proper partial result accumulation and concatenation.
        # See `QuantizedTileModuleArray` for more details.
        rpu_config.act_quant_config = ActivationQuantConfig(
            n_bits=n_act_bits,
            symmetric=not asym_act,
        )

        # Configure the precision of the periphery, which applies to the
        # affine transformations and bias additions of the tiles.
        rpu_config.pre_post.periph_quant.n_bits = n_act_bits
        rpu_config.pre_post.periph_quant.symmetric = True

    return rpu_config

In [4]:
# Append to the QuantizationMap the recipe to convert any modules that are not in the
# DEFAULT_CONVERSIONS (which includes only Linear, Conv2d, Embedding, and LayerNorm).
# In the case of this network, the only module is BatchNorm2d. Other networks, with
# more advanced modules, may require custom quantized layer implementations using the
# tools we offer in this framework, like the QuantizedActivation, etc.
def append_resnet_quantization_map(quantization_map: QuantizationMap):
    quantization_map.module_qconfig_map[torch.nn.BatchNorm2d] = QuantizedModuleConfig(
        quantized_module=QuantBatchNorm2d,
        module_qconfig=deepcopy(quantization_map.default_qconfig),
    )


# Function to convert the pretrained model to a mixed-precision model,
# merging the convert_to_analog and convert_to_quantized calls, with
# the appropriate selections for the resnet network in question.
# The tools we offer allow for the configuration of every single
# activation in the system, enabling faithful replication of
# the user's targetted system.
def convert_to_mixedprecision(
    model,  # The input model
    mode,  # "mixed" will place first and last layers on digital, "analog" will place all weight layers in analog
    act_quant,  # Precision of activations across the system
    asym_act,  # Whether activations are asymmetric
    weight_quant,  # Weight precision if "mixed" is selected
    per_channel_weight=True,  # Whether weights are quantized per_channel or not
    # Flags to perform partial convertions - Used during PTQ (see the corresponding code)
    only_analog_conversion=False,
    ptq=False,
):
    # Generate the RPU config according to the parameters. The QuantizedTorchInferenceTile
    # is only selected if activations on the network are selected to be quantized
    base_rpu_config = gen_rpu_config(
        quant_tile=(act_quant > 0), n_act_bits=act_quant, asym_act=asym_act
    )

    # Convert to analog, excluding first and last layers if "mixed" mode is given
    model = convert_to_analog(
        model,
        base_rpu_config,
        exclude_modules=[] if mode == "analog" else ["conv1", "linear"],
    )

    # For PTQ we may need to stop here so that we can load an analog-only trained checkpoint
    # before proceeding to the activation and module quantization
    if only_analog_conversion:
        return model

    # Now we'll define the quantization configuration for all modules (excluding the already
    # analog converted modules). We also have the capability to wrap a given module with input
    # quantization if there's no an already quantized activation coming to it (from a functional
    # call for example). That way we can keep a true quantized-everywhere activation flow.
    quantization_map = QuantizationMap()
    # Populate the default quantization config with the given parameters
    quantization_map.default_qconfig.activation_quant.n_bits = act_quant
    quantization_map.default_qconfig.activation_quant.symmetric = not asym_act
    if mode == "mixed":
        quantization_map.default_qconfig.weight_quant.n_bits = weight_quant
        quantization_map.default_qconfig.weight_quant.per_channel = per_channel_weight

    # As mentioned before, some layers will receive data that are not quantized from the layer
    # before it. In this network, this happens in the first layer that receives data from the
    # dataset directly, the last layer that receives data from a functional pooling layer and
    # for every resnet block that receives data after a residual addition (the first resnet block
    # receives data directly from a batchnorm so it's skipped here). With the
    # quantization_map.input_activation_qconfig_map we can add the names of all of these layers
    # such that we assure their input activations get quantized properly.
    if act_quant > 0:
        quantization_map.input_activation_qconfig_map = {
            lname: quantization_map.default_qconfig.activation_quant
            for lname in ["conv1", "linear"]
            + [f"layer{i}.{j}" for j in range(5) for i in range(1, 4) if (i, j) != (1, 0)]
        }

    # Append the resnet and default module conversions to the quantization_map.module_qconfig_map.
    # These conversions will use the quantization_map.default_qconfig parameters as a recipe for each
    # of the appended conversions. If some specific layers need to use different parameters than the
    # default_qconfig parameters, they must be defined in the quantization_map.instance_qconfig_map
    # field by name, along with their special parameters.
    append_resnet_quantization_map(quantization_map)
    append_default_conversions(quantization_map)

    # For the case that we convert from an analog network to a mixed network, we need to define
    # a conversion of an Analog layer to a Quantized layer. Since we only need to do it for two
    # layers, we define it here as instance specific conversion.
    if mode == "mixed" and ptq:
        quantization_map.instance_qconfig_map["conv1"] = QuantizedModuleConfig(
            quantized_module=QuantConv2d,
            module_qconfig=deepcopy(quantization_map.default_qconfig),
        )
        quantization_map.instance_qconfig_map["linear"] = QuantizedModuleConfig(
            quantized_module=QuantLinear,
            module_qconfig=deepcopy(quantization_map.default_qconfig),
        )

    # Convert to quantized model using all the options we configured above
    model = convert_to_quantized(model, quantization_map)
    model = model.to(device)

    return model

## Mixed-precision training and PTQ code

In [6]:
# Parameters to configure how to convert the model
# and which analog-only checkpoint to use for the PTQ demonstration
mode = "mixed"  # "analog or "mixed" Mixed means that first and last layers are in digital with `weight_quant` precision
act_quant = 6   # Precision of activations in bits
asym_act = True # Asymmetric activations or not
weight_quant = 8 # Precision of weights (when mixed is True)


# Training hyper params
perform_finetuning = False  # True to perform the mixed-precision training, otherwise it will load a known-good checkpoint
epochs_finetuning = 200 # How many epochs to finetune, if `perform_finetuning` is True
batch_size = 128
lr = 5e-3
seed = 0

# Evaluation repetitions and time for the analog layers
eval_reps = 10
eval_time = 0  # 86400

# Set seeds
set_seed(seed)

# Auto-generation of a unique model name, based on the parameters
model_name = (
    f"{mode}_{f'quantact{act_quant}' if act_quant>0 else 'fpact'}"
    + f"{'_asym' if asym_act and act_quant>0 else ''}"
    + f"{f'_quantw{weight_quant}' if weight_quant>0 and mode == 'mixed' else ''}"
    + f"_lr{lr:.0e}_seed{seed}"
)
model_name = f"mixed_prec_example_{model_name}"
# Get the dataloader
trainloader, testloader = load_cifar10(batch_size=batch_size, path=os.path.expanduser("data/"))

# Define model, criterion, optimizer and scheduler.
petrained_model = resnet32()
petrained_model = petrained_model.to(device)
criterion = torch.nn.CrossEntropyLoss()

# Load and evaluate the FP-only pretrained model
if not os.path.exists("Models"):
    os.makedirs("Models")
url = 'https://aihwkit-tutorial.s3.us-east.cloud-object-storage.appdomain.cloud/mixed_prec_example_pre_trained_model.th'
response = requests.get(url)
with open('Models/mixed_prec_example_pre_trained_model.th', 'wb') as f:
    f.write(response.content)
petrained_model.load_state_dict(
    torch.load("Models/mixed_prec_example_pre_trained_model.th", map_location=device, weights_only=True)
)
print(f"Pretrained test acc. {test_step(petrained_model, criterion, testloader)}%")

Files already downloaded and verified
Files already downloaded and verified
Pretrained test acc. 94.13%


### Mixed-precision finetuning

In [7]:
set_seed(seed)

# Convert the model to a mixed-precicion model for finetuning
mixed_precision_model_fortraining = convert_to_mixedprecision(
    petrained_model, mode, act_quant, asym_act, weight_quant
)

if perform_finetuning:
    # Perform a normal training loop
    optimizer_class = AnalogSGD if mode in ["analog", "mixed"] else torch.optim.SGD
    optimizer = optimizer_class(
        mixed_precision_model_fortraining.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_finetuning)

    best_acc = -1.0
    pbar = tqdm(range(epochs_finetuning))
    for epoch in pbar:
        train_loss, train_acc = train_step(
            mixed_precision_model_fortraining, optimizer, criterion, trainloader
        )
        pbar.set_description(
            f"Epoch {epoch} Train loss: {train_loss:.4f} train acc. {train_acc:.2f}%"
        )
        current_lr = scheduler.get_last_lr()[0]
        test_acc = test_step(mixed_precision_model_fortraining, criterion, testloader)
        if test_acc > best_acc and epoch > 50:
            best_acc = test_acc
            torch.save(
                mixed_precision_model_fortraining.state_dict(),
                os.path.join(f"Models/{model_name}.th"),
            )

        scheduler.step()
else:
    # Load from a known-good checkpoint for evaluation
    assert (
        model_name == "mixed_prec_example_mixed_quantact6_asym_quantw8_lr5e-03_seed0"
    ), "Checkpoint is only offered for one combination"
    url = "https://aihwkit-tutorial.s3.us-east.cloud-object-storage.appdomain.cloud/mixed_prec_example_mixed_quantact6_asym_quantw8_lr5e-03_seed0.th"
    response = requests.get(url)
    with open(f"Models/{model_name}.th", "wb") as f:
        f.write(response.content)
    mixed_precision_model_fortraining.load_state_dict(
        torch.load(f"Models/{model_name}.th", weights_only=False)
    )

### PTQ of an Analog HWA model

In [8]:
set_seed(seed)

# First convert to just an analog model to load the state dict of the already trained analog model.
# Doing it with strict=False and with load_rpu_config=False because we need the RPU config to be as
# defined in the RPU config function above. The checkpoint may have trained input ranges or other
# parameters that are going to be ignored when the rest of the network produces quantized activations.
ana_model = convert_to_mixedprecision(
    petrained_model, "analog", act_quant, asym_act, weight_quant, only_analog_conversion=True
)
url = "https://aihwkit-tutorial.s3.us-east.cloud-object-storage.appdomain.cloud/mixed_prec_example_analog_fpact_lr5e-03_seed0.th"
response = requests.get(url)
with open("Models/mixed_prec_example_analog_fpact_lr5e-03_seed0.th", "wb") as f:
    f.write(response.content)
ana_model.load_state_dict(
    torch.load(
        "Models/mixed_prec_example_analog_fpact_lr5e-03_seed0.th",
        weights_only=False,
        map_location=device,
    ),
    strict=False,
    load_rpu_config=False,
)

# Convert to mixed-precision but with all of the options, to convert
# batchnorms and wrap modules with I/O quantization where applicable
ptq_model = convert_to_mixedprecision(ana_model, mode, act_quant, asym_act, weight_quant, ptq=True)

# Now that the model is loaded with the correct weights from the analog checkpoint and contains
# all the quantizer modules, we need to calibrate them and find good quantizations scale-offsets.
# For this reason, we offer the `calibrate_quantization_ranges` function that will use training
# data to calibrate all the (currently uninitialized) quantizers in the ptq network.
_ = calibrate_quantization_ranges(ptq_model, trainloader, max_num_batches=25)

### Evaluation of the finetuned and PTQed model

In [9]:
set_seed(seed)
criterion = torch.nn.CrossEntropyLoss()

# Place models in eval mode
ptq_model.eval()
mixed_precision_model_fortraining.eval()

# Evaluate `eval_reps` times, each time resampling the analog noise
test_accs_ptq, test_accs_training = [], []
pbar = tqdm(range(eval_reps))
for step in pbar:
    # Add noise on both models with t_inference equal to `eval_time`
    ptq_model.drift_analog_weights(t_inference=eval_time)
    mixed_precision_model_fortraining.drift_analog_weights(t_inference=eval_time)
    # Evaluate both models
    test_accs_ptq.append(test_step(ptq_model, criterion, testloader))
    test_accs_training.append(test_step(mixed_precision_model_fortraining, criterion, testloader))

print(f" [PTQ Model Accuracy]:\t\t{np.mean(test_accs_ptq):.2f} +- {np.std(test_accs_ptq):.2f} %")
print(
    f" [Trained Model Accuracy]:\t{np.mean(test_accs_training):.2f} +- {np.std(test_accs_training):.2f} %"
)

100%|███████████████████████████████████████████████████████████████████████████| 10/10 [04:42<00:00, 28.21s/it]

 [PTQ Model Accuracy]:		92.61 +- 0.23 %
 [Trained Model Accuracy]:	93.07 +- 0.14 %



