In [None]:
import os
import torch
import random
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
from deepcase_copy.context_builder.loss import LabelSmoothing
from deepcase_copy.context_builder.context_builder import ContextBuilder
from deepcase_copy.interpreter.interpreter import Interpreter
from deepcase_copy.interpreter.utils import group_by, sp_unique
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

from deepcase.context_builder.context_builder import ContextBuilder as BaseContextBuilder
from deepcase.interpreter.interpreter import Interpreter as BaseInterpreter

def disable_dropout(m):
    if isinstance(m, torch.nn.Dropout):
        m.p = 0.0 

def to_cuda(item):
    if torch.cuda.is_available():
        return item.to('cuda')
    return item

def to_cuda_tensor(item):
    return to_cuda(torch.tensor(item))

def get_unique_indices_per_row(tensor):
    indices_list = []
    row_list = []
    indices_list_set = set()
    for row in tqdm(range(len(tensor))):
        curr = tuple(tensor[row])
        if isinstance(tensor, torch.Tensor):
            curr = tuple(tensor[row].tolist())
        if curr in indices_list_set:
            continue
        row_list.append(curr)
        indices_list.append(row)
        indices_list_set.add(curr)
    return indices_list

def get_unique_indices_per_row_or_file(tensor, f_name):
    if os.path.exists(f_name):
        return torch.load(f_name)
    t = get_unique_indices_per_row(tensor)
    torch.save(t, f_name)
    return t  

SEQ_LEN=10
MAX_ITER=20
FEATURES=100
DATASET="ided"
PICKING="rand" if DATASET=="hdfs" else "first"
PREDICT_THRESHOLD=0.2
context_builder = to_cuda(ContextBuilder.load(f'{DATASET}/{SEQ_LEN=}/builder.save'))
context_builder.apply(disable_dropout)
interpreter = Interpreter.load(f'{DATASET}/{SEQ_LEN=}/interpreter.save', context_builder)
interpreter.threshold = PREDICT_THRESHOLD
criterion = LabelSmoothing(context_builder.decoder_event.out.out_features, 0.1)

with open(f'{DATASET}/{SEQ_LEN=}/sequences.save', 'rb') as infile:
    data = torch.load(infile)
    context = data["context"]
    events  = data["events"]
    labels  = data["labels"]
    mapping = data["mapping"]
    
    indices = get_unique_indices_per_row_or_file(context, f'{DATASET}/{SEQ_LEN=}/context.pt')
    test_indices, labels_test = None, None
    if labels is not None:
        test_indices = train_test_split(indices, test_size=0.2, random_state=42, stratify=labels[indices])[1]
        labels_test   = to_cuda(labels[test_indices])
    else:
       test_indices = train_test_split(indices, test_size=0.2, random_state=42)[1]
    context_test  = to_cuda(context[test_indices])
    events_test   = to_cuda(events[test_indices])

def to_one_hot(t):
    return to_cuda(context_builder.embedding_one_hot(t).clone().detach())

def to_trace(o):
    return torch.argmax(o, dim=-1).tolist()[0]

def max_to_one_first(tensor):
    max_indices = torch.argmax(tensor, dim=-1, keepdim=True)
    result = torch.zeros_like(tensor)
    result.scatter_(-1, max_indices, 1.0)
    return result

def max_to_one_rand(tensor):
    max_values = torch.max(tensor.squeeze(0), dim=1).values
    comparison = tensor == max_values.unsqueeze(1)
    tensor_indices = torch.nonzero(comparison, as_tuple=True)
    max_indices = [[] for _ in range(tensor.size(1))]
    for row, col in zip(tensor_indices[1].tolist(), tensor_indices[2].tolist()):
        max_indices[row].append(col)
    random_max_indices = to_cuda_tensor([random.choice(sublist) for sublist in max_indices]).unsqueeze(0)
    return to_one_hot(random_max_indices)

def max_to_one(tensor):
    return max_to_one_rand(tensor) if PICKING == "rand" else max_to_one_first(tensor)

def to_output(context_chosen):
    return context_builder.predict(context_chosen)

def get_file_name(f_name, attention_query=False):
    return f"{DATASET}/{SEQ_LEN=}/{PREDICT_THRESHOLD=}/{PICKING}/{"attention_query" if attention_query else "no_query"}/{f_name}"

