In [16]:
import os
import torch
import itertools
import pandas as pd
from tqdm import tqdm
from deepcase_copy.context_builder.context_builder import ContextBuilder
from deepcase_copy.interpreter.interpreter import Interpreter
from deepcase_copy.interpreter.utils import group_by
from deepcase_copy.context_builder.loss import LabelSmoothing
from sklearn.model_selection import train_test_split

def disable_dropout(m):
    if isinstance(m, torch.nn.Dropout):
        m.p = 0.0 

def to_cuda(item):
    if torch.cuda.is_available():
        return item.to('cuda')
    return item

def get_unique_indices_per_row(tensor, f_name):
    if os.path.exists(f_name):
        return torch.load(f_name)
    indices_list = []
    row_list = []
    indices_list_set = set()
    for row in range(len(tensor)):
        curr = tuple(tensor[row].tolist())
        if curr in indices_list_set:
            continue
        row_list.append(curr)
        indices_list.append(row)
        indices_list_set.add(curr)
    t = torch.tensor(indices_list)
    torch.save(t, f_name)
    return t  

ALPHA=0.01
EPSILON=0.5
MAX_ITER = 100
PERTURB_THRESHOLD = 3
builder = to_cuda(ContextBuilder.load('save/save/builder.save'))
builder.apply(disable_dropout)
interpreter = Interpreter.load('save/save/interpreter.save', builder)
criterion = LabelSmoothing(builder.decoder_event.out.out_features, 0.1)    

with open('save/save/sequences.save', 'rb') as infile:
    data = torch.load(infile)
    events  = data["events"]
    context = data["context"] # 172572
    labels  = data["labels"]
    
    indices = get_unique_indices_per_row(context, 'save/save/context.pt')    
    train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42, stratify=labels[indices])
    
    context_train = to_cuda(context[train_indices]) 
    context_test  = to_cuda(context[test_indices]) 
    
    events_train  = to_cuda(events[train_indices])
    events_test   = to_cuda(events[test_indices])
    
    labels_train  = to_cuda(labels[train_indices])  # 34514
    labels_test   = to_cuda(labels[test_indices])  # 138058
    
    clusters        = pd.read_csv('save/clusters.csv')['clusters'].values
    clusters_train  = clusters[train_indices]
    clusters_test   = clusters[test_indices]

    prediction_label        = pd.read_csv('save/prediction.csv')['labels'].values
    prediction_label_train  = prediction_label[train_indices]
    prediction_label_test   = prediction_label[test_indices]

In [17]:
def to_one_hot(t):
    return builder.embedding_one_hot(t)

def max_to_one(tensor):
    max_indices = torch.argmax(tensor, dim=-1, keepdim=True)
    result = torch.zeros_like(tensor)
    result.scatter_(-1, max_indices, 1.0)
    return result

def get_results(results):
    results_picked = torch.topk(results[0][0][0], 3)
    exp = results_picked.values.exp()
    res_indices = results_picked.indices
    s = []
    for j in range(3):
        s.append(f"{format_list([res_indices[j].item()])} {'{:.3f}'.format(exp[j])}")
    return res_indices, ", ".join(s)

def compute_change(trace, original, epsilon=EPSILON):
    a = original - epsilon
    b = (trace >= a).float() * trace + (trace < a).float() * a
    c = (b > original + epsilon).float() * (original + epsilon) + (b <= original + epsilon).float() * b
    return max_to_one(c), c

