In [358]:
import os
import torch
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
from deepcase_copy.context_builder.loss import LabelSmoothing
from deepcase_copy.context_builder.context_builder import ContextBuilder
from deepcase_copy.interpreter.interpreter import Interpreter
from deepcase_copy.interpreter.utils import group_by, sp_unique
from sklearn.model_selection import train_test_split

def disable_dropout(m):
    if isinstance(m, torch.nn.Dropout):
        m.p = 0.0 

def to_cuda(item):
    if torch.cuda.is_available():
        return item.to('cuda')
    return item

def get_unique_indices_per_row(tensor):
    indices_list = []
    row_list = []
    indices_list_set = set()
    for row in range(len(tensor)):
        curr = tuple(tensor[row])
        if isinstance(tensor, torch.Tensor):
            curr = tuple(tensor[row].tolist())
        if curr in indices_list_set:
            continue
        row_list.append(curr)
        indices_list.append(row)
        indices_list_set.add(curr)
    return indices_list

def get_unique_indices_per_row_or_file(tensor, f_name):
    if os.path.exists(f_name):
        return torch.load(f_name)
    t = get_unique_indices_per_row(tensor)
    torch.save(t, f_name)
    return t  

ALPHA=0.01
EPSILON=0.5
MAX_ITER = 100
PREDICT_THRESHOLD = 0.2
PERTURB_THRESHOLD = 10
builder = to_cuda(ContextBuilder.load('save/builder.save'))
builder.apply(disable_dropout)
interpreter = Interpreter.load('save/interpreter.save', builder)
criterion = LabelSmoothing(builder.decoder_event.out.out_features, 0.1) 

with open('save/sequences.save', 'rb') as infile:
    data = torch.load(infile)
    context = data["context"] # 172572
    events  = data["events"]
    labels  = data["labels"]
    
    indices = get_unique_indices_per_row_or_file(context, 'save/context.pt')    
    train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42, stratify=labels[indices])
    
    context_train = to_cuda(context[train_indices]) #
    context_test  = to_cuda(context[test_indices])  # 4392
    
    events_train  = to_cuda(events[train_indices])
    events_test   = to_cuda(events[test_indices])
    
    labels_train  = to_cuda(labels[train_indices])
    labels_test   = to_cuda(labels[test_indices])

def to_one_hot(t):
    return builder.embedding_one_hot(t)

def max_to_one(tensor):
    max_indices = torch.argmax(tensor, dim=-1, keepdim=True)
    result = torch.zeros_like(tensor)
    result.scatter_(-1, max_indices, 1.0)
    return result

def get_results(results):
    results_picked = torch.topk(results[0][0][0], 3)
    exp = results_picked.values.exp()
    res_indices = results_picked.indices
    s = []
    for j in range(3):
        s.append(f"{format_list([res_indices[j].item()])} {'{:.3f}'.format(exp[j])}")
    return res_indices, ", ".join(s)

def compute_change(trace, original):
    a = original - EPSILON
    b = (trace >= a).float() * trace + (trace < a).float() * a
    c = (b > original + EPSILON).float() * (original + EPSILON) + (b <= original + EPSILON).float() * b
    return max_to_one(c), c

def get_performance(results, num):
    output = builder.predict(to_one_hot(results))
    results_picked = torch.topk(output[0][0][0], len(output[0][0][0]))
    exp = results_picked.values.exp()
    res_indices = results_picked.indices
    for j in range(len(output[0][0][0])):
        if res_indices[j].item() == num:
            return exp[j]
    return 0

def bim_attack(context_given, target_given):
    change = None
    original_context = to_cuda(to_one_hot(context_given).clone().detach())
    context_processed = to_cuda(to_one_hot(context_given).clone().detach())
    changes = []
    computed_change_collected = []
    for i in range(MAX_ITER):
        context_processed.requires_grad_(True)
        output = builder.predict(context_processed)
        changes.append({
            "changed_to": torch.argmax(context_processed, dim=-1).tolist()[0],
            "prediction_str": get_results(output)[1],
        })
        if get_performance(context_processed, target_given[0]) < PREDICT_THRESHOLD:
            break
        loss = criterion(output[0][0], target_given)
        context_processed.retain_grad()
        loss.backward(retain_graph=True)
        grad = context_processed.grad.sign()
        if change is None:
            change = ALPHA * grad
        else:
            change += ALPHA * grad
        context_processed = context_processed + change
        context_processed, computed_change = compute_change(context_processed, original_context)
        computed_change_collected.append(computed_change)
    return changes, computed_change_collected