def get_correct_prediction_for_list(context_chosen, events_chosen, attention_query=False):
    _, mask = interpreter.attended_context(
        X           = to_one_hot(context_chosen),
        y           = to_cuda(events_chosen),
        iterations  = 100 if attention_query else 0
    )
    return torch.where(~mask)[0], torch.where(mask)[0]

def get_performance(context_chosen, event_chosen, attention_query=False):
    if attention_query:
        context_processed = to_cuda_tensor(to_trace(context_chosen)).unsqueeze(0)
        pred_true = get_correct_prediction_for_list(context_processed, event_chosen.unsqueeze(0).unsqueeze(0), attention_query=True)[1]
        return len(pred_true) == 0
    output = to_output(context_chosen)
    results_picked = torch.topk(output[0][0][0], len(output[0][0][0]))
    exp = results_picked.values.exp()
    res_indices = results_picked.indices
    for j in range(len(output[0][0][0])):
        if res_indices[j].item() == event_chosen:
            return exp[j] < PREDICT_THRESHOLD
    return False

def get_changes_list(start, final):
    perturbations_made = []
    for i, (s, f) in enumerate(zip(start, final)):
        if s != f:
            perturbations_made.append((i, f.item()))
    return perturbations_made

def get_context_builder_interpreter(train, path):
    l_context_builder, l_interpreter = None, None
    if os.path.exists(f'{path}/builder.save') and os.path.exists(f'{path}/interpreter.save'): 
        l_context_builder = ContextBuilder.load(f'{path}/builder.save')
        l_interpreter = Interpreter.load(f'{path}/interpreter.save', l_context_builder)
    else:
        l_context_builder = BaseContextBuilder(
            input_size=FEATURES,  # Number of input features to expect
            output_size=FEATURES,  # Same as input size
            hidden_size=128,  # Number of nodes in hidden layer, in paper we set this to 128
            max_length=10,  # Length of the context, should be same as context in Preprocessor
        )
        l_context_builder.fit(
            X             = context[train],               # Context to train with
            y             = events[train].reshape(-1, 1), # Events to train with, note that these should be of shape=(n_events, 1)
            epochs        = 10,                         # Number of epochs to train with
            batch_size    = 128,                         # Number of samples in each training batch, in paper this was 128
            learning_rate = 0.01,                        # Learning rate to train with, in paper this was 0.01
            verbose       = True,                        # If True, prints progress
        )
        l_interpreter = BaseInterpreter(
            context_builder = l_context_builder, # ContextBuilder used to fit data
            features        = FEATURES,             # Number of input features to expect, should be same as ContextBuilder
            eps             = 0.1,             # Epsilon value to use for DBSCAN clustering, in paper this was 0.1
            min_samples     = 5,               # Minimum number of samples to use for DBSCAN clustering, in paper this was SEQ_LEN=5
            threshold       = 0.2,             # Confidence threshold used for determining if attention from the ContextBuilder can be used, in paper this was 0.2
        )
        l_interpreter.cluster(
            X          = context[train],               # Context to train with
            y          = events[train].reshape(-1, 1), # Events to train with, note that these should be of shape=(n_events, 1)
            iterations = 100,                         # Number of iterations to use for attention query, in paper this was 100
            batch_size = 1024,                        # Batch size to use for attention query, used to limit CUDA memory usage
            verbose    = True,                        # If True, prints progress
        )
        l_context_builder.save(f"{path}/builder.save")
        l_interpreter.save(f"{path}/interpreter.save")
    
    l_context_builder = ContextBuilder.load(f'{path}/builder.save')
    l_interpreter = Interpreter.load(f'{path}/interpreter.save', l_context_builder)
    l_context_builder = to_cuda(l_context_builder)
    l_context_builder.apply(disable_dropout)
    return l_context_builder, l_interpreter, LabelSmoothing(l_context_builder.decoder_event.out.out_features, 0.1)

