This cell installs the timm (PyTorch Image Models) library and imports all necessary dependencies.

In [None]:
from __future__ import print_function
!pip install timm
import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import argparse
import random
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch import autograd
import copy
import pandas as pd
from torch.profiler import profile, record_function, ProfilerActivity
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
import os

**Code to process images from Imagenet into a DataLoader**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!mkdir -p ./val_images/
!tar -xzvf ./drive/MyDrive/val_images.tar.gz -C ./val_images/
from drive.MyDrive.classes import IMAGENET2012_CLASSES
val_images_path = "./drive/MyDrive/val_images.tar.gz"
class_name_to_idx = {class_name: idx for idx, class_name in enumerate(IMAGENET2012_CLASSES.values())}
synset_to_class_idx = {synset: class_name_to_idx[IMAGENET2012_CLASSES[synset]] for synset in IMAGENET2012_CLASSES}

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
ILSVRC2012_val_00026765_n03710193.JPEG
ILSVRC2012_val_00036399_n03888605.JPEG
ILSVRC2012_val_00001975_n02641379.JPEG
ILSVRC2012_val_00022598_n02841315.JPEG
ILSVRC2012_val_00024370_n02837789.JPEG
ILSVRC2012_val_00009522_n02669723.JPEG
ILSVRC2012_val_00041416_n02105412.JPEG
ILSVRC2012_val_00018273_n01817953.JPEG
ILSVRC2012_val_00014016_n09332890.JPEG
ILSVRC2012_val_00010062_n02091635.JPEG
ILSVRC2012_val_00040336_n03958227.JPEG
ILSVRC2012_val_00043321_n03590841.JPEG
ILSVRC2012_val_00015542_n03394916.JPEG
ILSVRC2012_val_00017434_n02108089.JPEG
ILSVRC2012_val_00010395_n01698640.JPEG
ILSVRC2012_val_00047131_n03950228.JPEG
ILSVRC2012_val_00019297_n04296562.JPEG
ILSVRC2012_val_00019317_n07717410.JPEG
ILSVRC2012_val_00010450_n02097658.JPEG
ILSVRC2012_val_00030066_n02692877.JPEG
ILSVRC2012_val_00012978_n02119022.JPEG
ILSVRC2012_val_00038970_n03954731.JPEG
ILSVRC2012_val_00049557_n03759954.JPEG
ILSVRC2012_val_00024718_n04127249.JPEG

In [None]:
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet means
        std=[0.229, 0.224, 0.225]    # ImageNet stds
    ),
])

In [None]:
# Step 2: Define the Dataset class to load images directly from the extracted folder
class ImageNetFolderDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.JPEG')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        # Extract the synset ID from the filename and map to the numeric class index
        root, _ = os.path.splitext(os.path.basename(image_path))
        _, synset_id = root.rsplit("_", 1)
        label = synset_to_class_idx[synset_id]  # Numeric label for the class

        return image, label

In [None]:
dataset = ImageNetFolderDataset(folder_path='/content/val_images/', transform=preprocess)

In [None]:
dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=2)

**Quantization Modules**

Defines the Quantizer_weight (per-channel) and Quantizer_activation (per-token) classes.

In [None]:
class Round(autograd.Function):

    @staticmethod
    def forward(ctx, inputs):
        return torch.floor(inputs + 0.5)

    @staticmethod
    def backward(ctx, grads):
        return grads

class Quantizer_weight(nn.Module):
# quantizises the activations per image rather than per batch and weights per channel rather than layer
    def __init__(self, bits_precision):
        super().__init__()
        self.bits_precision = bits_precision

    def quantize(self,inputs,scale):
        outputs = inputs*(1/scale)
        return outputs

    def round(self, inputs):
        outputs = Round.apply(inputs)
        return outputs

    def clamp(self,inputs):
        outputs = torch.clamp(inputs,-2**(self.bits_precision-1) + 1,2**(self.bits_precision-1) - 1)
        return outputs

    def dequantize(self, inputs,scale):
        outputs = scale*(inputs)
        return outputs

    def forward(self, inputs):
        if len(inputs.shape) == 4:
          tup = (1,2,3)
        else:
          tup = (1,)
        inp_min = torch.clone(inputs)
        inp_max = torch.clone(inputs)
        for dim in tup:
          inp_min = torch.min(inp_min,dim = dim,keepdim = True)[0]
          inp_max = torch.max(inp_max,dim = dim,keepdim = True)[0]
        scale = (inp_max - inp_min + 1e-15)/(2**(self.bits_precision) - 2)
        outputs = self.quantize(inputs,scale)
        outputs = self.round(outputs)
        outputs = self.clamp(outputs)
        outputs = self.dequantize(outputs,scale)
        return outputs,scale