def bim_attack(context_given, target_given, alpha=0.1, epsilon=EPSILON, num_iterations=MAX_ITER, training=True):
    change = None
    original_context = to_cuda(to_one_hot(context_given).clone().detach())
    context_processed = to_cuda(to_one_hot(context_given).clone().detach())
    changes = []
    computed_change_collected = []
    for i in range(num_iterations):
        context_processed.requires_grad_(True)
        output = builder.predict(context_processed, training=training)
        indices_of_results, prediction_str = get_results(output)
        changes.append({
            "changed_to": torch.argmax(context_processed, axis=-1).tolist()[0],
            "prediction_str": prediction_str,
        })
        if target_given[0] != indices_of_results[0]:
            break
        loss = criterion(output[0][0], target_given)
        context_processed.retain_grad()
        loss.backward(retain_graph=True)
        grad = context_processed.grad.sign()
        if change is None:
            change = alpha * grad
        else:
            change += alpha * grad
        context_processed = context_processed + change
        context_processed, computed_change = compute_change(context_processed, original_context, epsilon)
        computed_change_collected.append(computed_change)
    return changes, computed_change_collected

def count_changes(changes_needed):
    orig = changes_needed[0]["changed_to"]
    final = changes_needed[-1]["changed_to"]
    return count_list_diff(orig, final)

def count_list_diff(orig, final):
    changed_entries = 0
    for orig, final in zip(orig, final):
        if orig != final:
            changed_entries += 1
    return changed_entries
    
def show_changes(changes_needed):
    orig = changes_needed[0]["changed_to"]
    final = changes_needed[-1]["changed_to"]
    changes = []
    same = []
    for orig, final in zip(orig, final):
        if orig == final:
            changes.append("-")
            same.append(final)
        else:
            changes.append(final)
            same.append("XX")
    return format_list(changes), format_list(same)

def format_list(li):
    return f"[{", ".join([f'{num:2}' for num in li])}]"
            
def format_line(current_trace_num, changes_needed, i):
    start = " "*(len(str(current_trace_num)) + 2)
    if i != 0:
        if len(str(current_trace_num)) == 1:
            start = f'{i:<3}'
        elif len(str(current_trace_num)) == 2:
            start = f'{i:<4}'
        elif len(str(current_trace_num)) == 3:
            start = f'{i:<5}'
        elif len(str(current_trace_num)) == 4:
            start = f'{i:<6}'
        else:
            start = f'{i:<7}'
        b = list(start)
        b[len(str(i))] = '<'
        start = ''.join(b)
    return f"{start}{format_list(changes_needed[i]["changed_to"])} -> {changes_needed[i]['prediction_str']}\n"
            
def print_state(changes_needed, current_trace_num, con, event_chosen, change_collected, print_path=True, include_change=True, num_iterations=MAX_ITER):
    mode_int = 0
    changed_num = len(changes_needed)
    perturbations_num = count_changes(changes_needed)
    result_string = ""
    if changed_num == 1:
        pass
    elif changed_num == num_iterations:
        mode_int = 3
    else:
        mode_int = 1
        if perturbations_num <= PERTURB_THRESHOLD:
            mode_int = 2
            result_string += f"{current_trace_num}: {format_list(con[0].tolist())} == {event_chosen.tolist()} Changed {{{changed_num}}}, Perturbations {{{perturbations_num}}}\n"
            if print_path:
                for i in range(len(changes_needed)):
                    result_string += format_line(current_trace_num, changes_needed, i)                    
                    if include_change:
                        if i is not len(changes_needed) - 1:
                            result_string += f"{" "*(len(str(current_trace_num)) + 2)} {change_collected[i]=}\n"
            else:
                for i in range(len(changes_needed)):
                    if i == 0 or (changes_needed[i]["changed_to"] != changes_needed[i-1]["changed_to"] and i != len(changes_needed) - 1):
                        result_string += format_line(current_trace_num, changes_needed, i)
                change_last = changes_needed[-1]
                result_string += f"{" "*(len(str(current_trace_num)) + 2)}{format_list(change_last["changed_to"])} -> {change_last['prediction_str']}\n"
            changed_entries, same_entries = show_changes(changes_needed)
            result_string += f"{" "*(len(str(current_trace_num)) - 1)}== {same_entries}\n"
            result_string += f"{" "*(len(str(current_trace_num)) - 1)}-> {changed_entries}\n"
            result_string += "\n" 
    return mode_int, result_string

