In [189]:
import os
import torch
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
from deepcase_copy.context_builder.loss import LabelSmoothing
from deepcase_copy.context_builder.context_builder import ContextBuilder
from deepcase_copy.interpreter.interpreter import Interpreter
from deepcase_copy.interpreter.utils import group_by, sp_unique
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import time

def disable_dropout(m):
    if isinstance(m, torch.nn.Dropout):
        m.p = 0.0 

def to_cuda(item):
    if torch.cuda.is_available():
        return item.to('cuda')
    return item

def get_unique_indices_per_row(tensor):
    indices_list = []
    row_list = []
    indices_list_set = set()
    for row in range(len(tensor)):
        curr = tuple(tensor[row])
        if isinstance(tensor, torch.Tensor):
            curr = tuple(tensor[row].tolist())
        if curr in indices_list_set:
            continue
        row_list.append(curr)
        indices_list.append(row)
        indices_list_set.add(curr)
    return indices_list

def get_unique_indices_per_row_or_file(tensor, f_name):
    if os.path.exists(f_name):
        return torch.load(f_name)
    t = get_unique_indices_per_row(tensor)
    torch.save(t, f_name)
    return t  

ALPHA=1.0
EPSILON=0.5
MAX_ITER = 100
PREDICT_THRESHOLD = 0.2
builder = to_cuda(ContextBuilder.load('save/builder.save'))
builder.apply(disable_dropout)
interpreter = Interpreter.load('save/interpreter.save', builder)
criterion = LabelSmoothing(builder.decoder_event.out.out_features, 0.1)

with open('save/sequences.save', 'rb') as infile:
    data = torch.load(infile)
    context = data["context"] # 172572
    events  = data["events"]
    labels  = data["labels"]
    
    indices = get_unique_indices_per_row_or_file(context, 'save/context.pt')    
    train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42, stratify=labels[indices])
    
    context_train = to_cuda(context[train_indices]) #
    context_test  = to_cuda(context[test_indices])  # 4392
    
    events_train  = to_cuda(events[train_indices])
    events_test   = to_cuda(events[test_indices])
    
    labels_train  = to_cuda(labels[train_indices])
    labels_test   = to_cuda(labels[test_indices])

def to_one_hot(t):
    return to_cuda(builder.embedding_one_hot(t).clone().detach())

def to_trace(o):
    return torch.argmax(o, dim=-1).tolist()[0]

def max_to_one(tensor):
    max_indices = torch.argmax(tensor, dim=-1, keepdim=True)
    result = torch.zeros_like(tensor)
    result.scatter_(-1, max_indices, 1.0)
    return result

def to_output(context_chosen):
    return builder.predict(context_chosen)

def get_file_name(f_name, attention_query=False):
    return f"{ALPHA=}, {EPSILON=}/{"attention_query" if attention_query else "no_query"}/{f_name}"

def get_perturbation_change(trace, original):
    a = original - EPSILON
    b = (trace >= a).float() * trace + (trace < a).float() * a
    c = (b > original + EPSILON).float() * (original + EPSILON) + (b <= original + EPSILON).float() * b
    return max_to_one(c)

def get_performance(context_chosen, event_chosen, attention_query=False):
    if attention_query:
        context_processed = to_cuda(torch.tensor(to_trace(context_chosen))).unsqueeze(0)
        pred_true = get_correct_prediction_for_list(context_processed, event_chosen.unsqueeze(0).unsqueeze(0), attention_query=True)[1]
        return len(pred_true) == 0
    output = to_output(context_chosen)
    results_picked = torch.topk(output[0][0][0], len(output[0][0][0]))
    exp = results_picked.values.exp()
    res_indices = results_picked.indices
    for j in range(len(output[0][0][0])):
        if res_indices[j].item() == event_chosen:
            return exp[j] < PREDICT_THRESHOLD
    return False

def get_changes_list(start, final):
    perturbations_made = []
    for i, (s, f) in enumerate(zip(start, final)):
        if s != f:
            perturbations_made.append((i, f.item()))
    return perturbations_made

def get_iter_count(context_given, target_given):
    for current_trace_num in range(len(context_given)):
        con, e = context_given[current_trace_num].unsqueeze(0), target_given.unsqueeze(1)[current_trace_num]
        for i in range(MAX_ITER):
            new_context = bim_no_iter(con, e, i)
            if con[0].tolist() != new_context.tolist():
                return i
    return -1