def get_perturbations(context_chosen, event_chosen, attention_query=False):
    perturbed_collected_main = []
    perturbed_indices_main = []
    perturbed_iterations_main = []
    pred_false, pred_true = get_correct_prediction_for_list(context_chosen, event_chosen.unsqueeze(1), attention_query=attention_query)
    states = [len(pred_false), 0, 0]
    for current_trace_num in tqdm(pred_true):
        con, e = context_chosen[current_trace_num].unsqueeze(0), event_chosen.unsqueeze(1)[current_trace_num]
        perturbed_result, perturb_iterations = bim_attack(con, e, attention_query=attention_query)
        if perturbed_result is not None:
            states[1] += 1
            perturbed_collected_main.append(perturbed_result)
            perturbed_indices_main.append(current_trace_num)
            perturbed_iterations_main.append(perturb_iterations)
        else:
            states[2] += 1
    print(to_cuda_tensor(states))
    return to_cuda_tensor(perturbed_collected_main), to_cuda_tensor(perturbed_indices_main), to_cuda_tensor(states), to_cuda_tensor(perturbed_iterations_main)

def get_perturbations_or_location(context_chosen, event_chosen, attention_query=False):
    path = f"{DATASET}/events/{"attention_query" if attention_query else "no_query"}"
    f_name_perturbed = f"{path}/perturbed_collected.pt"
    f_name_indices = f"{path}/perturbed_indices.pt"
    f_name_distribution = f"{path}/perturbed_distribution.pt"
    f_name_iterations = f"{path}/perturbed_iterations.pt"
    if os.path.exists(f_name_perturbed) and os.path.exists(f_name_indices) and os.path.exists(f_name_distribution) and os.path.exists(f_name_iterations):
        print(f"Loading {f_name_perturbed}")
        print(f"Loading {f_name_indices}")
        print(f"Loading {f_name_distribution}")
        print(torch.load(f_name_distribution))
        return torch.load(f_name_perturbed), torch.load(f_name_indices), torch.load(f_name_distribution), torch.load(f_name_iterations)
    perturb_main, indices_main, result_main, iterations_main = get_perturbations(context_chosen, event_chosen, attention_query=attention_query)
    torch.save(perturb_main, f_name_perturbed)
    torch.save(indices_main, f_name_indices)
    torch.save(result_main, f_name_distribution)
    torch.save(iterations_main, f_name_iterations)
    print(result_main)
    return perturb_main, indices_main, result_main, iterations_main

def get_perturbations_or_file(context_chosen, event_chosen, attention_query=False):
    f_name_perturbed = get_file_name("perturbed_collected.pt", attention_query=attention_query)
    f_name_indices = get_file_name("perturbed_indices.pt", attention_query=attention_query)
    f_name_distribution = get_file_name("perturbed_distribution.pt", attention_query=attention_query)
    f_name_iterations = get_file_name("perturbed_iterations.pt", attention_query=attention_query)
    if os.path.exists(f_name_perturbed) and os.path.exists(f_name_indices) and os.path.exists(f_name_distribution) and os.path.exists(f_name_iterations):
        print(f"Loading {f_name_perturbed}")
        print(f"Loading {f_name_indices}")
        print(f"Loading {f_name_distribution}")
        print(torch.load(f_name_distribution))
        return torch.load(f_name_perturbed), torch.load(f_name_indices), torch.load(f_name_distribution), torch.load(f_name_iterations)
    os.makedirs(f"{DATASET}/{SEQ_LEN=}/{PREDICT_THRESHOLD=}/{PICKING}/{MAX_ITER=}/{"attention_query" if attention_query else "no_query"}", exist_ok=True)
    perturb_main, indices_main, result_main, iterations_main = get_perturbations(context_chosen, event_chosen, attention_query=attention_query)
    torch.save(perturb_main, f_name_perturbed)
    torch.save(indices_main, f_name_indices)
    torch.save(result_main, f_name_distribution)
    torch.save(iterations_main, f_name_iterations)
    print(result_main)
    return perturb_main, indices_main, result_main, iterations_main

def get_possible_combinations(perturbations_made):
    subsets = []
    for r_index in range(1, len(perturbations_made)):
        subsets.extend(itertools.combinations(perturbations_made, r_index))
    result = [list(subset) for subset in subsets]
    return result

def get_minimum_change_for_perturbation_no_query(perturbed_chosen, context_chosen, events_chosen):
    for combination in get_possible_combinations(get_changes_list(context_chosen, perturbed_chosen)):
        copy = to_cuda(context_chosen.clone().detach())
        for index_of_change, value_of_change in combination:
            copy[index_of_change] = value_of_change
        if get_performance(to_one_hot(copy.unsqueeze(0)), events_chosen):
            return copy
    return perturbed_chosen