def process_traces(context_to_process, events_to_process, alpha=ALPHA, epsilon=EPSILON, num_iterations=MAX_ITER, chosen_index=0, print_path=False, include_change=False, write_to_file=False, print_result=False, training=True):
    length = len(context_to_process)
    dest = chosen_index+length
    perturbed_collected_main = []
    perturbed_indices = []
    states = [0, 0, 0, 0]
    safe_to_file = ""
    iters = range(length)
    if not print_result:
        iters = tqdm(iters)
    for current_trace_num in iters:
        con, e = context_to_process[chosen_index:dest][current_trace_num], events_to_process[chosen_index:dest].unsqueeze(1)[current_trace_num]
        con.resize_(1, con.size()[-1])
        changes_needed, change_collected = bim_attack(context_given=con, target_given=e, alpha=alpha, epsilon=epsilon, num_iterations=num_iterations, training=training)
        mode_int, result_string = print_state(changes_needed, current_trace_num, con, e, change_collected, print_path=print_path, include_change=include_change, num_iterations=num_iterations)
        if print_result:
            print(result_string, end="")
        safe_to_file += result_string
        if mode_int == 2:
            perturbed_indices.append(current_trace_num)
            perturbed_collected_main.append((current_trace_num, changes_needed[0]['changed_to'], changes_needed[-1]['changed_to']))
        states[mode_int] += 1
    print(f"incorrect={states[0]} changed={states[1]} perturbed={states[2]} timeout={states[3]}")
    safe_to_file += f"incorrect={states[0]} changed={states[1]} perturbed={states[2]} timeout={states[3]}"
    if write_to_file:
        with open(f"results_trace/length={length}, alpha={alpha}, epsilon={epsilon}, num_iterations={num_iterations}, print_path={print_path} include_change={include_change}.txt", "w") as f:
            f.write(safe_to_file)
    return perturbed_collected_main, perturbed_indices

def inspect_index(index_inspected, context_l, events_l, training=True):
    local_context_picked = to_cuda(context_l[index_inspected:index_inspected+1].clone().detach())
    local_events_picked = to_cuda(events_l[index_inspected:index_inspected+1].clone().detach())
    process_traces(local_context_picked, local_events_picked, alpha=ALPHA, epsilon=EPSILON, num_iterations=MAX_ITER, print_path=True, include_change=False, write_to_file=True, print_result=True, training=training)

In [18]:
def get_changes_list(s, f):
    perturbations_made = []
    for i in range(len(s)):
        if s[i] != f[i]:
            perturbations_made.append((i, f[i]))
            
    return perturbations_made

def get_possible_combinations(perturbations_made):
    subsets = []
    for r in range(1, len(perturbations_made) + 1):
        subsets.extend(itertools.combinations(perturbations_made, r))
    result = [list(subset) for subset in subsets]
    return result

def get_minimum_change_for_perturbation(perturbed_chosen, events_picked, index_in_list, training=True):
    i, s, f = perturbed_chosen[index_in_list]
    event_target = events_picked[i]
    combination_of_perturbation = get_possible_combinations(get_changes_list(s, f))
    for i in range(len(combination_of_perturbation)):
        combination = combination_of_perturbation[i]
        copy = to_cuda(torch.tensor(s).detach())
        for index_of_change, value_of_change in combination:
            copy[index_of_change] = value_of_change
        copy.resize_(1, copy.size()[-1])
        output = builder.predict(to_one_hot(copy), training=training)
        indices_of_results, _ = get_results(output)
        if event_target != indices_of_results[0]:
            return copy, f"Shortcut: {combination_of_perturbation[-1]} -> {combination}\n" if i == len(combination_of_perturbation) - 1 and len(combination_of_perturbation) != 1 else ""
    # print(f"ERROR: Could not find a perturbation for {index_in_list} [{i}]")
    # print(f"       Cannot go from {format_list(s)}")
    # print(f"                      {format_list(f)}")
    
    return None, False

