In [None]:
!git clone https://github.com/WaiNaat/pytorchfi.git
!pip install bitstring

In [None]:
import torch
import torchvision
import random
import copy
import numpy as np
import datetime
import pandas as pd
from bitstring import BitArray

from torchvision import transforms
from tqdm import tqdm

from google.colab import drive
drive.mount('/content/drive')

In [None]:
from pytorchfi.core import FaultInjection
from pytorchfi.weight_error_models import random_weight_location

## 환경설정 관련
`model_name`: https://github.com/chenyaofo/pytorch-cifar-models 여기 표에 있는 모델명 복붙    
`seed`: `None`으로 하면 랜덤 시드 사용

In [None]:
# 실험 환경 설정
model_name = "vgg19_bn"
dataset = 'cifar10'

seed = None

batch_size = 256
img_size = 32
channels = 3

use_gpu = torch.cuda.is_available()

save_detailed_log = True

custom_bit_flip_pos = None
layer_type = ['all']
layer_nums = ['all']

In [None]:
if seed is None:
    seed = int(datetime.datetime.now().timestamp())

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

In [None]:
# 모델 설정
model = torch.hub.load("chenyaofo/pytorch-cifar-models", dataset + '_' + model_name, pretrained=True)
if use_gpu: model.to(device='cuda')

# print(model)

In [None]:
# Transform statics from https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/vgg11_bn/default.log
dataloader = None
if dataset == 'cifar10':
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.201])
        ]
    )
    data = torchvision.datasets.CIFAR10(root='/data', train=False, download=True, transform=transform)
    dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)

elif dataset == 'cifar100':
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.507, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761])
        ]
    )
    data = torchvision.datasets.CIFAR100(root='/data', train=False, download=True, transform=transform)
    dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)

else:
    raise AssertionError(f'Invalid dataset name {dataset}')

In [None]:
class weight_single_bit_flip(FaultInjection):
    def __init__(self, model, batch_size, flip_bit_pos=None, save_log_list=False, **kwargs):
        super().__init__(model, batch_size, **kwargs)
        self.flip_bit_pos = flip_bit_pos
        self.save_log_list = save_log_list

        self.log_original_value = []
        self.log_original_value_bin = []
        self.log_error_value = []
        self.log_error_value_bin = []
        self.log_bit_pos = []

    def reset_log(self):
        '''
        You MUST call this function after single inference if save_log_list=True
        '''
        self.log_original_value = []
        self.log_original_value_bin = []
        self.log_error_value = []
        self.log_error_value_bin = []
        self.log_bit_pos = []

    def weight_flip_function(self, weight, position):

        bits = weight[position].dtype
        if bits == torch.float32:
            bits = 32
        elif bits == torch.float64:
            bits = 64
        else:
            raise AssertionError(f'Unsupported data type {bits}')

        rand_bit = random.randint(0, bits - 1) if self.flip_bit_pos is None else self.flip_bit_pos

        return self._single_bit_flip(weight[position], rand_bit)
            
    def _single_bit_flip(self, orig_value, bit_pos):
        # set data type
        save_type = orig_value.dtype
        orig_value = orig_value.cpu().item()
        length = None
        if save_type == torch.float32:
            length = 32
        elif save_type == torch.float64:
            length = 64
        else:
            raise AssertionError(f'Unsupported Data Type: {save_type}')

        # single bit flip
        orig_arr = BitArray(float = orig_value, length = length)
        error = list(map(int, orig_arr.bin))
        error[bit_pos] = (error[bit_pos] + 1) % 2
        error = ''.join(map(str, error))
        error = BitArray(bin=error)
        new_value = error.float

        if self.save_log_list:
            self.log_original_value.append(orig_value)
            self.log_original_value_bin.append(orig_arr.bin)
            self.log_error_value.append(new_value)
            self.log_error_value_bin.append(error.bin)
            self.log_bit_pos.append(bit_pos)

        return torch.tensor(new_value, dtype=save_type)

In [None]:
# single bit flip을 일으킬 모델 만들기
base_fi_model = weight_single_bit_flip(
    model = copy.deepcopy(model),
    batch_size = batch_size, 
    input_shape = [channels, img_size, img_size], 
    use_gpu = use_gpu,
    layer_types = layer_type,
    flip_bit_pos = custom_bit_flip_pos,
    save_log_list = save_detailed_log
)
# print(base_fi_model.print_pytorchfi_layer_summary())