def get_minimum_change_for_perturbation_attention_query(perturbed_chosen, context_chosen, events_chosen):
    trace_combinations = []
    for combination in get_possible_combinations(get_changes_list(context_chosen, perturbed_chosen)):
        copy = to_cuda(context_chosen.clone().detach())
        for index_of_change, value_of_change in combination:
            copy[index_of_change] = value_of_change
        trace_combinations.append(copy)
    if len(trace_combinations) == 0:
        return perturbed_chosen
    mask_indices = get_correct_prediction_for_list(torch.stack(trace_combinations), torch.full((len(trace_combinations), 1), events_chosen.item()), True)[0]
    return trace_combinations[mask_indices[0]] if len(mask_indices) != 0 else perturbed_chosen
    
def get_shortcuts(perturbed_chosen, context_chosen, events_chosen, attention_query=False):
    get_shortcuts_func = get_minimum_change_for_perturbation_attention_query if attention_query else get_minimum_change_for_perturbation_no_query
    pick_list = []
    for perturbed_element in tqdm(zip(perturbed_chosen, context_chosen, events_chosen), total=len(perturbed_chosen)):
        pick_list.append(get_shortcuts_func(*perturbed_element))
    return to_cuda(torch.stack(pick_list))

def get_shortcuts_or_file(perturbed_chosen, context_chosen, events_chosen, attention_query=False):
    f_name = get_file_name("shortcuts.pt", attention_query=attention_query)
    if os.path.exists(f_name):
        print(f"Loading {f_name}")
        return torch.load(f_name)
    pick_list = get_shortcuts(perturbed_chosen, context_chosen, events_chosen, attention_query)
    torch.save(pick_list, f_name)
    return pick_list

def get_dist(trace_dist):
    tensor_tuples = [tuple(t) if len(t) > 0 else (t.item(),) for t in trace_dist]
    index_groups = {}
    for idx, tensor_tuple in enumerate(tensor_tuples):
        if tensor_tuple not in index_groups:
            index_groups[tensor_tuple] = []
        index_groups[tensor_tuple].append(idx)
    sorted_groups = sorted(index_groups.items(), key=lambda x: (len(x[0]), x[0]))
    return [idx for _, indices_l in sorted_groups for idx in indices_l]
    
def get_matrix_perturb(context_chosen, perturb_chosen):
    matrix = [0] * SEQ_LEN
    for c, p in zip(context_chosen, perturb_chosen):
        cp = len(get_changes_list(c, p)) - 1
        matrix[cp] += 1
    return pd.DataFrame(matrix, index=range(1, SEQ_LEN + 1))

def get_matrix_shortcuts(context_chosen, shortcut_chosen, perturb_chosen):
    matrix = [[0] * SEQ_LEN for _ in range(SEQ_LEN)]
    for c, s, p in zip(context_chosen, shortcut_chosen, perturb_chosen):
        cs = len(get_changes_list(c, s)) - 1
        cp = len(get_changes_list(c, p)) - 1
        matrix[cp][cs] += 1
    p_res = pd.DataFrame(matrix, index=range(1, SEQ_LEN + 1), columns=range(1, SEQ_LEN + 1))
    return p_res.loc[:, p_res.sum() != 0]
    
def interpret_query(context_passed, events_passed):
    c = to_one_hot(context_passed)
    e = events_passed.reshape(-1, 1)
    return interpreter.predict(X=c, y=e)

def interpret(context_passed, events_passed):
    c = to_one_hot(context_passed)
    e = events_passed.reshape(-1, 1)
    return interpreter.predict(X=c, y=e, iterations=0)
    
def get_combined(perturbed_chosen, perturbed_indices_chosen, attention_query=False):
    interpret_func = interpret_query if attention_query else interpret
    context_test_copy = context_test.clone().detach()
    context_test_copy[perturbed_indices_chosen] = perturbed_chosen
    # context_test_copy_indices = get_unique_indices_per_row(context_test_copy)
    return interpret_func(context_test_copy, events_test)
    
def bim_attack(context_chosen, event_chosen, attention_query=False):
    context_processed = to_one_hot(context_chosen)
    for iteration in range(MAX_ITER):
        context_processed.requires_grad_(True)
        output = context_builder.predict(context_processed)
        if get_performance(context_processed, event_chosen[0], attention_query=attention_query):
            return to_trace(context_processed), iteration
        loss = criterion(output[0][0], event_chosen)
        context_processed.retain_grad()
        loss.backward(retain_graph=True)
        context_processed = max_to_one(context_processed + context_processed.grad.sign())
    return None, -1
    
