# IBM Analog Hardware Acceleration Kit (AIHWKIT): Inference using noise models characterized on the IBM HERMES Project Chip

Le Gallo, M., Khaddam-Aljameh, R., Stanisavljevic, M. et al. A 64-core mixed-signal in-memory compute chip based on phase-change memory for deep neural network inference. Nat Electron 6, 680–693 (2023). https://doi.org/10.1038/s41928-023-01010-1

In [None]:
# various utility functions
import torch
import torch.nn.functional as F
import torch.nn.init as init
import torchvision
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _weights_init(m):
    if isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d):
        init.kaiming_normal_(m.weight)


class LambdaLayer(torch.nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(torch.nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option="A"):
        super(BasicBlock, self).__init__()
        self.conv1 = torch.nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = torch.nn.BatchNorm2d(planes)
        self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(planes)

        self.shortcut = torch.nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == "A":
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(
                    lambda x: F.pad(
                        x[:, :, ::2, ::2],
                        (0, 0, 0, 0, planes // 4, planes // 4),
                        "constant",
                        0,
                    )
                )
            elif option == "B":
                self.shortcut = torch.nn.Sequential(
                    torch.nn.Conv2d(
                        in_planes,
                        self.expansion * planes,
                        kernel_size=1,
                        stride=stride,
                        bias=False,
                    ),
                    torch.nn.BatchNorm2d(self.expansion * planes),
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(torch.nn.Module):
    def __init__(self, block, num_blocks, n_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = torch.nn.Linear(64, n_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return torch.nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet32(n_classes=10):
    return ResNet(BasicBlock, [5, 5, 5], n_classes=n_classes)


class TorchCutout(object):
    def __init__(self, length, fill=(0.0, 0.0, 0.0)):
        self.length = length
        self.fill = torch.tensor(fill).reshape(shape=(3, 1, 1))

    def __call__(self, img):
        h = img.size(1)
        w = img.size(2)
        y = np.random.randint(h)
        x = np.random.randint(w)
        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)
        img[:, y1:y2, x1:x2] = self.fill
        return img


# Load dataset
def load_cifar10(batch_size, path):
    transform_train = torchvision.transforms.Compose(
        [
            torchvision.transforms.RandomCrop(32, padding=4),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            TorchCutout(length=8),
        ]
    )

    transform_test = torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    trainset = torchvision.datasets.CIFAR10(
        root=path, train=True, download=True, transform=transform_train
    )
    testset = torchvision.datasets.CIFAR10(
        root=path, train=False, download=True, transform=transform_test
    )
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=1
    )
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=1
    )

    return trainloader, testloader

In [None]:
# - Generic imports
import os

import matplotlib.pyplot as plt
import numpy as np
import torch

from aihwkit.inference.compensation.drift import GlobalDriftCompensation
from aihwkit.inference.noise.hermes import HermesNoiseModel
from aihwkit.inference.noise.pcm import PCMLikeNoiseModel

# - AIHWKIT related imports
from aihwkit.nn.conversion import convert_to_analog
from aihwkit.simulator.configs import InferenceRPUConfig
from aihwkit.simulator.configs.utils import (
    BoundManagementType,
    NoiseManagementType,
    WeightClipType,
    WeightModifierType,
    WeightRemapType,
)
from aihwkit.simulator.presets import StandardHWATrainingPreset
from aihwkit.simulator.presets.utils import IOParameters

## RPUConfig
To use the Hermes noise model, adjust the `rpu_config.noise_model` field of the `RPUConfig`. The noise model can be instatiated by the class `HermesNoiseModel`. See the following cells for available options.

In [None]:
def gen_rpu_config(noise_model):
    rpu_config = InferenceRPUConfig()

    # To select the Hermes noise model, change the `rpu_config.noise_model` field
    # with an instance of the noise class (see next cells for details)
    rpu_config.noise_model = noise_model

    # RPU config options to match the training config
    rpu_config.modifier.std_dev = 0.06
    rpu_config.modifier.type = WeightModifierType.ADD_NORMAL

    rpu_config.mapping.digital_bias = True
    rpu_config.mapping.weight_scaling_omega = 1.0
    rpu_config.mapping.weight_scaling_columnwise = False
    rpu_config.mapping.out_scaling_columnwise = False
    rpu_config.remap.type = WeightRemapType.LAYERWISE_SYMMETRIC

    rpu_config.clip.type = WeightClipType.LAYER_GAUSSIAN
    rpu_config.clip.sigma = 2.0

    rpu_config.forward = IOParameters()
    rpu_config.forward.is_perfect = False
    rpu_config.forward.out_noise = 0.0
    rpu_config.forward.inp_bound = 1.0
    rpu_config.forward.inp_res = 1 / (2**8 - 2)
    rpu_config.forward.out_bound = 12
    rpu_config.forward.out_res = 1 / (2**8 - 2)
    rpu_config.forward.bound_management = BoundManagementType.NONE
    rpu_config.forward.noise_management = NoiseManagementType.NONE

    rpu_config.pre_post.input_range.enable = True
    rpu_config.pre_post.input_range.decay = 0.01
    rpu_config.pre_post.input_range.init_from_data = 50
    rpu_config.pre_post.input_range.init_std_alpha = 3.0
    rpu_config.pre_post.input_range.input_min_percentage = 0.995
    rpu_config.pre_post.input_range.manage_output_clipping = False
    
    rpu_config.drift_compensation = GlobalDriftCompensation()
    return rpu_config


In [None]:
# Function to perform inference on the test set and calculate the test accuracy
def test_step(model, criterion, testloader):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return 100.0 * correct / total

In [None]:
# - Set seeds
torch.manual_seed(42)
np.random.seed(42)

# - Get the dataloader
batch_size = 128
_, testloader = load_cifar10(
    batch_size=batch_size, path=os.path.expanduser("~/Data/")
)

# - Define model and the criterion
model = resnet32()
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()

## Hermes Noise Model
Hermes' unit cell offers the capability to map a weight on either 1 device or 2 devices per polarity. Both modes have been characterized and can be called using the `num_devices` parameter during instatiation. The user has the capability to further tweek the model by changing other parameters, see the class prototype for more details.

In [None]:
# - Noise model instatiation for comparison (Previous PCMLikeNoiseModel with the new HermesNoiseModel for 1 and 2 num_devices)
noise_models_to_compare = {
    "Standard": PCMLikeNoiseModel(g_max=25.0),
    "Hermes 1D": HermesNoiseModel(num_devices=1),
    "Hermes 2D": HermesNoiseModel(num_devices=2),
}
rpu_configs = {
    model_name: gen_rpu_config(noise_model)
    for model_name, noise_model in noise_models_to_compare.items()
}
# - Instatiate models, each with an RPU config with the corresponding noise model
analog_models = {
    model_name: convert_to_analog(model, config) for model_name, config in rpu_configs.items()
}

# Download the HW-Aware trained checkpoint and load it in the models
!wget -P Models/ https://aihwkit-tutorial.s3.us-east.cloud-object-storage.appdomain.cloud/finetuned_model_0.9.1.th
for model in analog_models.values():
    model.load_state_dict(
        torch.load("Models/finetuned_model_0.9.1.th", map_location=device), load_rpu_config=False,
    )
print(f"Finetuned test acc. w/o noise: {test_step(analog_models['Standard'], criterion, testloader):.2f} %")

In [None]:
# - For programming the model, we need to put it into eval() mode
for model in analog_models.values():
    model.eval()
# - We repeat each measurement 5 times
n_rep = 5
t_inferences = [0.0, 60.0, 3600.0, 86400.0, 2592000.0, 31104000.0]
_, ax = plt.subplots()
ax: plt.Axes
for noise_name, model in analog_models.items():
    drifted_test_accs = torch.zeros(size=(len(t_inferences), n_rep))
    for i, t in enumerate(t_inferences):
        for j in range(n_rep):
            model.drift_analog_weights(t)
            drifted_test_accs[i, j] = test_step(model, criterion, testloader)

    ax.errorbar(
        t_inferences,
        drifted_test_accs.mean(1),
        drifted_test_accs.std(1),
        capsize=3,
        label=noise_name,
    )

ax.set_xlabel("Time (s)")
ax.set_xscale("log")
ax.set_ylabel("Test acc. (%)")
ax.legend();