def get_perturbations(context_chosen, event_chosen, attention_query=False):
    perturbed_collected_main = []
    perturbed_indices_main = []
    pred_false, pred_true = get_correct_prediction_for_list(context_chosen, event_chosen.unsqueeze(1), attention_query=attention_query)
    states = [len(pred_false), 0, 0]
    for current_trace_num in tqdm(pred_true):
        con, e = context_chosen[current_trace_num].unsqueeze(0), event_chosen.unsqueeze(1)[current_trace_num]
        perturbed_result, perturb_iterations = bim_attack(con, e, attention_query=attention_query)
        if perturbed_result is not None:
            states[1] += 1
            perturbed_collected_main.append(perturbed_result)
            perturbed_indices_main.append(current_trace_num)
        else:
            states[2] += 1
    return to_cuda(torch.tensor(perturbed_collected_main)), to_cuda(torch.tensor(perturbed_indices_main)), to_cuda(torch.tensor(states))

def get_perturbations_or_file(context_chosen, event_chosen, attention_query=False):
    f_name_perturbed = get_file_name("perturbed_collected.pt", attention_query=attention_query)
    f_name_indices = get_file_name("perturbed_indices.pt", attention_query=attention_query)
    f_name_distribution = get_file_name("perturbed_distribution.pt", attention_query=attention_query)
    if os.path.exists(f_name_perturbed) and os.path.exists(f_name_indices) and os.path.exists(f_name_distribution):
        print(f"Loading {f_name_perturbed}")
        print(f"Loading {f_name_indices}")
        print(f"Loading {f_name_distribution}")
        print(torch.load(f_name_distribution))
        return torch.load(f_name_perturbed), torch.load(f_name_indices), torch.load(f_name_distribution)
    perturb_main, indices_main, result_main = get_perturbations(context_chosen, event_chosen, attention_query=attention_query)
    torch.save(perturb_main, f_name_perturbed)
    torch.save(indices_main, f_name_indices)
    torch.save(result_main, f_name_distribution)
    print(result_main)
    return perturb_main, indices_main, result_main

def get_possible_combinations(perturbations_made):
    subsets = []
    for r_index in range(1, len(perturbations_made)):
        subsets.extend(itertools.combinations(perturbations_made, r_index))
    result = [list(subset) for subset in subsets]
    return result

def get_minimum_change_for_perturbation_no_query(perturbed_chosen, context_chosen, events_chosen):
    for combination in get_possible_combinations(get_changes_list(context_chosen, perturbed_chosen)):
        copy = to_cuda(context_chosen.clone().detach())
        for index_of_change, value_of_change in combination:
            copy[index_of_change] = value_of_change
        if get_performance(to_one_hot(copy.unsqueeze(0)), events_chosen):
            return copy
    return perturbed_chosen

def get_minimum_change_for_perturbation_attention_query(perturbed_chosen, context_chosen, events_chosen):
    trace_combinations = []
    for combination in get_possible_combinations(get_changes_list(context_chosen, perturbed_chosen)):
        copy = to_cuda(context_chosen.clone().detach())
        for index_of_change, value_of_change in combination:
            copy[index_of_change] = value_of_change
        trace_combinations.append(copy)
    if len(trace_combinations) == 0:
        return perturbed_chosen
    mask_indices = get_correct_prediction_for_list(torch.stack(trace_combinations), torch.full((len(trace_combinations), 1), events_chosen.item()))[0]
    chosen = perturbed_chosen
    if len(mask_indices) != 0:
        chosen = trace_combinations[mask_indices[0]]
    return chosen
    
def get_correct_prediction_for_list(context_chosen, events_chosen, attention_query=False):
    _, mask = interpreter.attended_context(
        X           = to_one_hot(context_chosen),
        y           = to_cuda(events_chosen),
        iterations  = 100 if attention_query else 0
    )
    return torch.where(~mask)[0], torch.where(mask)[0]
    
def get_shortcuts_or_file(perturbed_chosen, context_chosen, events_chosen, attention_query=False):
    f_name = get_file_name("shortcuts.pt", attention_query=attention_query)
    if os.path.exists(f_name):
        print(f"Loading {f_name}")
        return torch.load(f_name)
    get_shortcuts_func = get_minimum_change_for_perturbation_attention_query if attention_query else get_minimum_change_for_perturbation_no_query
    pick_list = []
    for perturbed_element  in tqdm(zip(perturbed_chosen, context_chosen, events_chosen), total=len(perturbed_chosen)):
        pick_list.append(get_shortcuts_func(*perturbed_element))
    pick_list =  to_cuda(torch.stack(pick_list))
    torch.save(pick_list, f_name)
    return pick_list
    
    