def format_results(results):
    results_picked = torch.topk(results[0][0][0], 3)
    exp = results_picked.values.exp()
    res_indices = results_picked.indices
    return ", ".join([f"{format_list([res_indices[j].item()])} {'{:.3f}'.format(exp[j])}" for j in range(3)])
  
def format_changes(start, final):
    changes = []
    same = []
    for s, f in zip(start, final):
        if s == f:
            changes.append("-")
            same.append(final)
        else:
            changes.append(f)
            same.append("XX")
    return format_list(changes), format_list(same)

def format_list(li):
    return f"[{", ".join([f'{num:2}' for num in li])}]"
            
def format_trace_prediction(current_trace_num, trace):
    start = " "*(len(str(current_trace_num)) + 2)
    trace = to_cuda(trace.clone().detach()).unsqueeze(0)
    return f"{start}{format_list(trace)} -> {format_results(to_output(trace))}\n"
            
def format_perturbation(perturb_chosen, perturb_int, context_chosen, event_chosen, current_trace_num):
    result_string = f"{current_trace_num}: {format_list(context_chosen[0].tolist())} == {event_chosen.tolist()} Changed [{perturb_int}], Perturbations [{len(get_changes_list(context_chosen, perturb_chosen))}]\n"
    result_string += format_trace_prediction(current_trace_num, context_chosen)    
    result_string += format_trace_prediction(current_trace_num, perturb_chosen)    
    changed_entries, same_entries = format_changes(context_chosen, perturb_chosen)
    result_string += f"{" "*(len(str(current_trace_num)) - 1)}== {same_entries}\n{" "*(len(str(current_trace_num)) - 1)}-> {changed_entries}\n\n"
    return result_string

def format_attention(x_mask, context_mask, vectors, neighbours, distance, scores):
    data_collected = [] 
    for i in range(len(context_mask)):
        l_value = torch.tensor(vectors[context_mask][i].toarray()[0])
        l_indices = torch.nonzero(l_value, as_tuple=False).squeeze(1)
        data_collected.append({
            "trace": x_mask[context_mask][i],
            "indices": l_indices.tolist(),
            "value": l_value[l_indices].tolist(),
            "neighbour": f"{{{neighbours[i]:5}; {scores[i]} | {'{:.4f}'.format(distance[i][0]) }}}"
        })
    res_str = ""
    for s in get_dist(list(map(lambda x: x["indices"], data_collected))):
        local_l = data_collected[s]
        local_list = list(zip(local_l["indices"], local_l["value"]))
        res = f"{format_list(local_l["trace"].tolist())} -> "
        for index, value in sorted(local_list, key=lambda x: x[1], reverse=True):
            res += f"[{f'{index:2}'}]: {'{:.4f}'.format(value)} "
        res_str += res + f"{" "*(150 - len(res))} {local_l["neighbour"]}\n"
    return res_str

def format_series(series):
    return pd.Series(series).value_counts().sort_index()

def format_confusion_matrix(y_true, y_pred):
    format_labels = sorted(list(set(y_true.tolist()) | set(y_pred.tolist())))
    cm_df = pd.DataFrame(confusion_matrix(y_pred, y_true), index=format_labels, columns=format_labels)
    cm_df = cm_df.loc[cm_df.sum(axis=1) != 0, :]
    cm_df = cm_df.loc[:, cm_df.sum() != 0]
    cm_df["Row_Sum"] = cm_df.sum(axis=1)
    return cm_df