def process_single(context_chosen, training=True):
    result_string = ""
    context_chosen = to_cuda(context_chosen)
    output = predict_single(context_chosen, training=training)
    attentions = [round(x, 5) for x in output[1][0][0].tolist()]
    indices_of_results, prediction_str = get_results(output)
    result_string += f"{format_list(context_chosen[0].tolist())} -> {prediction_str}\n"
    for c, a in zip(context_chosen[0], attentions):
        result_string += f"[{c:2}] {'{:.5f}'.format(a)} "
    result_string += "\n"
    return result_string
    
def predict_single(context_chosen, training=True):
    context_chosen.resize_(1, context_chosen.size()[-1])
    context_one_hot = to_one_hot(context_chosen)
    return builder.predict(context_one_hot, training=training)

def predict_single_from_list(list_picked, index_picked):
    return predict_single(to_cuda(list_picked[index_picked:index_picked+1].detach()))

def analysis(perturbed_chosen, events_picked, index_picked, training=True):
    result_string = ""
    i, s, f = perturbed_chosen[index_picked]
    minimum_change_for_perturbation, final_combination = get_minimum_change_for_perturbation(perturbed_chosen, events_picked, index_picked, training=training)
    if minimum_change_for_perturbation is not None:
        result_string += f"Analyzing [{i}], Perturbations [{count_list_diff(s, f)}], Index [{index_picked}]\n"
        result_string += process_single(torch.tensor(s), training=training)
        result_string += process_single(minimum_change_for_perturbation, training=training)
        result_string += final_combination
    else:
        result_string += f"Change did not work on [{i}], Index [{index_picked}]\n"
    result_string += "\n"
    return result_string

def store(perturbed_chosen, events_picked, chosen_index=0, alpha=ALPHA, epsilon=EPSILON, num_iterations=MAX_ITER, training=True, print_details=False):
    with open(f"results_attention/length={len(perturbed_chosen)}, alpha={alpha}, epsilon={epsilon}, num_iterations={num_iterations}.txt", "w") as file:
        safe_to_file = ""
        skipped = 0
        shortcuts = 0
        for perturbed_element in range(len(perturbed_chosen)):
            r_string = analysis(perturbed_chosen, events_picked[chosen_index:chosen_index+len(perturbed_chosen)], perturbed_element, training=training)
            if "Change did not work on" in r_string:
                skipped += 1
            if "Shortcut" in r_string:
                shortcuts += 1
            if print_details:
                print(r_string, end="")
            safe_to_file += r_string
        skipped_str = f"Results: {skipped=}/{shortcuts=} out of {len(perturbed_chosen)}"
        print(skipped_str)
        safe_to_file += skipped_str
        file.write(safe_to_file)

In [19]:
perturbed_collected, perturbed_indices_collected = process_traces(context_test, events_test, write_to_file=True)  

  2%|▏         | 96/4392 [00:11<08:51,  8.09it/s]


KeyboardInterrupt: 

In [148]:
store(perturbed_collected, events_picked=events_test)

IndexError: index 797 is out of bounds for dimension 0 with size 792

In [20]:
def interpret():
    c = to_one_hot(context_test)
    e = events_test.reshape(-1, 1)
    l = labels_test.cpu()    
    # interpreter.cluster(c, e)
    # scores = interpreter.score_clusters(l)
    # interpreter.score(scores)
    return interpreter.predict(X=c, y=e)


res_i = interpret()
pd.Series(res_i).value_counts().sort_index()

-3.0     467
-1.0     490
 2.0     378
 3.0    3025
 5.0      32
Name: count, dtype: int64

In [23]:
import numpy as np
from deepcase_copy.interpreter.utils import sp_unique