def get_dist(trace_dist):
    tensor_tuples = [tuple(t) if len(t) > 0 else (t.item(),) for t in trace_dist]
    index_groups = {}
    for idx, tensor_tuple in enumerate(tensor_tuples):
        if tensor_tuple not in index_groups:
            index_groups[tensor_tuple] = []
        index_groups[tensor_tuple].append(idx)
    sorted_groups = sorted(index_groups.items(), key=lambda x: (len(x[0]), x[0]))
    return [idx for _, indices_l in sorted_groups for idx in indices_l]
    
def get_matrix(context_chosen, shortcut_chosen, perturb_chosen):
    matrix = [[0] * 10 for _ in range(10)]
    for c, s, p in zip(context_chosen, shortcut_chosen, perturb_chosen):
        cs = len(get_changes_list(c, s)) - 1
        cp = len(get_changes_list(c, p)) - 1
        matrix[cp][cs] += 1
    for m in matrix:
        print(f"[{", ".join([f'{num:3}' for num in m])}]")
    
def interpret_query(context_passed, events_passed, threshold=0.2):
    c = to_one_hot(context_passed)
    e = events_passed.reshape(-1, 1)
    interpreter.threshold = threshold
    return interpreter.predict(X=c, y=e)

def interpret(context_passed, events_passed, threshold=0.2):
    c = to_one_hot(context_passed)
    e = events_passed.reshape(-1, 1)
    interpreter.threshold = threshold
    return interpreter.predict(X=c, y=e, iterations=0)
    
def get_combined(perturbed_chosen, perturbed_indices_chosen, attention_query=False):
    interpret_func = interpret_query if attention_query else interpret
    context_test_copy = context_test.clone().detach()
    context_test_copy[perturbed_indices_chosen] = perturbed_chosen
    context_test_copy_indices = get_unique_indices_per_row(context_test_copy)
    return interpret_func(context_test_copy[context_test_copy_indices], events_test[context_test_copy_indices]), context_test_copy_indices
    
def bim_attack(context_chosen, event_chosen, attention_query=False):
    change = 0
    original_context = to_one_hot(context_chosen)
    context_processed = to_one_hot(context_chosen)
    for iteration in range(MAX_ITER):
        context_processed.requires_grad_(True)
        output = builder.predict(context_processed)
        start = time.time()
        if get_performance(context_processed, event_chosen[0], attention_query=attention_query):
            return to_trace(context_processed), iteration
        end = time.time()
        # print(f"{end - start=}")
        loss = criterion(output[0][0], event_chosen)
        context_processed.retain_grad()
        loss.backward(retain_graph=True)
        grad = context_processed.grad.sign()
        change += ALPHA * grad
        context_processed = get_perturbation_change(context_processed + change, original_context)
    return None, MAX_ITER
    
def format_results(results):
    results_picked = torch.topk(results[0][0][0], 3)
    exp = results_picked.values.exp()
    res_indices = results_picked.indices
    return ", ".join([f"{format_list([res_indices[j].item()])} {'{:.3f}'.format(exp[j])}" for j in range(3)])
  
def format_changes(start, final):
    changes = []
    same = []
    for s, f in zip(start, final):
        if s == f:
            changes.append("-")
            same.append(final)
        else:
            changes.append(f)
            same.append("XX")
    return format_list(changes), format_list(same)

def format_list(li):
    return f"[{", ".join([f'{num:2}' for num in li])}]"
            
def format_trace_prediction(current_trace_num, trace):
    start = " "*(len(str(current_trace_num)) + 2)
    trace = to_cuda(trace.clone().detach()).unsqueeze(0)
    return f"{start}{format_list(trace)} -> {format_results(to_output(trace))}\n"
            