def get_events_dist(attention_query=False):
    training_indices_str_org, test_indices_str_org = train_test_split(range(len(events)), test_size=0.2, random_state=42, stratify=events)
    context_builder_str, interpreter_str, criterion_str = get_context_builder_interpreter(training_indices_str_org, f'{DATASET}/events')
    test_indices_str = get_unique_indices_per_row_or_file(context[test_indices_str_org], f'{DATASET}/events/context_test.pt')
    events_test_indices_str   = to_cuda(events [test_indices_str])
    context_test_indices_str  = to_cuda(context[test_indices_str])
    _, mask = interpreter_str.attended_context(
        X           = to_one_hot(context_test_indices_str),
        y           = events_test_indices_str.unsqueeze(1),
        iterations  = 100 if attention_query else 0
    )
    targeted_indices = [i for i in range(len(test_indices_str)) if i not in torch.where(~mask)[0].tolist()]
    _, perturbed_indices_str, _, _ = get_perturbations_or_location(context_test_indices_str, events_test_indices_str, attention_query=attention_query)
    events_series = pd.Series(events_test_indices_str[perturbed_indices_str].cpu()).value_counts().sort_values(ascending=False)
    events_df = events_series.reset_index()
    events_df.columns = ['label', 'count']
    print(len(set(events.tolist())), len(set(events[test_indices_str_org].tolist())), len(set(events_test_indices_str.tolist())))
    if labels is not None:
        events_df["level"] = events_df["label"].apply(lambda label: list(set(labels[test_indices_str][np.where(events_test_indices_str.cpu() == label)[0]].tolist()))[0])
    events_df['print'] = events_df.apply(lambda row: f"\\colorcellpercentamount{{{row['count']}}}{{{torch.sum(events_test_indices_str[targeted_indices] == row['label'])}}}", axis=1)
    events_df.drop(columns=['count'], inplace=True)
    if labels is not None:
        events_df["mapping"] = events_df["label"].apply(lambda label: mapping[label])
    events_df["label"] = events_df["label"].apply(lambda label: f"\\textbf{{{label}}}")
    return events_df

def get_events_miss(attention_query=False):
    training_indices_str_org, test_indices_str_org = train_test_split(range(len(events)), test_size=0.2, random_state=42, stratify=events)
    context_builder_str, interpreter_str, criterion_str = get_context_builder_interpreter(training_indices_str_org, f'{DATASET}/events')
    test_indices_str = get_unique_indices_per_row_or_file(context[test_indices_str_org], f'{DATASET}/events/context_test.pt')
    events_test_indices_str   = to_cuda(events [test_indices_str])
    context_test_indices_str  = to_cuda(context[test_indices_str])
    _, mask = interpreter_str.attended_context(
        X           = to_one_hot(context_test_indices_str),
        y           = events_test_indices_str.unsqueeze(1),
        iterations  = 100 if attention_query else 0
    )
    _, perturbed_indices_str, _, _ = get_perturbations_or_location(context_test_indices_str, events_test_indices_str, attention_query=attention_query)
    incorrect_indices = [i for i in range(len(test_indices_str)) if i not in torch.where(~mask)[0].tolist() + perturbed_indices_str.tolist()]
    events_series = pd.Series(events_test_indices_str[incorrect_indices].cpu()).value_counts().sort_values(ascending=False)
    events_df = events_series.reset_index()
    events_df.columns = ['label', 'count']
    print(len(set(events.tolist())), len(set(events[test_indices_str_org].tolist())), len(set(events_test_indices_str.tolist())))
    if labels is not None:
        events_incorrect = events_test_indices_str[incorrect_indices].cpu()
        labels_incorrect = labels[incorrect_indices]
        events_df["level"] = events_df["label"].apply(lambda label: list(set(labels_incorrect[np.where(events_incorrect == label)[0]].tolist()))[0])

    return events_df

def show_clusters(x, y):
    vectors, mask = interpreter.attended_context(to_one_hot(x), y.reshape(-1, 1))
    indices_y = group_by(y[mask].reshape(-1, 1).cpu().numpy(), lambda e: e.data.tobytes(),)
    with open(f"{DATASET}/{SEQ_LEN=}/{PREDICT_THRESHOLD=}/{PICKING}/{MAX_ITER=}/cluster.txt", "w") as file:
        for event, context_mask in indices_y:
            event = ord(event.decode('ascii')[0])   
            if event not in interpreter.tree:
                file.write(f"{"<"*50}[{event}]{">"*50}\n")
                continue
            file.write(f"{"="*50}[{event}]{"="*50}\n")
            vectors_, inverse, _ = sp_unique(vectors[context_mask])
            distance, neighbours = interpreter.tree[event].query(
                X               = vectors_.toarray(),
                return_distance = True,
                dualtree        = vectors_.shape[0] >= 1e3, # Optimization
            )
            neighbours = interpreter.tree[event].get_arrays()[1][neighbours][:, 0]
            scores = np.asarray([interpreter.labels[event][neighbour] for neighbour in neighbours])
            file.write(format_attention(x[mask], context_mask, vectors, neighbours[inverse], distance[inverse], scores[inverse]) + "\n")