In [None]:
# single bit flip을 수행할 layer 번호 정리
if 'all' in layer_nums:
    layer_nums = range(base_fi_model.get_total_layers())
else:
    layer_nums.sort()
    while layer_nums and layer_nums[-1] >= base_fi_model.get_total_layers():
        layer_nums.pop()

In [None]:
# 실험 진행
results = []
layer_name = []
misclassification_rate = []
detailed_log = []

# layer 순회
for layer_num in tqdm(layer_nums):

    # 우선 해당 레이어에 weight값이 있는지부터 확인
    try:
        layer, k, C, H, W = random_weight_location(base_fi_model, layer=layer_num)
    except:
        results.append(f"Layer # {layer_num} has no weight")
        continue

    orig_correct_cnt = 0
    orig_corrupt_diff_cnt = 0
    batch_idx = -1

    # batch 순회
    for images, labels in dataloader:
        batch_idx += 1
        if use_gpu:
            images = images.to(device='cuda')

        # 원본에 inference 진행
        model.eval()
        with torch.no_grad():
            orig_output = model(images)

        # fault injection 위치 선정
        layer, k, C, H, W = random_weight_location(base_fi_model, layer=layer_num)

        # corrupted model 만들기
        if save_detailed_log:
            base_fi_model.reset_log()

        corrupted_model = base_fi_model.declare_weight_fault_injection(
            function = base_fi_model.weight_flip_function,
            layer_num = layer,
            k = k,
            dim1 = C,
            dim2 = H,
            dim3 = W
        )

        if save_detailed_log:
            log = [
                f'Layer: {layer_num}',
                f'''Layer type: {str(base_fi_model.layers_type[layer_num]).split(".")[-1].split("'")[0]}''',
                f'Position: {k[0]}, {C[0]}, {H[0]}, {W[0]}',
                f'Original value:  {base_fi_model.log_original_value[0]}',
                f'Original binary: {base_fi_model.log_original_value_bin[0]}',
                f'Flip bit: {base_fi_model.log_bit_pos[0]}',
                f'Error value:     {base_fi_model.log_error_value[0]}',
                f'Error binary:    {base_fi_model.log_error_value_bin[0]}',
            ]

            detailed_log.append('\n'.join(log))

        # corrupted model에 inference 진행
        corrupted_model.eval()
        with torch.no_grad():
            corrupted_output = corrupted_model(images)

        # 결과 정리
        original_output = torch.argmax(orig_output, dim=1).cpu().numpy()
        corrupted_output = torch.argmax(corrupted_output, dim=1).cpu().numpy()
        
        # 결과 비교: 원본이 정답을 맞춘 경우 중 망가진 모델이 틀린 경우를 셈
        for i in range(batch_size):
            if labels[i] == original_output[i]:
                orig_correct_cnt += 1
                if original_output[i] != corrupted_output[i]:
                        orig_corrupt_diff_cnt += 1

                        if save_detailed_log:
                            detailed_log.append(f'Batch: {batch_idx}\nImage: {i}\nLabel: {labels[i]}\nModel output: {corrupted_output[i]}')

    # 결과 저장
    rate = orig_corrupt_diff_cnt / orig_correct_cnt * 100
    result = f'Layer #{layer_num}: {orig_corrupt_diff_cnt} / {orig_correct_cnt} = {rate:.4f}%, ' + str(base_fi_model.layers_type[layer_num]).split(".")[-1].split("'")[0]
    misclassification_rate.append(rate)
    layer_name.append(str(base_fi_model.layers_type[layer_num]).split(".")[-1].split("'")[0])
    results.append(result)

In [None]:
for result in results:
    print(result)

In [None]:
save_path = '/content/drive/MyDrive/' + '_'.join(['weight', model_name, f'batch{batch_size}', dataset, str(seed)])

f = open(save_path + ".txt", 'w')

f.write(base_fi_model.print_pytorchfi_layer_summary())
f.write(f'\n\n===== Result =====\nSeed: {seed}\n')
for result in results:
    f.write(result + '\n')

f.close()

if save_detailed_log:
    f = open(save_path + '_detailed.txt', 'w')
    for log in detailed_log:
        f.write(log + '\n\n')
    f.close()

data = pd.DataFrame({'name': layer_name, f'seed_{seed}': misclassification_rate})
data.to_csv(save_path + '.csv')