def format_perturbation(perturb_chosen, perturb_int, context_chosen, event_chosen, current_trace_num):
    result_string = f"{current_trace_num}: {format_list(context_chosen[0].tolist())} == {event_chosen.tolist()} Changed [{perturb_int}], Perturbations [{len(get_changes_list(context_chosen, perturb_chosen))}]\n"
    result_string += format_trace_prediction(current_trace_num, context_chosen)    
    result_string += format_trace_prediction(current_trace_num, perturb_chosen)    
    changed_entries, same_entries = format_changes(context_chosen, perturb_chosen)
    result_string += f"{" "*(len(str(current_trace_num)) - 1)}== {same_entries}\n{" "*(len(str(current_trace_num)) - 1)}-> {changed_entries}\n\n"
    return result_string

def format_attention(x_mask, context_mask, vectors, neighbours, distance, scores):
    data_collected = [] 
    for i in range(len(context_mask)):
        l_value = torch.tensor(vectors[context_mask][i].toarray()[0])
        l_indices = torch.nonzero(l_value, as_tuple=False).squeeze(1)
        data_collected.append({
            "trace": x_mask[context_mask][i],
            "indices": l_indices.tolist(),
            "value": l_value[l_indices].tolist(),
            "neighbour": f"{{{neighbours[i]:5}; {scores[i]} | {'{:.4f}'.format(distance[i][0]) }}}"
        })
    res_str = ""
    for s in get_dist(list(map(lambda x: x["indices"], data_collected))):
        local_l = data_collected[s]
        local_list = list(zip(local_l["indices"], local_l["value"]))
        res = f"{format_list(local_l["trace"].tolist())} -> "
        for index, value in sorted(local_list, key=lambda x: x[1], reverse=True):
            res += f"[{f'{index:2}'}]: {'{:.4f}'.format(value)} "
        res_str += res + f"{" "*(150 - len(res))} {local_l["neighbour"]}\n"
    return res_str

def format_series(series):
    return pd.Series(series).value_counts().sort_index()

def format_confusion_matrix(y_true, y_pred):
    format_labels = sorted(list(set(y_true.tolist()) | set(y_pred.tolist())))
    return pd.DataFrame(confusion_matrix(y_true, y_pred), index=format_labels, columns=format_labels)

def show_clusters(x, y):
    vectors, mask = interpreter.attended_context(to_one_hot(x), y.reshape(-1, 1))
    indices_y = group_by(y[mask].reshape(-1, 1).cpu().numpy(), lambda e: e.data.tobytes(),)
    with open(f"{ALPHA=}, {EPSILON=}/cluster.txt", "w") as file:
        for event, context_mask in indices_y:
            event = ord(event.decode('ascii')[0])   
            if event not in interpreter.tree:
                file.write(f"{"<"*50}[{event}]{">"*50}\n")
                continue
            file.write(f"{"="*50}[{event}]{"="*50}\n")
            vectors_, inverse, _ = sp_unique(vectors[context_mask])
            distance, neighbours = interpreter.tree[event].query(
                X               = vectors_.toarray(),
                return_distance = True,
                dualtree        = vectors_.shape[0] >= 1e3, # Optimization
            )
            neighbours = interpreter.tree[event].get_arrays()[1][neighbours][:, 0]
            scores = np.asarray([interpreter.labels[event][neighbour] for neighbour in neighbours])
            file.write(format_attention(x[mask], context_mask, vectors, neighbours[inverse], distance[inverse], scores[inverse]) + "\n")

def bim_no_iter(context_chosen, event_chosen, iters):
    context_processed = to_one_hot(context_chosen)
    original_context = to_one_hot(context_chosen)
    context_processed.requires_grad_(True)
    output = builder.predict(context_processed)
    loss = criterion(output[0][0], event_chosen)
    context_processed.retain_grad()
    loss.backward(retain_graph=True)
    context_processed = get_perturbation_change(context_processed + (ALPHA * context_processed.grad.sign()) * iters, original_context)
    return to_cuda(torch.tensor(torch.argmax(context_processed, dim=-1).tolist()[0]))

def run_bim(context_to_process, events_to_process):
    perturbed_collected_main = []
    perturbed_indices_main = []
    states = [0, 0, 0]
    length = len(context_to_process)
    iters = tqdm(range(length))
    iters_to_flip = get_iter_count(context_to_process, events_to_process)
    for current_trace_num in iters:
        con, e = context_to_process[current_trace_num].unsqueeze(0), events_to_process.unsqueeze(1)[current_trace_num]
        if get_performance(to_one_hot(con), e[0]):
            states[0] += 1
            continue
        perturbed = bim_no_iter(con, e, iters=iters_to_flip).unsqueeze(0)
        if get_performance(perturbed, e[0]):
            perturbed_collected_main.append(perturbed.tolist()[0])
            perturbed_indices_main.append(current_trace_num)
            states[1] += 1
        else:
            states[2] += 1
    print(f"incorrect={states[0]} perturbed={states[1]} timeout={states[2]}")
    return to_cuda(torch.tensor(perturbed_collected_main)), to_cuda(torch.tensor(perturbed_indices_main))