def count_changes(changes_needed):
    orig = changes_needed[0]["changed_to"]
    final = changes_needed[-1]["changed_to"]
    return count_list_diff(orig, final)

def count_list_diff(orig, final):
    changed_entries = 0
    for orig, final in zip(orig, final):
        if orig != final:
            changed_entries += 1
    return changed_entries
    
def show_changes(changes_needed):
    orig = changes_needed[0]["changed_to"]
    final = changes_needed[-1]["changed_to"]
    changes = []
    same = []
    for orig, final in zip(orig, final):
        if orig == final:
            changes.append("-")
            same.append(final)
        else:
            changes.append(final)
            same.append("XX")
    return format_list(changes), format_list(same)

def format_list(li):
    return f"[{", ".join([f'{num:2}' for num in li])}]"
            
def format_line(current_trace_num, changes_needed, i):
    start = " "*(len(str(current_trace_num)) + 2)
    if i != 0:
        if len(str(current_trace_num)) == 1:
            start = f'{i:<3}'
        elif len(str(current_trace_num)) == 2:
            start = f'{i:<4}'
        elif len(str(current_trace_num)) == 3:
            start = f'{i:<5}'
        elif len(str(current_trace_num)) == 4:
            start = f'{i:<6}'
        else:
            start = f'{i:<7}'
        b = list(start)
        b[len(str(i))] = '<'
        start = ''.join(b)
    return f"{start}{format_list(changes_needed[i]["changed_to"])} -> {changes_needed[i]['prediction_str']}\n"
            
def print_state(changes_needed, current_trace_num, con, event_chosen, change_collected, print_path=True, include_change=True):
    mode_int = 0
    changed_num = len(changes_needed)
    perturbations_num = count_changes(changes_needed)
    result_string = ""
    if changed_num == 1:
        pass
    elif changed_num == MAX_ITER:
        mode_int = 2
    else:
        mode_int = 1
        result_string += f"{current_trace_num}: {format_list(con[0].tolist())} == {event_chosen.tolist()} Changed {{{changed_num - 2}}}, Perturbations {{{perturbations_num}}}\n"
        if print_path:
            for i in range(len(changes_needed)):
                result_string += format_line(current_trace_num, changes_needed, i)                    
                if include_change:
                    if i is not len(changes_needed) - 1:
                        result_string += f"{" "*(len(str(current_trace_num)) + 2)} {change_collected[i]=}\n"
        else:
            for i in range(len(changes_needed)):
                if i == 0 or (changes_needed[i]["changed_to"] != changes_needed[i-1]["changed_to"] and i != len(changes_needed) - 1):
                    result_string += format_line(current_trace_num, changes_needed, i)
            change_last = changes_needed[-1]
            result_string += f"{" "*(len(str(current_trace_num)) + 2)}{format_list(change_last["changed_to"])} -> {change_last['prediction_str']}\n"
        changed_entries, same_entries = show_changes(changes_needed)
        result_string += f"{" "*(len(str(current_trace_num)) - 1)}== {same_entries}\n"
        result_string += f"{" "*(len(str(current_trace_num)) - 1)}-> {changed_entries}\n"
        result_string += "\n" 
    return mode_int, result_string