class Quantizer_activation(nn.Module):
# quantizises the activations per image rather than per batch and weights per channel rather than layer
    def __init__(self, bits_precision):
        super().__init__()
        self.bits_precision = bits_precision

    def quantize(self,inputs,scale):
        outputs = inputs*(1/scale)
        return outputs

    def round(self, inputs):
        outputs = Round.apply(inputs)
        return outputs

    def clamp(self,inputs):
        outputs = torch.clamp(inputs,-2**(self.bits_precision-1),2**(self.bits_precision) - 1)
        return outputs

    def dequantize(self, inputs,scale):
        outputs = scale*(inputs)
        return outputs

    def forward(self, inputs):
        if len(inputs.shape) == 4:
          tup = (1,2,3)
        elif len(inputs.shape) ==3:
          tup = (1,2)
        else:
          tup = (1,)
        inp_min = torch.clone(inputs)
        inp_max = torch.clone(inputs)
        for dim in tup:
          inp_min = torch.min(inp_min,dim = dim,keepdim = True)[0]
          inp_max = torch.max(inp_max,dim = dim,keepdim = True)[0]
        scale = (inp_max - inp_min + 1e-15)/(2**(self.bits_precision) - 1)
        outputs = self.quantize(inputs,scale)
        outputs = self.round(outputs)
        outputs = self.clamp(outputs)
        outputs = self.dequantize(outputs,scale)
        return outputs,scale

**Fault Injection Infrastructure: LUTs and Fallbacks**

This block implements the core logic for efficient fault simulation and mitigation:

`build_small_fault_lut`: Precomputes a Look-Up Table (LUT) on the GPU. For every possible n-bit weight and every "small" fault pattern (defined by max_forced_bits), it calculates the closest valid integer value.

`make_forced_mask_val`: A helper that uses matrix multiplication to convert bitwise fault maps (shape bitwidth x N) into integer masks for fast LUT indexing.

`fallback_many_weights`: A robust fallback mechanism in case the LUT does not contain a case encountered.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import combinations, product

##############################################################################
# 1) Build a global LUT on GPU
##############################################################################
def enumerate_small_fault_patterns(bitwidth=8, max_forced_bits=2):
    """
    Return a list of (pattern_code, forced_mask, forced_val) for all patterns
    that have <= max_forced_bits forced bits in an 8-bit number.
    """
    patterns = []
    pcode = 0

    # For k=0..max_forced_bits
    for k in range(max_forced_bits+1):
        if k == 0:
            # No forced bits => (mask=0, val=0)
            patterns.append((pcode, 0, 0))
            pcode += 1
        else:
            # pick which k bits to force
            for bits_chosen in combinations(range(bitwidth), k):
                # each forced bit can be forced to 0 or 1
                for forced_vals in product([0,1], repeat=k):
                    mask = 0
                    val  = 0
                    for bit_pos, f_val in zip(bits_chosen, forced_vals):
                        mask |= (1 << bit_pos)
                        if f_val == 1:
                            val |= (1 << bit_pos)
                    patterns.append((pcode, mask, val))
                    pcode += 1

    return patterns

