In [337]:
import torch
from deepcase_copy.context_builder.context_builder import ContextBuilder
from deepcase_copy.interpreter.interpreter import Interpreter

def to_cuda(item):
    if torch.cuda.is_available():
        return item.to('cuda')
    return item

builder = to_cuda(ContextBuilder.load('save/builder.save'))
interpreter = Interpreter.load('save/interpreter.save', builder)

with open('save/sequences.save', 'rb') as infile:
    data = torch.load(infile)
    events  = data["events"]
    context = data["context"]
    labels  = data["labels"]
    mapping = data["mapping"]
    
    events_train  = to_cuda(events [:events.shape[0]//5 ])
    events_test   = to_cuda(events [ events.shape[0]//5:])
    
    context_train = to_cuda(context[:events.shape[0]//5 ])
    context_test  = to_cuda(context[ events.shape[0]//5:])
    
    labels_train  = to_cuda(labels [:events.shape[0]//5 ])
    labels_test   = to_cuda(labels [ events.shape[0]//5:])

In [338]:
def get_unique_indices_per_row(tensor):
    indices_list = []
    row_list = []
    indices_list_set = set()
    
    for row in range(len(tensor)):
        curr = tuple(tensor[row].tolist())
        
        if curr in indices_list_set:
            continue
        
        row_list.append(curr)
        indices_list.append(row)
        indices_list_set.add(curr)
    
    return indices_list, row_list
indices, rows = get_unique_indices_per_row(context_test)

In [339]:
from deepcase_copy.context_builder.loss import LabelSmoothing

MAX_ITER = 100

def max_to_one(tensor):
    max_indices = torch.argmax(tensor, dim=-1, keepdim=True)
    result = torch.zeros_like(tensor)
    result.scatter_(-1, max_indices, 1.0)
    return result

def get_results(results):
    results_picked = torch.topk(results[0][0][0], 3)
    exp = results_picked.values.exp()
    res_indices = results_picked.indices
    s = []
    for j in range(3):
        s.append(f"{format_list([res_indices[j].item()])} {'{:.3f}'.format(exp[j])}")
    return res_indices, ", ".join(s)

def compute_change(trace, original, epsilon=0.1):
    a = torch.clamp(original - epsilon, min=0)
    b = (trace >= a).float() * trace + (trace < a).float() * a
    c = (b > original + epsilon).float() * (original + epsilon) + (b <= original + epsilon).float() * b
    return max_to_one(c)

def bim_attack(context_given, target_given, alpha=0.1, epsilon=0.1, num_iterations=MAX_ITER):
    change = None
    original_context = to_one_hot(context_given)
    context_processed = to_one_hot(context_given)
    criterion = LabelSmoothing(builder.decoder_event.out.out_features, 0.1)    
    changes = []
    for i in range(num_iterations):
        context_processed.requires_grad_(True)
        output = builder.predict(context_processed)
        indices_of_results, prediction_str = get_results(output)
        changes.append({
            "changed_to": torch.argmax(context_processed, axis=-1).tolist()[0],
            "prediction_str": prediction_str,
        })
        if target_given[0] != indices_of_results[0]:
            break
        loss = criterion(output[0][0], target_given)
        context_processed.retain_grad()
        loss.backward(retain_graph=True)
        grad = context_processed.grad.sign()
        if change is None:
            change = alpha * grad
        else:
            change += alpha * grad
        context_processed = context_processed + change
        context_processed = compute_change(context_processed, original_context, epsilon)
    return changes

def count_changes(changes_needed):
    orig = changes_needed[0]["changed_to"]
    final = changes_needed[-1]["changed_to"]
    return count_list_diff(orig, final)

def count_list_diff(orig, final):
    changed_entries = 0
    for orig, final in zip(orig, final):
        if orig != final:
            changed_entries += 1
    return changed_entries
    
def show_changes(changes_needed):
    orig = changes_needed[0]["changed_to"]
    final = changes_needed[-1]["changed_to"]
    changes = []
    same = []
    for orig, final in zip(orig, final):
        if orig == final:
            changes.append("-")
            same.append(final)
        else:
            changes.append(final)
            same.append("XX")
    return format_list(changes), format_list(same)

def format_list(li):
    return f"[{", ".join([f'{num:2}' for num in li])}]"
            
def print_state(changes_needed, current_trace_num, con, event_chosen, print_path=True, num_iterations=MAX_ITER):
    mode_int = 0
    changed_num = len(changes_needed)
    perturbations_num = count_changes(changes_needed)
    result_string = ""
    if changed_num == 1:
        pass
    elif changed_num == num_iterations:
        mode_int = 3
    else:
        mode_int = 1
        if perturbations_num <= 3:
            mode_int = 2
            result_string += f"{current_trace_num}: {format_list(con[0].tolist())} == {event_chosen.tolist()} Changed {{{changed_num}}}, Perturbations {{{perturbations_num}}}\n"
            if print_path:
                for change in changes_needed:
                    result_string += f"{" "*(len(str(current_trace_num)) + 2)}{format_list(change["changed_to"])} -> {change['prediction_str']}\n"
            else:
                change_last = changes_needed[-1]
                result_string += f"{" "*(len(str(current_trace_num)) + 2)}{format_list(change_last["changed_to"])} -> {change_last['prediction_str']}\n"
            changed_entries, same_entries = show_changes(changes_needed)
            result_string += f"{" "*(len(str(current_trace_num)) - 1)}== {same_entries}\n"
            result_string += f"{" "*(len(str(current_trace_num)) - 1)}-> {changed_entries}\n"
            result_string += "\n" 
    return mode_int, result_string

def process_traces(alpha=0.01, epsilon=0.5, num_iterations=100, print_path=False):
    perturbed_collected_main = []
    states = [0, 0, 0, 0]
    safe_to_file = ""
    for current_trace_num in range(len(context_picked)):
        con, e = context_picked[current_trace_num], events_picked.unsqueeze(1)[current_trace_num]
        con.resize_(1, con.size()[-1])
        changes_needed = bim_attack(context_given=con, target_given=e, alpha=alpha, epsilon=epsilon, num_iterations=num_iterations)
        mode_int, result_string = print_state(changes_needed, current_trace_num, con, e, print_path=print_path, num_iterations=num_iterations)
        print(result_string, end="")
        safe_to_file += result_string
        if mode_int == 2:
            perturbed_collected_main.append((current_trace_num, changes_needed[0]['changed_to'], changes_needed[-1]['changed_to']))
        states[mode_int] += 1
    print(f"incorrect={states[0]} changed={states[1]} perturbed={states[2]} timeout={states[3]}")
    safe_to_file += f"incorrect={states[0]} changed={states[1]} perturbed={states[2]} timeout={states[3]}"
    with open(f"results_trace/length={l}, alpha={alpha}, epsilon={epsilon}, num_iterations={num_iterations}, print_path={print_path}.txt", "w") as f:
        f.write(safe_to_file)
    return perturbed_collected_main

In [340]:
l = 100
chosen_index = 0
context_picked = to_cuda(context_test[indices][chosen_index:chosen_index+l].clone().detach())
events_picked = to_cuda(events_test[indices][chosen_index:chosen_index+l].clone().detach())
labels_picked = to_cuda(labels_test[indices][chosen_index:chosen_index+l].clone().detach())

In [341]:
v_alpha=0.01
v_epsilon=0.5
v_num_iterations=100
v_print_path=False

In [342]:
torch.manual_seed(42)
perturbed_collected = process_traces(alpha=v_alpha, epsilon=v_epsilon, num_iterations=v_num_iterations, print_path=v_print_path)    

3: [86, 87, 86, 87, 66, 66, 42, 42, 42, 72] == [72] Changed {52}, Perturbations {2}
   [86, 87, 86, 87, 66, 66,  1, 42, 42,  1] -> [ 1] 0.234, [15] 0.118, [72] 0.027
== [86, 87, 86, 87, 66, 66, XX, 42, 42, XX]
-> [- , - , - , - , - , - ,  1, - , - ,  1]

5: [87, 86, 87, 66, 66, 42, 42, 42, 72, 72] == [72] Changed {52}, Perturbations {2}
   [87, 86, 87, 66, 66, 42,  1, 42, 72,  1] -> [ 1] 0.234, [15] 0.118, [72] 0.037
== [87, 86, 87, 66, 66, 42, XX, 42, 72, XX]
-> [- , - , - , - , - , - ,  1, - , - ,  1]

12: [71, 71, 71, 71, 71, 71, 71, 71, 71, 76] == [79] Changed {52}, Perturbations {2}
    [71, 71, 71, 71, 71, 71,  3, 71, 71,  4] -> [71] 0.215, [64] 0.078, [72] 0.072
 == [71, 71, 71, 71, 71, 71, XX, 71, 71, XX]
 -> [- , - , - , - , - , - ,  3, - , - ,  4]

29: [64, 64, 64, 64, 64, 57, 64, 72, 86, 87] == [86] Changed {52}, Perturbations {3}
    [64, 64, 64, 64, 64,  0,  0,  0, 86, 87] -> [66] 0.373, [86] 0.274, [87] 0.216
 == [64, 64, 64, 64, 64, XX, XX, XX, 86, 87]
 -> [- , - , - , -

In [358]:
torch.manual_seed(42)
def weird_case():
    _, _, final_item = perturbed_collected[5]
    final_item = to_cuda(torch.tensor(final_item).detach())
    final_item.resize_(1, final_item.size()[-1])
    output = builder.predict(to_one_hot(final_item))
    indices_of_results, print_results = get_results(output)
    print(print_results)

weird_case()

[72] 0.393, [64] 0.385, [57] 0.021


In [344]:
import itertools

def to_one_hot(t):
    return builder.embedding_one_hot(t)

def get_changes_list(s, f):
    perturbations_made = []
    for i in range(len(s)):
        if s[i] != f[i]:
            perturbations_made.append((i, f[i]))
            
    return perturbations_made

def get_possible_combinations(perturbations_made):
    subsets = []
    for r in range(1, len(perturbations_made) + 1):
        subsets.extend(itertools.combinations(perturbations_made, r))
    result = [list(subset) for subset in subsets]
    return result

def get_minimum_change_for_perturbation(perturbed_chosen, index_in_list):
    i, s, f = perturbed_chosen[index_in_list]
    event_target = events_picked[i]
    combination_of_perturbation = get_possible_combinations(get_changes_list(s, f))
    for combination in combination_of_perturbation:
        copy = to_cuda(torch.tensor(s).detach())
        for index_of_change, value_of_change in combination:
            copy[index_of_change] = value_of_change
        copy.resize_(1, copy.size()[-1])
        output = builder.predict(to_one_hot(copy))
        indices_of_results, _ = get_results(output)
        if event_target != indices_of_results[0]:
            return copy
    # print(f"ERROR: Could not find a perturbation for {index_in_list} [{i}]")
    # print(f"       Cannot go from {format_list(s)}")
    # print(f"                      {format_list(f)}")
    
    return None

def process_single(context_chosen):
    result_string = ""
    context_chosen = to_cuda(context_chosen)
    output = predict_single(context_chosen)
    attentions = [round(x, 5) for x in output[1][0][0].tolist()]
    indices_of_results, prediction_str = get_results(output)
    result_string += f"{format_list(context_chosen[0].tolist())} -> {prediction_str}\n"
    for c, a in zip(context_chosen[0], attentions):
        result_string += f"[{c:2}] {'{:.5f}'.format(a)} "
    result_string += "\n"
    return result_string

def predict_single(context_chosen):
    context_chosen.resize_(1, context_chosen.size()[-1])
    context_one_hot = to_one_hot(context_chosen)
    return builder.predict(context_one_hot)

def predict_single_from_list(list_picked, index_picked):
    return predict_single(to_cuda(list_picked[index_picked:index_picked+1].detach()))

def analysis(perturbed_chosen, index_picked):
    result_string = ""
    i, s, f = perturbed_chosen[index_picked]
    minimum_change_for_perturbation = get_minimum_change_for_perturbation(perturbed_chosen, index_picked)
    if minimum_change_for_perturbation is not None:
        result_string += f"Analyzing [{i}], Perturbations [{count_list_diff(s, f)}], Index [{index_picked}]\n"
        result_string += process_single(torch.tensor(s))
        result_string += process_single(minimum_change_for_perturbation)
    else:
        result_string += f"Change did not work on [{i}], Index [{index_picked}]"
    result_string += "\n"
    return result_string

def store(perturbed_chosen):
    with open(f"results_attention/length={l}, alpha={v_alpha}, epsilon={v_epsilon}, num_iterations={v_num_iterations}.txt", "w") as file:
        safe_to_file = ""
        skipped = 0
        for perturbed_element in range(len(perturbed_chosen)):
            r_string = analysis(perturbed_chosen, perturbed_element)
            if "Change did not work on" in r_string:
                skipped += 1
            print(r_string, end="")
            safe_to_file += r_string
        skipped_str = f"Had to skip {skipped} out of {len(perturbed_chosen)}"
        print(skipped_str)
        safe_to_file += skipped_str
        file.write(safe_to_file)

In [345]:
torch.manual_seed(42)
predict_single_from_list(context_test, 1)

(tensor([[[-6.8765, -6.8174, -7.1251, -7.1295, -6.9701, -7.1554, -6.8575,
           -6.8075, -6.5020, -7.2721, -7.0917, -6.9263, -7.0404, -7.0070,
           -6.9872, -7.0126, -7.1001, -6.3215, -6.9414, -6.7841, -6.9150,
           -7.1430, -7.0604, -7.0142, -6.6135, -6.7930, -6.6675, -7.0499,
           -6.8515, -6.1267, -6.7921, -6.7788, -6.8685, -6.9623, -7.1139,
           -6.5269, -6.6200, -6.9810, -6.9283, -7.0452, -6.9250, -6.8681,
           -6.7176, -6.8632, -6.1814, -6.8988, -7.0167, -7.0005, -6.8409,
           -5.3759, -6.8490, -6.8508, -6.9967, -6.8919, -6.9691, -6.9846,
           -6.6954, -5.1534, -6.2868, -6.8769, -6.7292, -7.0326, -6.9670,
           -6.5998, -3.0556, -7.0751, -7.0083, -6.8997, -7.0593, -6.7141,
           -7.1769, -0.1969, -3.8894, -6.8459, -7.2636, -7.4133, -5.9543,
           -5.5674, -6.9376, -5.6622, -5.4744, -7.3986, -6.9219, -6.5993,
           -6.0573, -6.4426, -6.5565, -6.6639, -7.2940, -7.0215]]],
        device='cuda:0', grad_fn=<StackBackw

In [346]:
torch.manual_seed(42)
store(perturbed_collected)

Analyzing [3], Perturbations [2], Index [0]
[86, 87, 86, 87, 66, 66, 42, 42, 42, 72] -> [72] 0.509, [42] 0.220, [78] 0.023
[86] 0.01876 [87] 0.01669 [86] 0.01789 [87] 0.02370 [66] 0.02419 [66] 0.04389 [42] 0.02034 [42] 0.14328 [42] 0.13589 [72] 0.55537 
[86, 87, 86, 87, 66, 66, 42, 42, 42,  1] -> [ 1] 0.233, [15] 0.118, [72] 0.027
[86] 0.00889 [87] 0.00745 [86] 0.01062 [87] 0.01018 [66] 0.01244 [66] 0.01842 [42] 0.00890 [42] 0.05713 [42] 0.04455 [ 1] 0.82143 

Analyzing [5], Perturbations [2], Index [1]
[87, 86, 87, 66, 66, 42, 42, 42, 72, 72] -> [72] 0.497, [42] 0.227, [78] 0.023
[87] 0.03032 [86] 0.02640 [87] 0.02328 [66] 0.03801 [66] 0.03201 [42] 0.07454 [42] 0.02623 [42] 0.20591 [72] 0.14446 [72] 0.39884 
[87, 86, 87, 66, 66, 42, 42, 42, 72,  1] -> [ 1] 0.225, [15] 0.117, [72] 0.038
[87] 0.01204 [86] 0.00933 [87] 0.01156 [66] 0.01284 [66] 0.01294 [42] 0.02377 [42] 0.00893 [42] 0.05477 [72] 0.04291 [ 1] 0.81091 

Analyzing [12], Perturbations [2], Index [2]
[71, 71, 71, 71, 71, 71, 

In [347]:
def interpret():
    for c, e in zip(context_test, events_test):
        cont = to_one_hot(c)
        cont = to_cuda(cont.unsqueeze(0))
        even = to_cuda(e.reshape(-1, 1))
        print(interpreter.predict(X=cont, y=even, verbose=True))
        
interpret()


Optimizing query:   0%|          | 0/100 [00:00<?, ?it/s][A
Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 865.78it/s][A

Predicting      : 100%|██████████| 1/1 [00:00<00:00, 621.93it/s]

Optimizing query:   0%|          | 0/100 [00:00<?, ?it/s][A
Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 931.15it/s][A

Predicting      : 100%|██████████| 1/1 [00:00<00:00, 597.14it/s]

Optimizing query:   0%|          | 0/100 [00:00<?, ?it/s][A

[3.]
[3.]



Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 936.21it/s][A

Predicting      : 100%|██████████| 1/1 [00:00<00:00, 770.59it/s]

Optimizing query:   0%|          | 0/100 [00:00<?, ?it/s][A
Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 934.16it/s][A

Predicting      : 100%|██████████| 1/1 [00:00<00:00, 1155.77it/s]

Optimizing query:   0%|          | 0/100 [00:00<?, ?it/s][A

[3.]
[3.]



Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 933.13it/s][A

Predicting      : 100%|██████████| 1/1 [00:00<00:00, 756.68it/s]

Optimizing query:   0%|          | 0/100 [00:00<?, ?it/s][A
Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 928.32it/s][A

Predicting      : 100%|██████████| 1/1 [00:00<00:00, 754.10it/s]

Optimizing query:   0%|          | 0/100 [00:00<?, ?it/s][A

[3.]
[3.]



Optimizing query: 100%|██████████| 100/100 [00:00<00:00, 901.03it/s][A


IndexError: list index out of range