In [33]:
import torch
from tqdm import tqdm
from deepcase_copy.context_builder.context_builder import ContextBuilder
from deepcase_copy.interpreter.interpreter import Interpreter
from deepcase_copy.context_builder.loss import LabelSmoothing

def to_cuda(item):
    if torch.cuda.is_available():
        return item.to('cuda')
    return item

builder = to_cuda(ContextBuilder.load('save/builder.save'))
interpreter = Interpreter.load('save/interpreter.save', builder)
criterion = LabelSmoothing(builder.decoder_event.out.out_features, 0.1)    

with open('save/sequences.save', 'rb') as infile:
    data = torch.load(infile)
    events  = data["events"]
    context = data["context"]
    labels  = data["labels"]
    mapping = data["mapping"]
    
    events_train  = to_cuda(events [:events.shape[0]//5 ])
    events_test   = to_cuda(events [ events.shape[0]//5:])
    
    context_train = to_cuda(context[:events.shape[0]//5 ])
    context_test  = to_cuda(context[ events.shape[0]//5:])
    
    labels_train  = to_cuda(labels [:events.shape[0]//5 ])
    labels_test   = to_cuda(labels [ events.shape[0]//5:])

In [35]:
torch.manual_seed(42)

def disable_dropout(m):
    if isinstance(m, torch.nn.Dropout):
        m.p = 0.0  # Disable dropout by setting probability to 0

def predict_single_eval(index_inspected, iterations=1):
    
    local_context_picked = to_cuda(context_test[index_inspected:index_inspected+1].clone().detach())
    local_events_picked = to_cuda(events_test[index_inspected:index_inspected+1].clone().detach())
    con, e = local_context_picked[0], local_events_picked.unsqueeze(1)[0]
    print(f"con={con.tolist()}, e={e.tolist()}")
    con.resize_(1, con.size()[-1])
    con = builder.embedding_one_hot(con)
    builder.apply(disable_dropout)
    for _ in range(iterations):
        output = builder.forward(con, training=True)
        print(get_results(output)[1])
        loss = criterion.forward(output[0][0], e)
        loss.backward()

predict_single_eval(300, 5)

con=[66, 72, 66, 66, 78, 78, 72, 72, 18, 18], e=[18]
[18] 0.584, [72] 0.158, [64] 0.015
[18] 0.584, [72] 0.158, [64] 0.015
[18] 0.584, [72] 0.158, [64] 0.015
[18] 0.584, [72] 0.158, [64] 0.015
[18] 0.584, [72] 0.158, [64] 0.015


In [36]:
def get_unique_indices_per_row(tensor):
    indices_list = []
    row_list = []
    indices_list_set = set()
    
    for row in range(len(tensor)):
        curr = tuple(tensor[row].tolist())
        
        if curr in indices_list_set:
            continue
        
        row_list.append(curr)
        indices_list.append(row)
        indices_list_set.add(curr)
    
    return indices_list, row_list
indices, rows = get_unique_indices_per_row(context_test)

In [44]:
MAX_ITER = 100
PERTURB_THRESHOLD = 3

def to_one_hot(t):
    return builder.embedding_one_hot(t)

def max_to_one(tensor):
    max_indices = torch.argmax(tensor, dim=-1, keepdim=True)
    result = torch.zeros_like(tensor)
    result.scatter_(-1, max_indices, 1.0)
    return result

def get_results(results):
    results_picked = torch.topk(results[0][0][0], 3)
    exp = results_picked.values.exp()
    res_indices = results_picked.indices
    s = []
    for j in range(3):
        s.append(f"{format_list([res_indices[j].item()])} {'{:.3f}'.format(exp[j])}")
    return res_indices, ", ".join(s)

def compute_change(trace, original, epsilon=0.1):
    a = original - epsilon
    b = (trace >= a).float() * trace + (trace < a).float() * a
    c = (b > original + epsilon).float() * (original + epsilon) + (b <= original + epsilon).float() * b
    return max_to_one(c), c

def bim_attack(context_given, target_given, alpha=0.1, epsilon=0.1, num_iterations=MAX_ITER, training=True):
    change = None
    original_context = to_one_hot(context_given)
    context_processed = to_one_hot(context_given)
    changes = []
    computed_change_collected = []
    for i in range(num_iterations):
        context_processed.requires_grad_(True)
        output = builder.predict(context_processed, training=training)
        indices_of_results, prediction_str = get_results(output)
        changes.append({
            "changed_to": torch.argmax(context_processed, axis=-1).tolist()[0],
            "prediction_str": prediction_str,
        })
        if target_given[0] != indices_of_results[0]:
            break
        loss = criterion(output[0][0], target_given)
        context_processed.retain_grad()
        loss.backward(retain_graph=True)
        grad = context_processed.grad.sign()
        if change is None:
            change = alpha * grad
        else:
            change += alpha * grad
        context_processed = context_processed + change
        context_processed, computed_change = compute_change(context_processed, original_context, epsilon)
        computed_change_collected.append(computed_change)
    return changes, computed_change_collected

def count_changes(changes_needed):
    orig = changes_needed[0]["changed_to"]
    final = changes_needed[-1]["changed_to"]
    return count_list_diff(orig, final)

def count_list_diff(orig, final):
    changed_entries = 0
    for orig, final in zip(orig, final):
        if orig != final:
            changed_entries += 1
    return changed_entries
    
def show_changes(changes_needed):
    orig = changes_needed[0]["changed_to"]
    final = changes_needed[-1]["changed_to"]
    changes = []
    same = []
    for orig, final in zip(orig, final):
        if orig == final:
            changes.append("-")
            same.append(final)
        else:
            changes.append(final)
            same.append("XX")
    return format_list(changes), format_list(same)

def format_list(li):
    return f"[{", ".join([f'{num:2}' for num in li])}]"
            
def format_line(current_trace_num, changes_needed, i):
    start = " "*(len(str(current_trace_num)) + 2)
    if i != 0:
        if len(str(current_trace_num)) == 1:
            start = f'{i:<3}'
        elif len(str(current_trace_num)) == 2:
            start = f'{i:<4}'
        elif len(str(current_trace_num)) == 3:
            start = f'{i:<5}'
        elif len(str(current_trace_num)) == 4:
            start = f'{i:<6}'
        else:
            start = f'{i:<7}'
        b = list(start)
        b[len(str(i))] = '<'
        start = ''.join(b)
    return f"{start}{format_list(changes_needed[i]["changed_to"])} -> {changes_needed[i]['prediction_str']}\n"
            
def print_state(changes_needed, current_trace_num, con, event_chosen, change_collected, print_path=True, include_change=True, num_iterations=MAX_ITER):
    mode_int = 0
    changed_num = len(changes_needed)
    perturbations_num = count_changes(changes_needed)
    result_string = ""
    if changed_num == 1:
        pass
    elif changed_num == num_iterations:
        mode_int = 3
    else:
        mode_int = 1
        if perturbations_num <= PERTURB_THRESHOLD:
            mode_int = 2
            result_string += f"{current_trace_num}: {format_list(con[0].tolist())} == {event_chosen.tolist()} Changed {{{changed_num}}}, Perturbations {{{perturbations_num}}}\n"
            if print_path:
                for i in range(len(changes_needed)):
                    result_string += format_line(current_trace_num, changes_needed, i)                    
                    if include_change:
                        if i is not len(changes_needed) - 1:
                            result_string += f"{" "*(len(str(current_trace_num)) + 2)} {change_collected[i]=}\n"
            else:
                for i in range(len(changes_needed)):
                    if i == 0 or (changes_needed[i]["changed_to"] != changes_needed[i-1]["changed_to"] and i != len(changes_needed) - 1):
                        result_string += format_line(current_trace_num, changes_needed, i)
                change_last = changes_needed[-1]
                result_string += f"{" "*(len(str(current_trace_num)) + 2)}{format_list(change_last["changed_to"])} -> {change_last['prediction_str']}\n"
            changed_entries, same_entries = show_changes(changes_needed)
            result_string += f"{" "*(len(str(current_trace_num)) - 1)}== {same_entries}\n"
            result_string += f"{" "*(len(str(current_trace_num)) - 1)}-> {changed_entries}\n"
            result_string += "\n" 
    return mode_int, result_string

def process_traces(context_to_process, events_to_process, alpha=0.01, epsilon=0.5, num_iterations=100, print_path=False, include_change=False, write_to_file=False, print_result=False, training=True):
    perturbed_collected_main = []
    states = [0, 0, 0, 0]
    safe_to_file = ""
    iters = range(len(context_to_process))
    if not print_result:
        iters = tqdm(iters)
    for current_trace_num in iters:
        con, e = context_to_process[current_trace_num], events_to_process.unsqueeze(1)[current_trace_num]
        con.resize_(1, con.size()[-1])
        changes_needed, change_collected = bim_attack(context_given=con, target_given=e, alpha=alpha, epsilon=epsilon, num_iterations=num_iterations, training=training)
        mode_int, result_string = print_state(changes_needed, current_trace_num, con, e, change_collected, print_path=print_path, include_change=include_change, num_iterations=num_iterations)
        if print_result:
            print(result_string, end="")
        safe_to_file += result_string
        if mode_int == 2:
            perturbed_collected_main.append((current_trace_num, changes_needed[0]['changed_to'], changes_needed[-1]['changed_to']))
        states[mode_int] += 1
    print(f"incorrect={states[0]} changed={states[1]} perturbed={states[2]} timeout={states[3]}")
    safe_to_file += f"incorrect={states[0]} changed={states[1]} perturbed={states[2]} timeout={states[3]}"
    if write_to_file:
        with open(f"results_trace/length={l}, alpha={alpha}, epsilon={epsilon}, num_iterations={num_iterations}, print_path={print_path} include_change={include_change}.txt", "w") as f:
            f.write(safe_to_file)
    return perturbed_collected_main

def inspect_index(index_inspected, training=True):
    local_context_picked = to_cuda(context_test[indices][index_inspected:index_inspected+1].clone().detach())
    local_events_picked = to_cuda(events_test[indices][index_inspected:index_inspected+1].clone().detach())
    process_traces(local_context_picked, local_events_picked, alpha=0.01, epsilon=0.5, num_iterations=1000, print_path=True, include_change=False, write_to_file=True, print_result=True, training=training)  

In [45]:
l = 1000
chosen_index = 0
v_alpha = 0.01
v_epsilon = 0.5
v_num_iterations = 100
context_picked = to_cuda(context_test[indices][chosen_index:chosen_index+l].clone().detach())
events_picked = to_cuda(events_test[indices][chosen_index:chosen_index+l].clone().detach())
labels_picked = to_cuda(labels_test[indices][chosen_index:chosen_index+l].clone().detach())

In [47]:
torch.manual_seed(42)
perturbed_collected = process_traces(
    context_picked, 
    events_picked, 
    alpha=v_alpha, 
    epsilon=v_epsilon, 
    num_iterations=v_num_iterations, 
    print_path=False, 
    include_change=False, 
    write_to_file=True, 
    print_result=False, 
    training=True
)  

100%|██████████| 1000/1000 [02:42<00:00,  6.14it/s]

incorrect=345 changed=522 perturbed=57 timeout=76





In [48]:
import itertools

def get_changes_list(s, f):
    perturbations_made = []
    for i in range(len(s)):
        if s[i] != f[i]:
            perturbations_made.append((i, f[i]))
            
    return perturbations_made

def get_possible_combinations(perturbations_made):
    subsets = []
    for r in range(1, len(perturbations_made) + 1):
        subsets.extend(itertools.combinations(perturbations_made, r))
    result = [list(subset) for subset in subsets]
    return result

def get_minimum_change_for_perturbation(perturbed_chosen, index_in_list, training=True):
    i, s, f = perturbed_chosen[index_in_list]
    event_target = events_picked[i]
    combination_of_perturbation = get_possible_combinations(get_changes_list(s, f))
    for i in range(len(combination_of_perturbation)):
        combination = combination_of_perturbation[i]
        copy = to_cuda(torch.tensor(s).detach())
        for index_of_change, value_of_change in combination:
            copy[index_of_change] = value_of_change
        copy.resize_(1, copy.size()[-1])
        output = builder.predict(to_one_hot(copy), training=training)
        indices_of_results, _ = get_results(output)
        if event_target != indices_of_results[0]:
            return copy, f"Shortcut: {combination_of_perturbation[-1]} -> {combination}\n" if i == len(combination_of_perturbation) - 1 and len(combination_of_perturbation) != 1 else ""
    # print(f"ERROR: Could not find a perturbation for {index_in_list} [{i}]")
    # print(f"       Cannot go from {format_list(s)}")
    # print(f"                      {format_list(f)}")
    
    return None, False

def process_single(context_chosen, training=True):
    result_string = ""
    context_chosen = to_cuda(context_chosen)
    output = predict_single(context_chosen, training=training)
    attentions = [round(x, 5) for x in output[1][0][0].tolist()]
    indices_of_results, prediction_str = get_results(output)
    result_string += f"{format_list(context_chosen[0].tolist())} -> {prediction_str}\n"
    for c, a in zip(context_chosen[0], attentions):
        result_string += f"[{c:2}] {'{:.5f}'.format(a)} "
    result_string += "\n"
    return result_string
    
def predict_single(context_chosen, training=True):
    context_chosen.resize_(1, context_chosen.size()[-1])
    context_one_hot = to_one_hot(context_chosen)
    return builder.predict(context_one_hot, training=training)

def predict_single_from_list(list_picked, index_picked):
    return predict_single(to_cuda(list_picked[index_picked:index_picked+1].detach()))

def analysis(perturbed_chosen, index_picked, training=True):
    result_string = ""
    i, s, f = perturbed_chosen[index_picked]
    minimum_change_for_perturbation, final_combination = get_minimum_change_for_perturbation(perturbed_chosen, index_picked, training=training)
    if minimum_change_for_perturbation is not None:
        result_string += f"Analyzing [{i}], Perturbations [{count_list_diff(s, f)}], Index [{index_picked}]\n"
        result_string += process_single(torch.tensor(s), training=training)
        result_string += process_single(minimum_change_for_perturbation, training=training)
        result_string += final_combination
    else:
        result_string += f"Change did not work on [{i}], Index [{index_picked}]\n"
    result_string += "\n"
    return result_string

def store(perturbed_chosen, training=True):
    with open(f"results_attention/length={l}, alpha={v_alpha}, epsilon={v_epsilon}, num_iterations={v_num_iterations}.txt", "w") as file:
        safe_to_file = ""
        skipped = 0
        shortcuts = 0
        for perturbed_element in range(len(perturbed_chosen)):
            r_string = analysis(perturbed_chosen, perturbed_element, training=training)
            if "Change did not work on" in r_string:
                skipped += 1
            if "Shortcut" in r_string:
                shortcuts += 1
            print(r_string, end="")
            safe_to_file += r_string
        skipped_str = f"Results: {skipped=}/{shortcuts=} out of {len(perturbed_chosen)}"
        print(skipped_str)
        safe_to_file += skipped_str
        file.write(safe_to_file)

In [6]:
torch.manual_seed(42)

store(perturbed_collected, training=True)

Results: skipped=0/shortcuts=0 out of 0


In [7]:
def interpret():
    for c, e in zip(context_test, events_test):
        cont = to_one_hot(c)
        cont = to_cuda(cont.unsqueeze(0))
        even = to_cuda(e.reshape(-1, 1))
        print(interpreter.predict(X=cont, y=even, verbose=True))
        
interpret()

Optimizing query: 100%|██████████| 100/100 [00:01<00:00, 75.25it/s]
Predicting      : 100%|██████████| 1/1 [00:00<00:00, 251.85it/s]
Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 612.99it/s]
Predicting      : 100%|██████████| 1/1 [00:00<00:00, 483.83it/s]
Optimizing query:   0%|          | 0/100 [00:00<?, ?it/s]

[3.]
[3.]


Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 529.46it/s]
Predicting      : 100%|██████████| 1/1 [00:00<00:00, 786.78it/s]
Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 585.66it/s]
Predicting      : 100%|██████████| 1/1 [00:00<00:00, 644.39it/s]
Optimizing query:   0%|          | 0/100 [00:00<?, ?it/s]

[3.]
[3.]


Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 457.73it/s]
Predicting      : 100%|██████████| 1/1 [00:00<00:00, 460.71it/s]
Optimizing query:  26%|██▌       | 26/100 [00:00<00:00, 251.63it/s]

[3.]


Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 321.84it/s]
Predicting      : 100%|██████████| 1/1 [00:00<00:00, 976.10it/s]
Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 575.50it/s]


[3.]


IndexError: list index out of range