# Base Data

In [2]:
format_series(labels_test.cpu())

1      65
2     454
3    3864
5       9
Name: count, dtype: int64

### Without Attention Query

In [3]:
res_normal = interpret(context_test, events_test)
format_series(res_normal)

-3.0    2675
-1.0    1074
 2.0     147
 3.0     480
 5.0      16
Name: count, dtype: int64

In [4]:
format_confusion_matrix(labels_test.cpu(), res_normal)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,33,20,0,0,0,12
2.0,221,86,0,147,0,0
3.0,2420,964,0,0,480,0
5.0,1,4,0,0,0,4


### Attention Query

In [5]:
res_normal_query = interpret_query(context_test, events_test)
format_series(res_normal_query)

-3.0     467
-1.0     490
 2.0     378
 3.0    3025
 5.0      32
Name: count, dtype: int64

In [6]:
format_confusion_matrix(labels_test.cpu(), res_normal_query)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,25,13,0,0,0,27
2.0,21,55,0,378,0,0
3.0,420,419,0,0,3025,0
5.0,1,3,0,0,0,5


# Perturbed Data

### Without Attention Query

In [152]:
perturbed_collected, perturbed_indices, perturb_distribution = get_perturbations_or_file(context_test, events_test)

100%|██████████| 3318/3318 [02:11<00:00, 25.21it/s]

tensor([1074, 3033,  285], device='cuda:0')





In [153]:
print("Unique:", len(get_unique_indices_per_row(perturbed_collected)))
interpreter_perturbed = interpret(perturbed_collected, events_test[perturbed_indices])
format_series(interpreter_perturbed)

Unique: 2409


-1.0    3033
Name: count, dtype: int64

In [154]:
format_confusion_matrix(labels_test[perturbed_indices].cpu(), interpret(context_test[perturbed_indices], events_test[perturbed_indices]))

Unnamed: 0,-3.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0
1.0,33,0,0,0,2
2.0,163,0,121,0,0
3.0,2315,0,0,398,0
5.0,1,0,0,0,0


In [155]:
format_confusion_matrix(labels_test[perturbed_indices].cpu(), interpreter_perturbed)

Unnamed: 0,-1.0,1.0,2.0,3.0,5.0
-1.0,0,0,0,0,0
1.0,35,0,0,0,0
2.0,284,0,0,0,0
3.0,2713,0,0,0,0
5.0,1,0,0,0,0


In [156]:
interpret_perturbed_combined, interpret_perturbed_indices = get_combined(perturbed_collected, perturbed_indices)
format_series(interpret_perturbed_combined)

-3.0     162
-1.0    3476
 2.0      19
 3.0      82
 5.0      14
Name: count, dtype: int64

In [157]:
format_confusion_matrix(labels_test[interpret_perturbed_indices].cpu(), interpret_perturbed_combined)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,0,49,0,0,0,10
2.0,57,310,0,19,0,0
3.0,105,3113,0,0,82,0
5.0,0,4,0,0,0,4


### Applying No Attention Query Data to Attention Query

In [158]:
print("Unique:", len(get_unique_indices_per_row(perturbed_collected)))
interpreter_perturbed_1 = interpret_query(perturbed_collected, events_test[perturbed_indices])
format_series(interpreter_perturbed_1)

Unique: 2409


-3.0     127
-1.0    2781
 2.0      20
 3.0     105
Name: count, dtype: int64

In [159]:
format_confusion_matrix(labels_test[perturbed_indices].cpu(), interpreter_perturbed_1)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,2,33,0,0,0,0
2.0,9,255,0,20,0,0
3.0,116,2492,0,0,105,0
5.0,0,1,0,0,0,0


In [162]:
interpret_perturbed_combined_1, interpret_perturbed_indices_1 = get_combined(perturbed_collected, perturbed_indices, attention_query=True)
format_series(interpret_perturbed_combined_1)