def process_traces(context_to_process, events_to_process, print_path=False, include_change=False, write_to_file=False, print_result=False, use_cache=True, save_pt=True):
    f_name1 = f"perturbed/collected/length={len(context_to_process)}, alpha={ALPHA}, epsilon={EPSILON}, num_iterations={MAX_ITER}, print_path={print_path} include_change={include_change}.pt"
    f_name2 = f"perturbed/indices_collected/length={len(context_to_process)}, alpha={ALPHA}, epsilon={EPSILON}, num_iterations={MAX_ITER}, print_path={print_path} include_change={include_change}.pt"
    if use_cache:
        if os.path.exists(f_name1) and os.path.exists(f_name2):
            print(f"Loading {f_name1}")
            print(f"Loading {f_name2}")
            with open(f"results/trace/length={len(context_to_process)}, alpha={ALPHA}, epsilon={EPSILON}, num_iterations={MAX_ITER}, print_path={print_path} include_change={include_change}.txt", 'r') as file:
                lines = file.readlines()
                print(lines[-1])
            return torch.load(f_name1), torch.load(f_name2)
    length = len(context_to_process)
    perturbed_collected_main = []
    perturbed_indices = []
    states = [0, 0, 0]
    safe_to_file = ""
    iters = range(length)
    if not print_result:
        iters = tqdm(iters)
    for current_trace_num in iters:
        con, e = context_to_process[current_trace_num], events_to_process.unsqueeze(1)[current_trace_num]
        con.resize_(1, con.size()[-1])
        changes_needed, change_collected = bim_attack(context_given=con, target_given=e)
        mode_int, result_string = print_state(changes_needed, current_trace_num, con, e, change_collected, print_path=print_path, include_change=include_change)
        if print_result:
            print(result_string, end="")
        safe_to_file += result_string
        if mode_int == 1:
            perturbed_indices.append(current_trace_num)
            perturbed_collected_main.append(changes_needed[-1]['changed_to'])
        states[mode_int] += 1
    print(f"incorrect={states[0]} perturbed={states[1]} timeout={states[2]}")
    safe_to_file += f"incorrect={states[0]} perturbed={states[1]} timeout={states[2]}"
    if write_to_file:
        with open(f"results/trace/length={length}, alpha={ALPHA}, epsilon={EPSILON}, num_iterations={MAX_ITER}, print_path={print_path} include_change={include_change}.txt", "w") as f:
            f.write(safe_to_file)
    perturbed_collected_main = to_cuda(torch.tensor(perturbed_collected_main))
    perturbed_indices = to_cuda(torch.tensor(perturbed_indices))
    if save_pt:
        torch.save(perturbed_collected_main, f_name1)
        torch.save(perturbed_indices, f_name2)
    return perturbed_collected_main, perturbed_indices

def inspect_index(index_inspected, context_l, events_l):
    local_context_picked = to_cuda(context_l[index_inspected:index_inspected+1].clone().detach())
    local_events_picked = to_cuda(events_l[index_inspected:index_inspected+1].clone().detach())
    process_traces(local_context_picked, local_events_picked, print_path=True, include_change=False, write_to_file=True, print_result=True)
    
def get_changes_list(s, f):
    perturbations_made = []
    for i in range(len(s)):
        if s[i] != f[i]:
            perturbations_made.append((i, f[i].item()))
    return perturbations_made

def get_possible_combinations(perturbations_made):
    subsets = []
    for r_index in range(1, len(perturbations_made)):
        subsets.extend(itertools.combinations(perturbations_made, r_index))
    result = [list(subset) for subset in subsets]
    return result

def get_minimum_change_for_perturbation(start, final, events_picked):
    combination_of_perturbation = get_possible_combinations(get_changes_list(start, final))
    for combination in combination_of_perturbation:
        copy = to_cuda(start.clone().detach())
        for index_of_change, value_of_change in combination:
            copy[index_of_change] = value_of_change
        copy.resize_(1, copy.size()[-1])
        if get_performance(copy, events_picked) < PREDICT_THRESHOLD:
            return copy, f"Shortcut: {get_changes_list(start, final)} -> {combination}\n" 
    return final, ""

def process_single(context_chosen):
    result_string = ""
    context_chosen = to_cuda(context_chosen.clone().detach())
    output = predict_single(context_chosen)
    attentions = [round(x, 5) for x in output[1][0][0].tolist()]
    result_string += f"{format_list(context_chosen[0].tolist())} -> {get_results(output)[1]}\n"
    for c, a in zip(context_chosen[0], attentions):
        result_string += f"[{c:2}] {'{:.5f}'.format(a)} "
    result_string += "\n"
    return result_string
    
def predict_single(context_chosen):
    context_chosen.resize_(1, context_chosen.size()[-1])
    context_one_hot = to_one_hot(context_chosen)
    return builder.predict(context_one_hot)