def build_small_fault_lut(bitwidth=8, max_forced_bits=2, device='cuda'):
    """
    Build:
      global_lut: (256, num_patterns) => best readout (int16 in [-128..127])
      pattern_id_map: shape (65536,) => maps (mask<<8 | val) -> pattern_idx or -1
    On the specified device (e.g. 'cuda').
    """
    all_patterns = enumerate_small_fault_patterns(bitwidth, max_forced_bits)
    num_patterns = len(all_patterns)

    # 1) Build the LUT array on GPU
    global_lut = torch.empty((256, num_patterns), dtype=torch.int16, device=device)

    # Precompute all 256 combos of 8-bit two's complement => [-128..127]
    all_u8 = torch.arange(256, dtype=torch.uint8, device=device)
    all_s8 = all_u8.to(torch.int16).clone()
    mask_neg = (all_s8 >= 128)
    all_s8[mask_neg] -= 256  # now in [-128..127]

    xs = torch.arange(-128, 128, dtype=torch.int16, device=device)
    xs_offset = xs + 128

    for i, xval in enumerate(xs):
        # Distance from xval
        dist_all = (all_s8 - xval).abs()

        for pat_idx, (pcode, fm, fv) in enumerate(all_patterns):
            # valid combos => (all_u8 & fm) == fv
            masked_bits = (all_u8 & fm)
            forced_ok = (masked_bits == fv)
            valid_indices = torch.where(forced_ok)[0]
            if len(valid_indices)==0:
                # No valid combos, pick xval itself
                best_val = xval.item()
            else:
                subdist = dist_all[valid_indices]
                best_subidx = torch.argmin(subdist)
                global_idx = valid_indices[best_subidx]
                best_val = all_s8[global_idx].item()
            global_lut[xs_offset[i].item(), pat_idx] = best_val

    # 2) pattern_id_map => shape(65536,), store on GPU
    pattern_id_map = -1 * torch.ones((1<<(2*bitwidth),), dtype=torch.int16, device=device)
    for pat_idx, (pcode, fm, fv) in enumerate(all_patterns):
        combo_id = (fm << 8) | fv
        pattern_id_map[combo_id] = pat_idx

    return global_lut, pattern_id_map, all_patterns


##############################################################################
# 2) Helper to build forced_mask/val using float matmul on GPU
##############################################################################
def make_forced_mask_val(faults_2d: torch.Tensor):
    """
    faults_2d: shape (bitwidth, N) in {0, +1, -1}, on GPU, typically float or int.
    Returns (forced_mask, forced_val, forced_count) each shape(N,) in int16.
    We do an internal float matmul, then cast to int16 at the end.
    """
    device = faults_2d.device
    bitwidth, N = faults_2d.shape

    # powers_of_two in [1,2,4,...]
    powers_of_two = (1 << torch.arange(bitwidth, device=device, dtype=torch.int64)).to(torch.float32)

    # Convert faults_2d to float32 if not already
    f_2d_float = faults_2d.to(torch.float32)

    forced_mask_bool = (f_2d_float != 0.0).to(torch.float32)   # (bitwidth, N)
    forced_mask_32 = forced_mask_bool.T @ powers_of_two        # (N,)

    forced_val_bool = (f_2d_float == 1.0).to(torch.float32)
    forced_val_32   = forced_val_bool.T @ powers_of_two        # (N,)

    forced_count_32 = forced_mask_bool.sum(dim=0)              # (N,)

    forced_mask  = forced_mask_32.to(torch.int16)
    forced_val   = forced_val_32.to(torch.int16)
    forced_count = forced_count_32.to(torch.int16)

    return forced_mask, forced_val, forced_count


##############################################################################
# 3) Fallback routines, all GPU-compatible, returning INT16
##############################################################################
def fallback_many_weights(
    targets_s8: torch.Tensor,  # shape(M,) in [-128..127], int16
    faults_2d:  torch.Tensor,  # shape(bitwidth,M) in {0,+1,-1}, float/int
    bitwidth=8
)->torch.Tensor:
    """
    Vector fallback: enumerates all 2^bitwidth combos for each item.
    Returns shape(M,) => best read for each item in INT16 ([-128..127]).
    """
    device = targets_s8.device
    M = targets_s8.shape[0]

    # all 2^bitwidth combos in [-128..127]
    all_u8 = torch.arange(1<<bitwidth, dtype=torch.uint8, device=device)
    all_s8 = all_u8.to(torch.int16).clone()
    mask_neg = (all_s8 >= (1<<(bitwidth-1)))
    all_s8[mask_neg] -= (1<<bitwidth)

    # Expand combos bitwise
    bpos = torch.arange(bitwidth, device=device).view(1, -1)
    combos_8 = ((all_u8.unsqueeze(-1) >> bpos) & 1).to(torch.int8)  # shape(256, bitwidth)

    # We interpret faults_2d
    # forced_1 => (f_t==1)
    # forced_0 => (f_t==-1)
    f_t = faults_2d.to(torch.int8).T  # shape(M, bitwidth)
    forced_1 = (f_t == 1)
    forced_0 = (f_t == -1)

    # combos_expanded: shape(M,256,bitwidth)
    combos_expanded = combos_8.unsqueeze(0).expand(M, -1, -1)
    forced_1_ok = (~forced_1.unsqueeze(1)) | (combos_expanded == 1)
    forced_0_ok = (~forced_0.unsqueeze(1)) | (combos_expanded == 0)
    forced_ok   = torch.logical_and(forced_1_ok, forced_0_ok).all(dim=-1)  # shape(M,256)

    # Distance to each combo
    # Convert targets_s8 => int16 => shape(M,)
    # We'll do an int32 distance to avoid any overflow.
    t_s32 = targets_s8.to(torch.int32)
    all_s8_s32 = all_s8.to(torch.int32)  # shape(256,)
    dist = (all_s8_s32.view(1, -1) - t_s32.view(-1, 1)).abs()  # (M,256) in int32

    # Mark invalid combos with very large distance
    dist[~forced_ok] = 999999

    best_idx = dist.argmin(dim=1)  # (M,)
    best_vals = all_s8[best_idx]   # shape(M,) in int16
    return best_vals