-3.0     358
-1.0    2651
 2.0     107
 3.0     623
 5.0      14
Name: count, dtype: int64

In [163]:
format_confusion_matrix(labels_test[interpret_perturbed_indices_1].cpu(), interpret_perturbed_combined_1)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,9,40,0,0,0,10
2.0,27,252,0,107,0,0
3.0,321,2356,0,0,623,0
5.0,1,3,0,0,0,4


### Attention Query

In [164]:
perturbed_collected_query, perturbed_indices_query, perturb_distribution_query = get_perturbations_or_file(context_test, events_test, attention_query=True)

100%|██████████| 3902/3902 [1:10:41<00:00,  1.09s/it]

tensor([ 490, 3598,  304], device='cuda:0')





In [165]:
print("Unique:", len(get_unique_indices_per_row(perturbed_collected_query)))
interpret_query_perturbed = interpret_query(perturbed_collected_query, events_test[perturbed_indices_query])
format_series(interpret_query_perturbed)

Unique: 2805


-1.0    3598
Name: count, dtype: int64

In [166]:
format_confusion_matrix(labels_test[perturbed_indices].cpu(), interpret_query(context_test[perturbed_indices], events_test[perturbed_indices]))

Unnamed: 0,-3.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0
1.0,18,0,0,0,17
2.0,2,0,282,0,0
3.0,209,0,0,2504,0
5.0,0,0,0,0,1


In [168]:
format_confusion_matrix(labels_test[perturbed_indices_query].cpu(), interpret_query_perturbed)

Unnamed: 0,-1.0,1.0,2.0,3.0,5.0
-1.0,0,0,0,0,0
1.0,42,0,0,0,0
2.0,310,0,0,0,0
3.0,3244,0,0,0,0
5.0,2,0,0,0,0


In [169]:
interpret_query_perturbed_combined, _ = get_combined(perturbed_collected, perturbed_indices)
format_series(interpret_query_perturbed_combined)

-3.0     162
-1.0    3476
 2.0      19
 3.0      82
 5.0      14
Name: count, dtype: int64

# Shortcuts

### Without Attention Query

In [170]:
perturbed_minimized = get_shortcuts_or_file(
    perturbed_collected, 
    context_test[perturbed_indices], 
    events_test[perturbed_indices]
)

100%|██████████| 3033/3033 [00:28<00:00, 108.25it/s]


In [171]:
get_matrix(context_test[perturbed_indices], perturbed_minimized, perturbed_collected)

[ 50,   0,   0,   0,   0,   0,   0,   0,   0,   0]
[109,   2,   0,   0,   0,   0,   0,   0,   0,   0]
[252,   7,   2,   0,   0,   0,   0,   0,   0,   0]
[302,   9,   0,   0,   0,   0,   0,   0,   0,   0]
[357,  10,   2,   1,   0,   0,   0,   0,   0,   0]
[387,  15,   0,   2,   1,   0,   0,   0,   0,   0]
[357,   3,   0,   0,   0,   0,   1,   0,   0,   0]
[398,   0,   0,   0,   0,   0,   0,   0,   0,   0]
[354,   1,   0,   0,   1,   1,   0,   0,   0,   0]
[409,   0,   0,   0,   0,   0,   0,   0,   0,   0]


In [172]:
print("Unique:", len(get_unique_indices_per_row(perturbed_minimized)))
res_perturbed_shortcuts = interpret(perturbed_minimized, events_test[perturbed_indices])
format_series(res_perturbed_shortcuts)

Unique: 3009


-1.0    3033
Name: count, dtype: int64

In [173]:
format_confusion_matrix(labels_test[perturbed_indices].cpu(), res_perturbed_shortcuts)

Unnamed: 0,-1.0,1.0,2.0,3.0,5.0
-1.0,0,0,0,0,0
1.0,35,0,0,0,0
2.0,284,0,0,0,0
3.0,2713,0,0,0,0
5.0,1,0,0,0,0


In [174]:
interpret_shortcuts_combined, interpret_shortcuts_indices = get_combined(perturbed_minimized, perturbed_indices)
format_series(interpret_shortcuts_combined)

-3.0     160
-1.0    4081
 2.0      25
 3.0      82
 5.0      14
Name: count, dtype: int64