def predict_single_from_list(list_picked, index_picked):
    return predict_single(to_cuda(list_picked[index_picked:index_picked+1].detach()))

def analysis(perturbed_chosen, context_chosen, events_chosen, index_chosen, index_inplace):
    result_string = ""
    start = context_chosen
    final = perturbed_chosen
    minimum_change_for_perturbation, final_combination = get_minimum_change_for_perturbation(start, final, events_chosen)
    pick = final.clone().detach()
    if minimum_change_for_perturbation is not None:
        result_string += f"Analyzing [{index_chosen}], Perturbations [{count_list_diff(start, final)}], Index [{index_inplace}], Changes [{len(get_changes_list(start, final))}]"
        if len(final_combination) != 0:
            pick = minimum_change_for_perturbation.clone().detach()[0]
            result_string += f", Needed [{final_combination.split("->")[1].count("(")}]"
        result_string += f"\n{process_single(start)}{"-"*130}\n"
        if len(final_combination) != 0:
            result_string += process_single(final)
            result_string += "-"*130 + "\n"
        result_string += process_single(minimum_change_for_perturbation)
        result_string += final_combination
    else:
        result_string += f"Change did not work on [{index_chosen}], Index [{index_inplace}]\n"
    result_string += "\n\n"
    return result_string, pick

def find_shortcuts(perturbed_chosen, indices_chosen, context_chosen, events_chosen, print_details=False, use_cache=True, save_file=True):
    f_name = f"perturbed/minimized/length={len(perturbed_chosen)}, alpha={ALPHA}, epsilon={EPSILON}, num_iterations={MAX_ITER}.pt"
    if use_cache and os.path.exists(f_name):
        print(f"Loading {f_name}")
        print(f"results/attention/length={len(perturbed_chosen)}, alpha={ALPHA}, epsilon={EPSILON}, num_iterations={MAX_ITER}.txt")
        with open(f"results/attention/length={len(perturbed_chosen)}, alpha={ALPHA}, epsilon={EPSILON}, num_iterations={MAX_ITER}.txt", 'r') as file:
            lines = file.readlines()
            print(lines[-1])
        return torch.load(f_name)
    pick_list = []
    with open(f"results/attention/length={len(perturbed_chosen)}, alpha={ALPHA}, epsilon={EPSILON}, num_iterations={MAX_ITER}.txt", "w") as file:
        safe_to_file = ""
        skipped = 0
        shortcuts = 0
        iters = enumerate(zip(perturbed_chosen, context_chosen[indices_chosen], events_chosen[indices_chosen], indices_chosen))
        if not print_details:
            iters = tqdm(iters, total=len(perturbed_chosen))
        for i, perturbed_element  in iters:
            r_string, pick = analysis(*perturbed_element, i)
            pick_list.append(pick)
            if "Change did not work on" in r_string:
                skipped += 1
            if "Shortcut" in r_string:
                shortcuts += 1
            if print_details:
                print(r_string, end="")
            safe_to_file += r_string
        skipped_str = f"same={len(perturbed_chosen) - skipped - shortcuts} {shortcuts=} {skipped=}"
        print(skipped_str)
        safe_to_file += skipped_str
        file.write(safe_to_file)
    pick_list = to_cuda(torch.stack(pick_list))
    if save_file:
        torch.save(pick_list, f_name)
    return pick_list

def sort_dist(trace_dist):
    tensor_tuples = [tuple(t) if len(t) > 0 else (t.item(),) for t in trace_dist]
    index_groups = {}
    for idx, tensor_tuple in enumerate(tensor_tuples):
        if tensor_tuple not in index_groups:
            index_groups[tensor_tuple] = []
        index_groups[tensor_tuple].append(idx)
    sorted_groups = sorted(index_groups.items(), key=lambda x: (len(x[0]), x[0]))
    return [idx for _, indices_l in sorted_groups for idx in indices_l]
    
def print_attention(x_mask, context_mask, vectors, neighbours, distance, scores):
    data_collected = [] 
    for i in range(len(context_mask)):
        l_value = torch.tensor(vectors[context_mask][i].toarray()[0])
        l_indices = torch.nonzero(l_value, as_tuple=False).squeeze(1)
        data_collected.append({
            "trace": x_mask[context_mask][i],
            "indices": l_indices.tolist(),
            "value": l_value[l_indices].tolist(),
            "neighbour": f"{{{neighbours[i]:5}; {scores[i]} | {'{:.4f}'.format(distance[i][0]) }}}"
        })
    res_str = ""
    for s in sort_dist(list(map(lambda x: x["indices"], data_collected))):
        local_l = data_collected[s]
        local_list = list(zip(local_l["indices"], local_l["value"]))
        res = f"{format_list(local_l["trace"].tolist())} -> "
        for index, value in sorted(local_list, key=lambda x: x[1], reverse=True):
            res += f"[{f'{index:2}'}]: {'{:.4f}'.format(value)} "
        res_str += res + f"{" "*(150 - len(res))} {local_l["neighbour"]}\n"
    return res_str

def show_clusters(x, y):
    vectors, mask = interpreter.attended_context(to_one_hot(x), y.reshape(-1, 1))
    indices_y = group_by(y[mask].reshape(-1, 1).cpu().numpy(), lambda e: e.data.tobytes(),)
    
    with open(f"results/cluster/length={len(x)}, alpha={ALPHA}, epsilon={EPSILON}, num_iterations={MAX_ITER}.txt", "w") as file:
        for event, context_mask in indices_y:
            event = ord(event.decode('ascii')[0])   
            if event not in interpreter.tree:
                file.write(f"{"<"*50}[{event}]{">"*50}\n")
                continue
            file.write(f"{"="*50}[{event}]{"="*50}\n")
            vectors_, inverse, _ = sp_unique(vectors[context_mask])
            distance, neighbours = interpreter.tree[event].query(
                X               = vectors_.toarray(),
                return_distance = True,
                dualtree        = vectors_.shape[0] >= 1e3, # Optimization
            )
            neighbours = interpreter.tree[event].get_arrays()[1][neighbours][:, 0]
            scores = np.asarray([interpreter.labels[event][neighbour] for neighbour in neighbours])
            file.write(print_attention(x[mask], context_mask, vectors, neighbours[inverse], distance[inverse], scores[inverse]) + "\n")

def interpret(context_passed, events_passed, attention_query=True, threshold=0.2):
    c = to_one_hot(context_passed)
    e = events_passed.reshape(-1, 1)
    interpreter.threshold = threshold
    # l = labels_test.cpu()    
    # interpreter.cluster(c, e)
    # scores = interpreter.score_clusters(l)
    # interpreter.score(scores)
    if not attention_query:
        return interpreter.predict(X=c, y=e, iterations=0)
    return interpreter.predict(X=c, y=e)

def get_iter_count(context_given, target_given):
    for current_trace_num in range(len(context_given)):
        con, e = context_given[current_trace_num], target_given.unsqueeze(1)[current_trace_num]
        con.resize_(1, con.size()[-1])
        for i in range(MAX_ITER):
            new_context = bim_no_iter(con, e, i)
            if con[0].tolist() != new_context.tolist():
                return i
    return -1

def bim_no_iter(context_given, target_given, iters):
    context_processed = to_cuda(to_one_hot(context_given).clone().detach())
    original_context = to_cuda(to_one_hot(context_given).clone().detach())
    context_processed.requires_grad_(True)
    output = builder.predict(context_processed)
    loss = criterion(output[0][0], target_given)
    context_processed.retain_grad()
    loss.backward(retain_graph=True)
    grad = context_processed.grad.sign()
    change = (ALPHA * grad) * iters
    context_processed = context_processed + change
    context_processed, _ = compute_change(context_processed, original_context)
    return to_cuda(torch.tensor(torch.argmax(context_processed, dim=-1).tolist()[0]))

def run_bim(context_to_process, events_to_process):
    perturbed_collected_main = []
    perturbed_indices = []
    states = [0, 0, 0]
    length = len(context_to_process)
    iters = tqdm(range(length))
    iters_to_flip = get_iter_count(context_to_process, events_to_process)
    for current_trace_num in iters:
        con, e = context_to_process[current_trace_num], events_to_process.unsqueeze(1)[current_trace_num]
        con.resize_(1, con.size()[-1])
        if get_performance(con, e[0]) < PREDICT_THRESHOLD:
            states[0] += 1
            continue
        perturbed = bim_no_iter(context_given=con, target_given=e, iters=iters_to_flip)
        perturbed.resize_(1, perturbed.size()[-1])
        if get_performance(perturbed, e[0]) < PREDICT_THRESHOLD:
            perturbed_collected_main.append(perturbed.tolist()[0])
            perturbed_indices.append(current_trace_num)
            states[1] += 1
        else:
            states[2] += 1
    print(f"incorrect={states[0]} perturbed={states[1]} timeout={states[2]}")
    return to_cuda(torch.tensor(perturbed_collected_main)), to_cuda(torch.tensor(perturbed_indices))

In [None]:
# ind = [0,1,2,3,4]
# perturbed_minimized = find_shortcuts(
#     perturbed_collected[ind], 
#     perturbed_indices_collected[ind],
#     context_test, 
#     events_test,
#     # print_details=True,
#     use_cache=False,
#     save_file=False
# )
# res_perturbed_shortcuts = interpret(perturbed_minimized, events_test[perturbed_indices_collected[ind]], attention_query=False)

In [None]:
# r = [4]
# e = events_test[perturbed_indices_collected[r]]
# 
# p2 = find_shortcuts(
#     perturbed_collected[r], 
#     perturbed_indices_collected[r],
#     context_test, 
#     events_test,
#     use_cache=False,
#     save_file=False
# )
# 
# print(f"original               {context_test[perturbed_indices_collected][r]}")
# print(f"{perturbed_collected[r]=}")
# print(f"{perturbed_minimized[r]=}")
# print(f"                    {p2=}")
# 
# print("l", get_performance(perturbed_minimized[r], e), interpret(perturbed_minimized[r], e, attention_query=False))
# print("p2", get_performance(p2, e), interpret(p2, e, attention_query=False))

In [350]:
pd.Series(labels_test.cpu()).value_counts().sort_index()

1      65
2     454
3    3864
5       9
Name: count, dtype: int64

In [None]:
res_normal = interpret(context_test, events_test, attention_query=False)
pd.Series(res_normal).value_counts().sort_index()

In [None]:
perturbed_collected, perturbed_indices_collected = process_traces(
    context_test, 
    events_test
)  

In [None]:
res_perturbed_only = interpret(perturbed_collected, events_test[perturbed_indices_collected], attention_query=False)
pd.Series(res_perturbed_only).value_counts().sort_index()

In [None]:
context_test_original_perturbed = context_test.clone().detach()
context_test_original_perturbed[perturbed_indices_collected] = perturbed_collected
c_indices_perturbed = get_unique_indices_per_row(context_test_original_perturbed)
res_combined_perturbed = interpret(context_test_original_perturbed[c_indices_perturbed], events_test[c_indices_perturbed], attention_query=False)
pd.Series(res_combined_perturbed).value_counts().sort_index()

In [359]:
perturbed_minimized = find_shortcuts(
    perturbed_collected, 
    perturbed_indices_collected,
    context_test, 
    events_test
)

100%|██████████| 2983/2983 [00:48<00:00, 61.17it/s]


same=54 shortcuts=2929 skipped=0


In [360]:
res_perturbed_shortcuts = interpret(perturbed_minimized, events_test[perturbed_indices_collected], attention_query=False)
pd.Series(res_perturbed_shortcuts).value_counts().sort_index()

-1.0    2983
Name: count, dtype: int64

In [361]:
context_test_original_perturbed_shortcuts = context_test.clone().detach()
context_test_original_perturbed_shortcuts[perturbed_indices_collected] = perturbed_minimized
c_indices_perturbed_shortcuts = get_unique_indices_per_row(context_test_original_perturbed_shortcuts)
res_combined_perturbed_shortcuts = interpret(context_test_original_perturbed_shortcuts[c_indices_perturbed_shortcuts], events_test[c_indices_perturbed_shortcuts], attention_query=False)
pd.Series(res_combined_perturbed_shortcuts).value_counts().sort_index()

-3.0     204
-1.0    4032
 2.0      26
 3.0      87
 5.0      14
Name: count, dtype: int64

In [13]:
show_clusters(context_test, events_test)