In [None]:
global_lut, pattern_id_map,_ = build_small_fault_lut(bitwidth=8, max_forced_bits=2)

**Custom Quantized Linear Layer with Fault Injection**

This class extends nn.Linear to support weight and activation quantization and stuck-at fault (SAF) simulation.

`inject_faults`: The central method for simulating SAFs. It injects "stuck-at" faults into the weights and applies advanced mitigation strategies:

`none`: Baseline in this work - Closest Value Mapping (CVM).

`flip`: Sign-Flip technique.

`flip-bitwise`: Bit-Flip technique.

`forward`: Executes the quantized matrix multiplication, handling input/weight scaling factors to restore the full-precision output range.

In [None]:
class custom_QuantizedLinear_new_baseline(nn.Linear):
    def __init__(self, *args,**kwargs):
        super(custom_QuantizedLinear_new_baseline, self).__init__(*args,**kwargs)
        self.weight_precision = 8
        self.input_precision  = 8
        self.array_dimension  = 64
        self.inference_done   = False
        self.weight_scale     = 0
        self.max_forced_bits = 2
        self.global_lut       = global_lut.to('cuda')            # shape(256, numPatterns)
        self.pattern_id_map   = pattern_id_map.to('cuda')        # shape(65536,)
        # placeholders for your actual quantizers
        self.weight_quantizer = Quantizer_weight(bits_precision=self.weight_precision)
        self.input_quantizer  = Quantizer_activation(bits_precision=self.input_precision)

    def to_inference(self):
        # quantize and store final
        weight, weight_scale = self.weight_quantizer(self.weight)
        self.weight_scale = weight_scale
        self.weight.data = weight
        self.inference_done = True

    def inject_faults(self,
                      fault_rate: float,
                      fault_mitigation: str='none',
                      row_block_size: int=64):
        """
        All modes use the LUT for recognized items (<= max_forced_bits).
        For 'bad' items, do a fully vector fallback.

        The "flip-bitwise" path enumerates the 256 possible bit flips in **one** pass
        per row-block (no sub-batching), stored in int-safe arrays.
        """
        device = self.weight.device
        bitwidth = self.weight_precision

        # 1) Convert original weights to int16 domain
        w_float   = self.weight / self.weight_scale
        w_rounded = Round.apply(w_float).clamp(-128,127)
        weight_2d = w_rounded.transpose(0,1).contiguous()  # shape(in_feats, out_feats)
        in_feats, out_feats = weight_2d.shape

        # 2) Build random stuck bits => shape(bitwidth, N)
        N = in_feats*out_feats
        mat = (fault_rate/100.0)*torch.ones(bitwidth, N, device=device)
        f_matrix= torch.bernoulli(mat)               # in {0,1} float
        sign_mat= (2*torch.bernoulli(0.5*torch.ones_like(mat))-1)  # in {-1,+1}
        faults_2d= f_matrix*sign_mat  # in {0, -1, +1}, shape(bitwidth,N)

        # We'll store final result in new_weight_2d => shape(in_feats, out_feats)
        new_weight_2d = torch.empty_like(weight_2d, dtype=torch.int16, device=device)

        if fault_mitigation == 'none':
            # Single-pass vector approach: LUT for recognized, fallback for others
            flat_w = weight_2d.view(-1).to(torch.int16)
            fm, fv, fc = make_forced_mask_val(faults_2d)
            combo_id   = ((fm << 8) | fv).to(torch.int64)
            pat_idx    = self.pattern_id_map[combo_id].to(torch.long)

            recognized_mask = (pat_idx >= 0) & (fc <= self.max_forced_bits)
            out_flat = torch.empty_like(flat_w, dtype=torch.int16, device=device)

            w_off = (flat_w + 128).clamp(0,255).to(torch.long)
            pat_idx_safe = pat_idx.clamp_min(0)
            out_lut = self.global_lut[w_off, pat_idx_safe]
            out_flat[recognized_mask] = out_lut[recognized_mask]

            # fallback
            bad_mask= ~recognized_mask
            if bad_mask.any():
                bad_idx = torch.where(bad_mask)[0]
                bad_w   = flat_w[bad_idx]
                bad_f   = faults_2d[:, bad_idx]
                best_val_bad = fallback_many_weights(bad_w, bad_f, bitwidth)
                out_flat[bad_idx] = best_val_bad

            new_weight_2d = out_flat.view(in_feats, out_feats)

        elif fault_mitigation == 'flip':
            # row-block approach, sign-flip
            row_start=0
            while row_start < in_feats:
                row_end = min(row_start+row_block_size, in_feats)
                block_height = row_end - row_start

                block_2d = weight_2d[row_start:row_end, :]
                start_idx= row_start*out_feats
                end_idx  = row_end*out_feats
                block_faults = faults_2d[:, start_idx:end_idx]

                block_flat = block_2d.view(-1).to(torch.int16)
                fm, fv, fc = make_forced_mask_val(block_faults)
                combo_id   = ((fm << 8) | fv).to(torch.int64)
                pat_idx    = self.pattern_id_map[combo_id].to(torch.long)
                recognized_mask = (pat_idx>=0) & (fc<=self.max_forced_bits)

                pos_val = torch.empty_like(block_flat, dtype=torch.int16, device=device)
                neg_val = torch.empty_like(block_flat, dtype=torch.int16, device=device)

                # LUT for recognized
                plus_flat  = block_flat
                minus_flat = -block_flat
                w_off_pos  = (plus_flat + 128).clamp(0,255).to(torch.long)
                w_off_neg  = (minus_flat+128).clamp(0,255).to(torch.long)
                pat_idx_safe = pat_idx.clamp_min(0)

                pos_lut = self.global_lut[w_off_pos, pat_idx_safe]
                neg_lut = self.global_lut[w_off_neg, pat_idx_safe]
                pos_val[recognized_mask] = pos_lut[recognized_mask]
                neg_val[recognized_mask] = neg_lut[recognized_mask]

                # fallback for bad
                bad_mask = ~recognized_mask
                if bad_mask.any():
                    bad_idx = torch.where(bad_mask)[0]
                    b_w_pos = plus_flat[bad_idx]
                    b_w_neg = minus_flat[bad_idx]
                    b_f     = block_faults[:, bad_idx]
                    best_p  = fallback_many_weights(b_w_pos, b_f, bitwidth)
                    best_n  = fallback_many_weights(b_w_neg, b_f, bitwidth)
                    pos_val[bad_idx] = best_p
                    neg_val[bad_idx] = best_n

                # reshape
                block_out_pos = pos_val.view(block_height, out_feats)
                block_out_neg = neg_val.view(block_height, out_feats)

                # sum error column-wise in int32
                diff_pos = (block_out_pos - block_2d).abs().to(torch.int32)
                diff_neg = (block_out_neg + block_2d).abs().to(torch.int32)
                sum_err_pos = diff_pos.sum(dim=0)
                sum_err_neg = diff_neg.sum(dim=0)

                final_block = torch.empty_like(block_out_pos, dtype=torch.int16, device=device)
                better_mask = (sum_err_pos <= sum_err_neg)
                final_block[:, better_mask]  = block_out_pos[:, better_mask]
                # If neg is better, store the 'negative' of block_out_neg
                # but remember block_out_neg is already "the best read for -w"
                # to keep the final consistent with the sign flip, we must store -1 * that
                final_block[:, ~better_mask] = (-block_out_neg[:, ~better_mask]).to(torch.int16)

                new_weight_2d[row_start:row_end, :] = final_block
                row_start=row_end

        elif fault_mitigation == 'flip-bitwise':
            # row-block approach, enumerates 256 bit flips in ONE PASS per block
            num_flips = 1 << bitwidth  # 256

            # Precompute "flip_vectors" in int8 or int16
            flip_u = torch.arange(num_flips, dtype=torch.int16, device=device)  # [0..255]
            bpos   = torch.arange(bitwidth,  dtype=torch.int16, device=device).view(1, -1)
            flip_bits = ((flip_u.unsqueeze(-1) >> bpos) & 1).to(torch.int16)  # shape(256, bitwidth)
            # Convert [0,1] bits to +1 or -1 flips:
            flip_vectors = (1 - 2*flip_bits)  # shape(256, bitwidth), each in {+1, -1}

            row_start=0
            while row_start<in_feats:
                row_end = min(row_start+row_block_size, in_feats)
                block_height= row_end - row_start

                block_2d     = weight_2d[row_start:row_end,:]     # (block_height, out_feats)
                start_idx    = row_start*out_feats
                end_idx      = row_end*out_feats
                block_faults = faults_2d[:, start_idx:end_idx]     # (bitwidth, block_height*out_feats)

                # Flatten block => shape(N,)
                N = block_height*out_feats
                block_flat_s16 = block_2d.view(-1).to(torch.int16)

                # Build forced_mask/val for the base block faults
                fm, fv, fc = make_forced_mask_val(block_faults)
                recognized_mask = (fc <= self.max_forced_bits)

                # Keep track of the best error (int32) and best read (int16) for each column
                col_err = torch.full((out_feats,), 999999, dtype=torch.int32, device=device)
                col_best= torch.empty((block_height, out_feats), dtype=torch.int16, device=device)

                # We'll do a single pass through all 256 flips:
                #   For each "flip vector", we multiply the block_faults by that flip pattern
                #   Then see what the best read is (LUT or fallback).
                #   Then compute the total error per column and keep track if it's better.

                # Expand block_2d for column-wise error
                block_2d_s16 = block_2d.to(torch.int16)

                # Precompute repeated w => shape(1, N)
                repeated_w = block_flat_s16.unsqueeze(0)

                # We will try all 256 flips in one shot:
                # shape(256, bitwidth, N)
                # For each flip row: expanded_faults[i] = block_faults * flip_vectors[i]
                flip_vectors_3d = flip_vectors.unsqueeze(-1)  # shape(256, bitwidth, 1)
                # shape(1, bitwidth, N)
                base_fault_3d   = block_faults.unsqueeze(0).to(torch.int16)
                # Multiply => shape(256, bitwidth, N)
                all_flips = base_fault_3d * flip_vectors_3d

                # Now flatten so we can call make_forced_mask_val:
                # shape(bitwidth, 256*N) => we want that, but we currently have (256, bitwidth, N)
                # we'll reorder to (bitwidth, 256*N):
                all_flips_reordered = all_flips.permute(1, 0, 2).reshape(bitwidth, -1)

                # Forced mask/val
                fm2, fv2, fc2 = make_forced_mask_val(all_flips_reordered)  # each => shape(256*N,)
                # Partition them back into 256 groups of N
                fm2_2d = fm2.view(num_flips, N)
                fv2_2d = fv2.view(num_flips, N)
                fc2_2d = fc2.view(num_flips, N)

                # pattern lookup
                combo_id2 = ((fm2_2d.to(torch.int32) << 8) | fv2_2d.to(torch.int32)).to(torch.int64)
                pat_idx2  = self.pattern_id_map[combo_id2].to(torch.int32)
                rec_mask2 = (pat_idx2 >= 0) & (fc2_2d <= self.max_forced_bits)

                # We'll accumulate all reads in out_vals => shape(256,N)
                out_vals = torch.empty((num_flips, N), dtype=torch.int16, device=device)

                # For recognized combos => LUT
                w_off = (repeated_w + 128).clamp(0,255).to(torch.int64)  # shape(1,N)
                w_off_expand = w_off.expand(num_flips, -1)               # shape(256,N)
                safe_idx2 = pat_idx2.clamp_min(0)                        # shape(256,N)
                # Use gather from global_lut => shape(256,N)
                # global_lut => (256, numPatterns), we have w_off_expand in [0..255], safe_idx2 in [0..numPatterns-1]
                # We'll flatten so we can index 2D in a single gather. One approach is:
                #   out_lut_flat = global_lut[w_off_expand.view(-1), safe_idx2.view(-1)]
                # Then reshape => (256,N).
                w_off_flat = w_off_expand.reshape(-1)
                pat_idx_flat= safe_idx2.reshape(-1)
                out_lut_flat= self.global_lut[w_off_flat, pat_idx_flat]  # shape(256*N,)
                out_lut_2d  = out_lut_flat.view(num_flips, N)

                # Write recognized combos
                recognized_2d = rec_mask2
                out_vals[recognized_2d] = out_lut_2d[recognized_2d]

                # Fallback for "bad" combos
                bad_mask_2d = ~recognized_2d
                if bad_mask_2d.any():
                    bad_idx = torch.where(bad_mask_2d)
                    # Each bad_idx => (flip_index, item_index)
                    # flatten these out:
                    flat_linear = bad_idx[0]*N + bad_idx[1]
                    b_w = repeated_w.view(-1)[bad_idx[1]]  # shape(#bad,) in int16
                    # shape(bitwidth, #bad)
                    # we gather from all_flips => (256, bitwidth, N)
                    # For a given flip i and item j => all_flips[i, :, j]
                    # We'll do a gather for each pair.
                    # Easiest is to index all_flips_reordered again (which is (bitwidth,256*N)),
                    # for these same 'flat_linear' but offset by item?
                    # Actually we can do:
                    b_fault = all_flips_reordered[:, flat_linear]  # shape(bitwidth, #bad)

                    best_vals_b = fallback_many_weights(b_w, b_fault, bitwidth)
                    # place them in out_vals
                    out_vals.view(-1)[flat_linear] = best_vals_b

                # Now we have out_vals => shape(256, N).  Reshape each row => (block_height, out_feats)
                out_vals_3d = out_vals.view(num_flips, block_height, out_feats)

                # Compute column-wise error => shape(256, out_feats)
                # We'll do int32 differences
                block_2d_expand = block_2d_s16.unsqueeze(0).expand(num_flips, -1, -1)  # shape(256, block_height,out_feats)
                diff = (out_vals_3d - block_2d_expand).abs().to(torch.int32)          # shape(256, block_height, out_feats)
                sum_err_cols = diff.sum(dim=1)  # shape(256, out_feats), int32

                # For each column, pick the flip index that yields minimal sum_err
                # We'll track min over the axis=0.
                # sum_err_cols => shape(256, out_feats)
                best_flip_idx = torch.argmin(sum_err_cols, dim=0)  # shape(out_feats,)

                # Now gather the actual best read
                # col_best => shape(block_height, out_feats)
                # For each column c, we want out_vals_3d[ best_flip_idx[c], :, c ]
                # We'll do a gather in a loop or do fancy indexing:
                # fancy indexing approach: out_vals_3d[ best_flip_idx, range(out_feats), ... ] won't work directly
                # because out_vals_3d is (256, block_height, out_feats).
                # We can do:
                #    for c in 0..out_feats:
                #       col_best[:, c] = out_vals_3d[best_flip_idx[c], :, c]
                for c in range(out_feats):
                    col_best[:, c] = out_vals_3d[best_flip_idx[c], :, c]

                new_weight_2d[row_start:row_end, :] = col_best
                row_start = row_end

        else:
            raise ValueError(f"Unknown fault_mitigation={fault_mitigation}")

        # store final
        final_w = new_weight_2d.transpose(0,1).to(self.weight.dtype)
        self.weight.data = final_w * self.weight_scale
        print(f"Done inject_faults with mode={fault_mitigation}. row_block_size={row_block_size}.")


    def forward(self,input):
        input,inp_scale = self.input_quantizer(input) #inp_scale dim = [batch_size,]
        input = Round.apply(input/inp_scale) #dim = [batch_size,no_of_tokens,x]
        if self.inference_done:
          weight = self.weight
          weight_scale = self.weight_scale
        else:
          weight,weight_scale = self.weight_quantizer(self.weight) #weight_scale dim = [output_channels,]
        out = Round.apply(F.linear(input, Round.apply(weight/weight_scale),bias = None)) #dimensions = [batch_size,no_of_tokens,output_channels]
        out = out*inp_scale*weight_scale.view(1,self.out_features)
        if self.bias is not None:
            out = out + self.bias.view(1,self.out_features)
        return out