In [175]:
format_confusion_matrix(labels_test[interpret_shortcuts_indices].cpu(), interpret_shortcuts_combined)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,0,55,0,0,0,10
2.0,55,367,0,25,0,0
3.0,105,3654,0,0,82,0
5.0,0,5,0,0,0,4


### Applying No Attention Query Data to Attention Query

In [176]:
print("Unique:", len(get_unique_indices_per_row(perturbed_minimized)))
res_perturbed_shortcuts_1 = interpret_query(perturbed_minimized, events_test[perturbed_indices])
format_series(res_perturbed_shortcuts_1)

Unique: 3009


-3.0    1434
-1.0     396
 2.0     159
 3.0    1038
 5.0       6
Name: count, dtype: int64

In [177]:
format_confusion_matrix(labels_test[perturbed_indices].cpu(), res_perturbed_shortcuts_1)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,29,0,0,0,0,6
2.0,90,35,0,159,0,0
3.0,1314,361,0,0,1038,0
5.0,1,0,0,0,0,0


In [178]:
interpret_shortcuts_combined_1, interpret_shortcuts_indices_1 = get_combined(perturbed_minimized, perturbed_indices, attention_query=True)
format_series(interpret_shortcuts_combined_1)

-3.0    1668
-1.0     876
 2.0     250
 3.0    1548
 5.0      20
Name: count, dtype: int64

In [179]:
format_confusion_matrix(labels_test[interpret_shortcuts_indices_1].cpu(), interpret_shortcuts_combined_1)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,36,13,0,0,0,16
2.0,109,88,0,250,0,0
3.0,1521,772,0,0,1548,0
5.0,2,3,0,0,0,4


### Attention Query

In [190]:
perturbed_minimized_query = get_shortcuts_or_file(
    perturbed_collected_query,
    context_test[perturbed_indices_query], 
    events_test[perturbed_indices_query],
    attention_query=True
)

100%|██████████| 3598/3598 [02:38<00:00, 22.73it/s]


In [192]:
get_matrix(context_test[perturbed_indices_query], perturbed_minimized_query, perturbed_collected_query)

[ 70,   0,   0,   0,   0,   0,   0,   0,   0,   0]
[173,   0,   0,   0,   0,   0,   0,   0,   0,   0]
[275,   6,   0,   0,   0,   0,   0,   0,   0,   0]
[380,   9,   0,   0,   0,   0,   0,   0,   0,   0]
[381,  10,   1,   0,   0,   0,   0,   0,   0,   0]
[441,  15,   0,   0,   0,   0,   0,   0,   0,   0]
[413,   3,   0,   0,   0,   0,   0,   0,   0,   0]
[490,   0,   0,   0,   0,   0,   0,   0,   0,   0]
[423,   1,   0,   0,   0,   0,   0,   0,   0,   0]
[507,   0,   0,   0,   0,   0,   0,   0,   0,   0]


In [196]:
print("Unique:", len(get_unique_indices_per_row(perturbed_minimized_query)))
res_perturbed_shortcuts_query = interpret_query(perturbed_minimized_query, events_test[perturbed_indices_query])
format_series(res_perturbed_shortcuts_query)

Unique: 3572


-3.0    1636
-1.0     446
 2.0     172
 3.0    1338
 5.0       6
Name: count, dtype: int64

In [197]:
format_confusion_matrix(labels_test[perturbed_indices_query].cpu(), res_perturbed_shortcuts_query)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,36,0,0,0,0,6
2.0,100,38,0,172,0,0
3.0,1498,408,0,0,1338,0
5.0,2,0,0,0,0,0


In [199]:
interpret_query_shortcuts_combined_query, interpret_query_shortcuts_indices_query = get_combined(perturbed_minimized_query, perturbed_indices_query)
format_series(interpret_query_shortcuts_combined_query)

-3.0     164
-1.0    4075
 2.0      25
 3.0      82
 5.0      14
Name: count, dtype: int64

In [200]:
format_confusion_matrix(labels_test[interpret_query_shortcuts_indices_query].cpu(), interpret_query_shortcuts_combined_query)

Unnamed: 0,-3.0,-1.0,1.0,2.0,3.0,5.0
-3.0,0,0,0,0,0,0
-1.0,0,0,0,0,0,0
1.0,0,55,0,0,0,10
2.0,56,366,0,25,0,0
3.0,108,3649,0,0,82,0
5.0,0,5,0,0,0,4


In [201]:
show_clusters(context_test, events_test)