def sort_dist(trace_dist):
    tensor_tuples = [tuple(t) if len(t) > 0 else (t.item(),) for t in trace_dist]
    index_groups = {}
    for idx, tensor_tuple in enumerate(tensor_tuples):
        if tensor_tuple not in index_groups:
            index_groups[tensor_tuple] = []
        index_groups[tensor_tuple].append(idx)
    sorted_groups = sorted(index_groups.items(), key=lambda x: (len(x[0]), x[0]))
    return [idx for _, indices in sorted_groups for idx in indices]
    
def print_attention(x_mask, context_mask, vectors, neighbours, distance, scores):
    data_collected = [] 
    for i in range(len(context_mask)):
        l_value = torch.tensor(vectors[context_mask][i].toarray()[0])
        l_indices = torch.nonzero(l_value, as_tuple=False).squeeze(1)
        data_collected.append({
            "trace": x_mask[context_mask][i],
            "indices": l_indices.tolist(),
            "value": l_value[l_indices].tolist(),
            "neighbour": f"{{{neighbours[i]:5}; {scores[i]} | {'{:.4f}'.format(distance[i][0]) }}}"
        })
    for s in sort_dist(list(map(lambda x: x["indices"], data_collected))):
        local_l = data_collected[s]
        local_list = list(zip(local_l["indices"], local_l["value"]))
        res = f"{format_list(local_l["trace"].tolist())} -> "
        for index, value in sorted(local_list, key=lambda x: x[1], reverse=True):
            res += f"[{f'{index:2}'}]: {'{:.4f}'.format(value)} "
        res += f"{" "*(150 - len(res))} {local_l["neighbour"]}"
        print(res)
    print()

def show_clusters(X, y):
    vectors, mask = interpreter.attended_context(to_one_hot(X), y.reshape(-1, 1))
    indices_y = group_by(y[mask].reshape(-1, 1).cpu().numpy(), lambda x: x.data.tobytes(),)
    for event, context_mask in indices_y:
        event = ord(event.decode('ascii')[0])   
        if event not in interpreter.tree:
            print(f"{"<"*50}[{event}]{">"*50}")
            continue
        print(f"{"="*50}[{event}]{"="*50}")
        vectors_, inverse, _= sp_unique(vectors[context_mask])
        distance, neighbours = interpreter.tree[event].query(
            X               = vectors_.toarray(),
            return_distance = True,
            dualtree        = vectors_.shape[0] >= 1e3, # Optimization
        )
        neighbours = interpreter.tree[event].get_arrays()[1][neighbours][:, 0]
        scores = np.asarray([
            interpreter.labels[event][neighbour] for neighbour in neighbours
        ])
        print_attention(X[mask], context_mask, vectors, neighbours[inverse], distance[inverse], scores[inverse])
        
show_clusters(context_test, events_test)

[72, 72, 72,  0,  0,  0,  0,  0,  0, 72] -> [72]: 0.9598 [ 0]: 0.0402                                                                                  {  897; 3.0 | 0.0396}
[72, 72, 72, 72,  0,  0,  0,  0,  0,  0] -> [72]: 0.8938 [ 0]: 0.1062                                                                                  {  897; 3.0 | 0.1716}
[ 0, 72, 72, 72, 72, 72, 72, 72, 72, 72] -> [72]: 0.9981 [ 0]: 0.0019                                                                                  {  850; 3.0 | 0.0038}
[ 2,  2,  2,  2,  2,  2, 72, 72, 72, 72] -> [72]: 0.9783 [ 2]: 0.0217                                                                                  {  644; 3.0 | 0.0118}
[ 7, 72, 72, 72, 72, 72, 72, 72, 72, 72] -> [72]: 0.9971 [ 7]: 0.0029                                                                                  {  850; 3.0 | 0.0058}
[ 9,  9,  9,  9,  9,  9,  9, 72, 72, 72] -> [72]: 0.9716 [ 9]: 0.0284                                                                  