In [134]:
import re
import torch
import random
import numpy as np
from copy import copy, deepcopy
from adversarial_attack.fgsm import FGSM
from adversarial_attack.pgd import PGD
from adversarial_attack.bim import BIM
from adversarial_attack.utils import compute_accuracy, compute_confusion_matrix

#### Create the HT model and import pretrained weights

In [2]:
import pickle
from dnn_model.CNN_netlist_softma_save_resluts import Classifier_Netlist

with open('save/source_config.pkl', 'rb') as pickle_file:
    source_config_copy = pickle.load(pickle_file)

path = '/home/erastus/Desktop/Postdoc_projects/Adv-TruDetect/json_temp_file/word2vec_emb/CNN_model_pretrained.pth'
HTnn_net = Classifier_Netlist(group_id=str(2), base_path='json_temp_file', source_config=source_config_copy, pretrained=path)   

Chunking...0/3
Read 1M Feature Traces.
Chunking...1/3
Read 2M Feature Traces.
Chunking...2/3
Read 3M Feature Traces.
Chunking...Last
Read 4M Feature Traces.
Iterable Dataset Loaded...
===TEST===
Chunking...0/1
Chunking...Last
Read 1M Feature Traces.
Iterable Dataset Loaded...
Model loaded


In [235]:
from torch.utils.data import DataLoader
from typing import List, Any

def get_all_text_labels(dataloader: DataLoader) -> List[str]:
    text_labels = []
    for batch in dataloader:
        text_label = batch[2]  
        text_labels.extend(text_label)
    return list(set(text_labels))  # Remove duplicates

def get_samples_by_text_label(dataloader: DataLoader, target_text: str) -> List[Any]:
    matching_samples = []
    for batch in dataloader:
        data, class_label, text_label = batch 
        for i, label in enumerate(text_label):
            if label == target_text:
                matching_samples.append((data[i], class_label[i]))
    return matching_samples

def get_cmp_by_emb(dictionary, value):
    for key, val in dictionary.items():
        if val == value:
            return key
    return None  

def get_emb_by_cmp(dictionary, value):
    for key, val in dictionary.items():
        if key == value:
            return val
    return None 

def get_all_embeddings(approx_list_pcp):
    embds = []
    for i in range(len(approx_list_pcp)):
        n_array = [get_emb_by_cmp(HTnn_net.val_data.word2vec_dict, elem) for elem in approx_list_pcp[i]]
        embd = torch.unsqueeze(torch.from_numpy(np.array(n_array)), 0)
        embds.append(embd)
    return embds

def approximation_error(orig_pcp_list, approx_pcp_list):
    error_maps = {  'i':   {'i': 0, 'and':.75, 'nnd': .25, 'or': .75, 'nor':.25, 'xor':.5, 'xnr':.5}, 
                    'and': {'i':.75, 'and':0, 'nnd': 1., 'or': .5, 'nor':.5, 'xor':.75, 'xnr':.25 }, 
                    'nnd': {'and':1., 'nnd':0, 'i': .25, 'or': .5, 'nor':.5, 'xor':.25, 'xnr':.75 }, 
                    'or':  {'and':.5, 'nnd': .5, 'i': 0.75, 'or':0, 'nor':1., 'xor':.25, 'xnr':.75 }, 
                    'nor':  {'and':.5, 'nnd': .5, 'i': 0.25, 'or':1., 'nor':0, 'xor':.75, 'xnr':.25 }, 
                    'xor':  {'and':.75, 'nnd': .25, 'i': 0.5, 'or':.25, 'nor':.75, 'xor':0, 'xnr':1. }, 
                    'xnr':  {'and':.25, 'nnd': .75, 'i': 0.5, 'or':.75, 'nor':.25, 'xor':1., 'xnr':0 }   }
    
    final_error = 0 
    for i in range(len(orig_pcp_list)):
        orig_pcp, approx_pcp = orig_pcp_list[i], approx_pcp_list[i]
        error = 0
        for j in range(len(orig_pcp)):
            orig, approx = orig_pcp[j], approx_pcp[j]
            orig_op, _= separate_letters_numbers(orig.split('_')[1])
            approx_op, _= separate_letters_numbers(approx.split('_')[1])
            if orig_op in ['i', 'and', 'nnd', 'or', 'nor', 'xor', 'xnor']:  
                error += error_maps[orig_op][approx_op]
        
        final_error += error/len(orig_pcp)
    
    return final_error/len(orig_pcp_list)


def detect_score(HT_model, approx_pcp_list):
    input_data = torch.stack(approx_pcp_list, dim=0).to(HT_model.device)
    out = HTnn_net.model(input_data)
    _, pred = torch.max(out, 1) 
    return pred.sum().item()/len(approx_pcp_list)

def separate_letters_numbers(input_string):
    letters = ''.join(re.findall(r'[a-zA-Z]', input_string))
    numbers = ''.join(re.findall(r'\d', input_string))
    return letters, numbers

def mutate_pcp_word(pcp_word, dict):
    mutate_list = {'i':   ['nnd', 'nor', 'xor', 'xnr'],
                   'and': ['or', 'nor', 'xnr'],
                   'nnd': ['or', 'nor', 'xor'],
                   'or':  ['and', 'nnd', 'xor'],
                   'nor': ['and', 'nnd', 'xnr'],
                   'xor': ['nnd', 'or'],
                   'xnr': ['and', 'nor']}

    splits = pcp_word.split('_')
    op, num = separate_letters_numbers(splits[1])
    if op in ['i', 'and', 'nnd', 'or', 'nor', 'xor', 'xnor']:
        new_word = splits[0]+'_'+random.choice(mutate_list[op])+str(num)+'_'+splits[2]
        if get_emb_by_cmp(dict, new_word) != None:
            return new_word
    return pcp_word