**Model-Wide Utility Functions**

These helper functions traverse the entire model architecture to manage the custom layers:

`inject_faults_to_model`: Recursively finds all `custom_QuantizedLinear_new_baseline` instances and triggers their fault injection logic with the specified error rate and mitigation strategy.

`model_to_inference`: Switches all quantized layers to inference mode, ensuring weights are quantized/frozen and scales are fixed before the validation phase begins.

In [None]:
def inject_faults_to_model(model,fault_rate,fault_mitigation):
    for module in model.modules():
        if isinstance(module, custom_QuantizedLinear_new_baseline):
            module.inject_faults(fault_rate,fault_mitigation)

def model_to_inference(model):
    for module in model.modules():
        if isinstance(module, custom_QuantizedLinear_new_baseline):
            module.to_inference()

**Standard validation function**

In [None]:
def test(model,test_loader,criterion):
 with torch.no_grad():
    model.eval()
    test_loss = 0
    correct = 0

    for data,target in test_loader:
        data,target = data.to('cuda'),target.to('cuda')
        output = model(data)
        test_loss += criterion(output,target).item()
        pred = output.data.max(1,keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        #print(correct)

    acc = 100. * correct/len(test_loader.dataset)

    test_loss /= len(test_loader)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    return acc

**Automated Model Conversion**

A utility function that recursively traverses the model architecture to replace standard `nn.Linear` layers with the custom quantized variant (quantized_cls). It handles the transfer of pre-trained weights and biases to the new layers.

Selective Replacement: Includes a skip_name parameter (defaulting to 'head') to exclude specific layers (like the final classification head) from quantization, preserving their original precision.

In [None]:
def replace_linear_layers(module, quantized_cls, skip_name='head'):
    """
    Recursively replace nn.Linear layers with quantized_cls layers,
    except for the layer whose immediate parent name is skip_name.
    """
    for name, child in list(module.named_children()):
        # If the child is a Linear layer and its parent's name isn't skip_name, replace it
        if isinstance(child, nn.Linear) and name != skip_name:
            # Create a new quantized layer with the same dimensions
            new_layer = quantized_cls(child.in_features, child.out_features, bias=(child.bias is not None))
            # Copy over the weights and bias
            with torch.no_grad():
                new_layer.weight.copy_(child.weight)
                if child.bias is not None:
                    new_layer.bias.copy_(child.bias)
            setattr(module, name, new_layer)
        else:
            # Recursively descend into children
            replace_linear_layers(child, quantized_cls, skip_name=skip_name)

**Experiment**


Model Initialization: Loads the Imagenet pretrained `vit_base_patch16_224` architecture.

Quantization & Inference Setup: Converts all linear layers (excluding the head) to custom 8-bit layers and switches the model to inference mode (`model_to_inference`), which freezes the quantization scales.

Fault Simulation Settings (`inject_faults_to_model`):

Fault Rate: Simulating 1% to 5% rates with loop variable `j`.

Mitigation Mode: Set to `none`, which uses the standard closest value mapping (CVM) approach. Other options include `flip` (Sign Flip) or `flip-bitwise` (Bit-Flip).

Statistical Reliability: The inner loop (range(20)) repeats the evaluation 20 times. Since fault injection is stochastic (random bits are stuck each time), averaging the accuracy increases the statistical significance of the results.

In [None]:
for j in range(1,6):
  accuracy = 0
  for i in range(20):
    criterion = torch.nn.CrossEntropyLoss()
    model_name = "vit_base_patch16_224"  # You can choose other variants
    model = timm.create_model(model_name, pretrained=True)
    acc = test(model.to('cuda'),dataloader, criterion)
    replace_linear_layers(model, custom_QuantizedLinear_new_baseline, skip_name='head')
    model = model.to('cuda')
    model_to_inference(model)
    acc = test(model,dataloader, criterion)
    inject_faults_to_model(model,j,"flip")
    acc = test(model,dataloader, criterion)
    accuracy += acc
  print(accuracy)