In [64]:
get_events_dist(attention_query=False)

Loading ided/events/no_query/perturbed_collected.pt
Loading ided/events/no_query/perturbed_indices.pt
Loading ided/events/no_query/perturbed_distribution.pt
tensor([ 864, 5050,  140], device='cuda:0')
89 84 54


Unnamed: 0,label,level,print,mapping
0,\textbf{24},3,\colorcellpercentamount{2180}{2195},ET JA3 Hash - [Abuse.ch] Possible Adware
1,\textbf{64},3,\colorcellpercentamount{847}{889},SURICATA HTTP unable to match response to request
2,\textbf{72},3,\colorcellpercentamount{358}{394},SURICATA STREAM CLOSEWAIT FIN out of window
3,\textbf{66},3,\colorcellpercentamount{200}{200},SURICATA Kerberos 5 weak encryption parameters
4,\textbf{71},3,\colorcellpercentamount{173}{173},SURICATA STREAM 3way handshake wrong seq wrong...
5,\textbf{22},2,\colorcellpercentamount{142}{142},ET INFO TLS Handshake Failure
6,\textbf{86},3,\colorcellpercentamount{114}{114},SURICATA TLS invalid record type
7,\textbf{83},3,\colorcellpercentamount{109}{109},SURICATA STREAM bad window update
8,\textbf{87},3,\colorcellpercentamount{105}{105},SURICATA TLS invalid record/traffic
9,\textbf{20},3,\colorcellpercentamount{104}{104},ET INFO Session Traversal Utilities for NAT (S...


### Without Attention Query

In [3]:
perturbed_collected, perturbed_indices, perturb_distribution, perturbed_iterations = get_perturbations_or_file(context_test, events_test)

Loading hdfs/SEQ_LEN=10/PREDICT_THRESHOLD=0.2/rand/no_query/perturbed_collected.pt
Loading hdfs/SEQ_LEN=10/PREDICT_THRESHOLD=0.2/rand/no_query/perturbed_indices.pt
Loading hdfs/SEQ_LEN=10/PREDICT_THRESHOLD=0.2/rand/no_query/perturbed_distribution.pt
tensor([ 591, 4743,   69], device='cuda:0')


In [4]:
perturbed_minimized = get_shortcuts_or_file(
    perturbed_collected, 
    context_test[perturbed_indices], 
    events_test[perturbed_indices]
)

Loading hdfs/SEQ_LEN=10/PREDICT_THRESHOLD=0.2/rand/no_query/shortcuts.pt


In [5]:
format_series(perturbed_iterations.cpu())

1     3447
2      779
3      279
4       96
5       67
6       24
7       16
8       10
9        6
10       5
11       4
12       4
13       1
14       1
15       3
18       1
Name: count, dtype: int64

In [6]:
format_confusion_matrix(labels_test[perturbed_indices].cpu(), interpret(context_test[perturbed_indices], events_test[perturbed_indices]))

TypeError: 'NoneType' object is not subscriptable

In [None]:
get_matrix_perturb(context_test[perturbed_indices], perturbed_minimized)

In [None]:
format_confusion_matrix(labels_test.cpu(), get_combined(perturbed_minimized, perturbed_indices))

### Applying No Attention Query Data to Attention Query

In [None]:
format_confusion_matrix(labels_test.cpu(), get_combined(perturbed_minimized, perturbed_indices, attention_query=True))

### Attention Query

In [None]:
perturbed_collected_query, perturbed_indices_query, perturb_distribution_query, perturbed_iterations_query = get_perturbations_or_file(context_test, events_test, attention_query=True)

In [None]:
format_series(events_test[perturbed_indices_query].cpu())

In [None]:
perturbed_minimized_query = get_shortcuts_or_file(
    perturbed_collected_query,
    context_test[perturbed_indices_query], 
    events_test[perturbed_indices_query],
    attention_query=True
)

In [None]:
format_series(perturbed_iterations_query.cpu())

In [None]:
get_matrix_perturb(context_test[perturbed_indices_query], perturbed_minimized_query)

In [None]:
format_confusion_matrix(labels_test[perturbed_indices_query].cpu(), interpret_query(context_test[perturbed_indices_query], events_test[perturbed_indices_query]))

In [None]:
format_confusion_matrix(labels_test.cpu(), get_combined(perturbed_minimized_query, perturbed_indices_query, attention_query=True))