def mutate_pcp_list(orig_list, n_changes=1, p=0.5, dict=HTnn_net.val_data.word2vec_dict):
    orig_pcp_copy = deepcopy(orig_list)
    for i in range(len(orig_pcp_copy)):
        if random.random() > p:
            pcp_word = orig_pcp_copy[i][random.randint(0, len(orig_pcp_copy[i])-1)]
            new_pcp_word = mutate_pcp_word(pcp_word, dict)
            
            for j in range(len(orig_pcp_copy)): # Apply changes to all pcps that share the word
                for k in range(len(orig_pcp_copy[j])):
                    if orig_pcp_copy[j][k].split('_')[1] == pcp_word.split('_')[1]:
                        orig_pcp_copy[j][k] = orig_pcp_copy[j][k].split('_')[0]+'_'+new_pcp_word.split('_')[1]+'_'+orig_pcp_copy[j][k].split('_')[2]

            n_changes -= 1
        if n_changes == 0:
            break
    return orig_pcp_copy

### Pick a trojan circuit

In [236]:
all_text_labels = get_all_text_labels(HTnn_net.val_dataloader)
trojan_comps_labels = []
for elem in all_text_labels:
    if elem.startswith("t"):
        trojan_comps_labels.append(elem)

trojan_comp = trojan_comps_labels[0] ## Choose one trojna circuit (automate later)
pcp_embs = get_samples_by_text_label(HTnn_net.val_dataloader, trojan_comp)

### Get all the PCP embeddings and componenets names for the selected trojan circuit

In [238]:
all_embds, all_cmps, all_labels = [], [], []
for pcp_emb in pcp_embs:
    p_emb, label = pcp_emb
    full_pcp_cmp = []
    for i in range(5):
        name = get_cmp_by_emb(HTnn_net.val_data.word2vec_dict, list(np.float32(p_emb[i])))
        full_pcp_cmp.append(name)
    
    all_labels.append(label.item())
    all_embds.append(p_emb)
    all_cmps.append(full_pcp_cmp)

all_labels

[1, 1, 1, 1]

### Define a genetic search algorithm to find an optimal tradeoff between approx-error and HT-detect-score

In [249]:
def genetic_search(HTnn_net, orig_pcp_list, population_size, generations):

    population = [mutate_pcp_list(deepcopy(orig_pcp_list), n_changes=random.randint(1, 5)) for _ in range(population_size)]  
    for _ in range(generations):
        
        fitness_scores, appx_errs, dect_errs = [], [], []
        for i in range(len(population)):
            appx_err = approximation_error(deepcopy(orig_pcp_list), deepcopy(population[i]))
            dect_err = detect_score(HTnn_net, get_all_embeddings(deepcopy(population[i])))
            appx_errs.append(appx_err)
            dect_errs.append(dect_err)
            fitness_scores.append(appx_err+dect_err)
        
        print(appx_errs, dect_errs)
        parents = random.choices(population, weights=fitness_scores, k=population_size)   
        new_population = []
        for i in range(0, population_size, 2):
            parent1, parent2 = parents[i], parents[i+1]    
            child1 = mutate_pcp_list(deepcopy(parent1))
            child2 = mutate_pcp_list(deepcopy(parent2))
            new_population.extend([child1, child2])
        
        population = new_population

### Run the genetic search on the trojan circuit

In [250]:
genetic_search(HTnn_net, deepcopy(all_cmps), population_size=8, generations=10)

[0.1375, 0.0, 0.175, 0.0, 0.1, 0.0, 0.0, 0.0] [0.0, 1.0, 0.0, 1.0, 0.75, 1.0, 1.0, 1.0]
[0.175, 0.1, 0.0, 0.1, 0.1, 0.1875, 0.0, 0.0] [0.0, 0.75, 1.0, 0.75, 0.0, 0.0, 1.0, 1.0]
[0.1, 0.0, 0.05, 0.15, 0.0, 0.05, 0.1, 0.0] [0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 0.75, 1.0]
[0.15, 0.32499999999999996, 0.0, 0.2, 0.15, 0.05, 0.1, 0.0] [1.0, 0.0, 1.0, 0.75, 0.0, 1.0, 0.0, 1.0]
[0.1, 0.15, 0.15, 0.15, 0.05, 0.0, 0.32499999999999996, 0.2375] [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]
[0.1, 0.05, 0.15, 0.175, 0.175, 0.15, 0.1, 0.15] [1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.75, 0.0]
[0.2, 0.32499999999999996, 0.225, 0.1, 0.15, 0.275, 0.2375, 0.05] [0.75, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]
[0.275, 0.32499999999999996, 0.1875, 0.05, 0.1375, 0.2, 0.225, 0.275] [0.0, 0.0, 0.0, 1.0, 0.0, 0.75, 0.0, 0.0]
[0.1375, 0.2, 0.5, 0.2, 0.225, 0.05, 0.375, 0.05] [0.0, 0.75, 0.0, 0.75, 0.0, 1.0, 0.0, 1.0]
[0.05, 0.5, 0.3, 0.05, 0.525, 0.05, 0.475, 0.375] [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0]
