# Constants

In [1]:
# Dataset
DATASET_ITEM_NAMES = "dataset/case_study_items.csv"

# Encoding
MIN_ATTENTION = 0.2

# LFIT
ALGORITHM = "gula"
THREADS = 8
HEURISTICS = ["try_all_atoms", "max_coverage_dynamic", "max_coverage_static", "max_diversity"]
TARGET = "Next"
LFIT_INPUT_FILE = "tmp/lfit_input.csv"
TOP_K = 3
IN_OUT_FILES = [("tmp/docomo_sr-gnn_predictions.csv", "tmp/sr-gnn_lfit_output.csv"), ("tmp/docomo_stamp_predictions.csv", "tmp/stamp_lfit_output.csv"), ("tmp/docomo_narm_predictions.csv", "tmp/narm_lfit_output.csv")]

# Others
RANDOM_SEED = 0

# Imports

In [2]:
import pandas as pd
import random
import pylfit

random.seed(RANDOM_SEED)


In [3]:
df_items = pd.read_csv(DATASET_ITEM_NAMES)
item_name = dict()

for idx, row in df_items.iterrows():
    item_name[row["ItemId"]] = row["ItemName"]
print(item_name)

def decode(rule, features, targets, item_name):
    out = ""
    conditions = []
    rule_str = rule.logic_form(features,targets)
    head = rule_str.split(":-")[0]
    head = int(head.split("(")[1].replace(")",""))

    #out += "Next("+item_name[head] + ") :- "
    out += item_name[head] + " :- "

    body = rule_str.split(":-")[1]
    body = body.split(",")

    #print(head,body)
    valid = False
    for token in body:
        property = token.split("(")[0].strip()
        #print(property)
        if property == "last_item":
            var = token.split("(")[1].split(")")[0]
            if var == "?":
            #    out += "last(?), "
                continue
            else:
                out += "last(" + item_name[int(var)] + "), "
                conditions += ["last(" + item_name[int(var)] + ")"]
            valid = True
        else:
            var = int(token.split("_")[1].split("(")[0])
            var = item_name[var]
            val = token.split("(")[1].split(")")[0]
            if val == "False":
                continue
            #out += "appear(" + var + ", "+val+"), "
            out += var + ", "
            conditions += [var]
            valid = True
        
        #print(var,val)
    return item_name[head], conditions, out[:-2] + ".", valid

{1: 'ottvod_A_paid_in', 2: 'ottvod_A_paid_out', 3: 'ottvod_B_paid_in', 4: 'ottvod_B_paid_out', 5: 'ottvod_C_paid_in', 6: 'ottvod_C_free_in', 7: 'ottvod_C_paid_out', 8: 'ottvod_C_free_out', 9: 'cinema', 10: 'theme_park', 11: 'fast_food', 12: 'toy_mall'}


In [4]:
item_ids = set()

for (prediction_file, output_file) in IN_OUT_FILES:
    df_preds = pd.read_csv(prediction_file)
    inputs = []
    attentions = []
    predictions = []
    hit_rate = []
    for idx, row in df_preds.iterrows():
        input = [int(i) for i in row["Model_input"].replace("[","").replace("]","").split(", ")]
        inputs += [input]

        attention = [float(i) for i in row["Model_attention"].replace("[","").replace("]","").split(", ")][:len(row["Model_input"])]
        attention = attention[:len(input)]
        min_att = min(attention)
        max_att = max(attention)
        if min_att == max_att:
            attention = [1.0 for i in attention]
        else:
            attention = [round((i - min_att) / (max_att - min_att),3) for i in attention]
        
        attentions += [attention]

        prediction = [int(i) for i in row["Model_prediction"].replace("[","").replace("]","").split(", ")]
        predictions += [prediction]

        hit_rate += [float(row["Expected"] in prediction[:TOP_K])]

        for i in input:
            item_ids.add(i)

    df_preds["Model_attention"] = attentions
    df_preds["Model_input"] = inputs
    df_preds["Model_prediction"] = predictions

    display(df_preds)

    features = []
    targets = []

    # Apply attention as feature mask
    for idx, row in df_preds.iterrows():
        for target in row["Model_prediction"][:TOP_K]:
            features += [[i if row["Model_attention"][id] >= MIN_ATTENTION else "?" for id, i in enumerate(row["Model_input"])]]
            targets += [target]

    df_lfit = pd.DataFrame({"Features": features, "Targets": targets})
    display(df_lfit)

    LFIT_FEATURES = []

    # Compute appear property
    for item in item_ids:
        appear_col = "focus_"+str(item)
        appear_values = []

        for idx, row in df_lfit.iterrows():
            appear_values += [item in row["Features"]]
            
        df_lfit[appear_col] = appear_values
        LFIT_FEATURES += [appear_col]

    # Last item property
    last_item = []
    for idx, row in df_lfit.iterrows():
        last_item += [row["Features"][-1]]
    df_lfit["last_item"] = last_item
    LFIT_FEATURES += ["last_item"]

    df_lfit[TARGET] = df_lfit["Targets"]

    df_lfit.drop("Features",axis=1,inplace=True)
    df_lfit.drop("Targets",axis=1,inplace=True)

    df_lfit.to_csv(LFIT_INPUT_FILE,index=False)

    print(LFIT_FEATURES)
    display(df_lfit)

    # LFIT learning
    dataset = pylfit.preprocessing.discrete_state_transitions_dataset_from_csv(path=LFIT_INPUT_FILE, \
    feature_names=LFIT_FEATURES, target_names=[TARGET])
    model = pylfit.models.DMVLP(features=dataset.features, targets=dataset.targets)
    model.compile(algorithm=ALGORITHM) # model.compile(algorithm="pride")
    model.fit(dataset=dataset,heuristics=HEURISTICS, verbose=1, threads=THREADS)

    # Decode and select rules
    rules_data = []

    encoded_data = pylfit.algorithms.Algorithm.encode_transitions_set(dataset.data, dataset.features, dataset.targets)
    for var in [TARGET]:
        selected_rules = set()
        var_id = [i for i,j in dataset.targets].index(var)
        for val_id, val in enumerate(dataset.targets[var_id][1]):
            positives, negatives = pylfit.algorithms.pride.PRIDE.interprete(encoded_data, var_id, val_id)
            for r in model.rules:
                if r.head_variable != var_id or r.head_value != val_id:
                    continue

                positive_full_matches = 0
                for s in positives:
                    if r.matches(s):
                        positive_full_matches += 1
                #print(r.logic_form(dataset.features,dataset.targets))

                head, conditions, decoded_rule, valid = decode(r, dataset.features, dataset.targets, item_name)

                if valid and positive_full_matches > 0:# and decoded_rule not in selected_rules:
                    rules_data.append([head, conditions, len(positives), positive_full_matches, decoded_rule, r.size(), r.logic_form(dataset.features,dataset.targets), r.to_string()])
                    selected_rules.add(decoded_rule)

    df_rules = pd.DataFrame(rules_data, columns=["head_value", "conditions", "total_positives", "rule_positive_supports", "rule_decoded", "rule_size", "rule_logic_form", "rule_raw_form"])

    df_rules = df_rules.sort_values(["head_value","rule_positive_supports"],ascending=False)
    display(df_rules)

    df_rules.to_csv(output_file,index=False)

Unnamed: 0,Expected,Model_input,Model_attention,Model_prediction
0,10,"[1, 6, 7]","[1.0, 0.787, 0.0]","[2, 10, 4, 9, 12, 1, 11, 3, 8, 5, 6, 7]"
1,7,"[2, 3, 6]","[0.855, 1.0, 0.0]","[7, 1, 9, 8, 10, 4, 12, 11, 6, 2, 5, 3]"
2,7,"[2, 3, 6]","[0.855, 1.0, 0.0]","[7, 1, 9, 8, 10, 4, 12, 11, 6, 2, 5, 3]"
3,7,"[2, 3, 6]","[0.855, 1.0, 0.0]","[7, 1, 9, 8, 10, 4, 12, 11, 6, 2, 5, 3]"
4,7,"[2, 3, 6]","[0.855, 1.0, 0.0]","[7, 1, 9, 8, 10, 4, 12, 11, 6, 2, 5, 3]"
...,...,...,...,...
3371,6,"[2, 3, 7]","[1.0, 0.46, 0.0]","[6, 1, 10, 4, 7, 9, 5, 2, 12, 11, 3, 8]"
3372,6,"[2, 3, 7]","[1.0, 0.46, 0.0]","[6, 1, 10, 4, 7, 9, 5, 2, 12, 11, 3, 8]"
3373,4,"[7, 6, 5]","[1.0, 0.0, 0.027]","[4, 10, 1, 12, 9, 2, 3, 11, 8, 7, 5, 6]"
3374,9,"[6, 9, 9, 9]","[1.0, 0.0, 0.0, 0.0]","[9, 8, 12, 11, 4, 10, 7, 2, 5, 1, 6, 3]"


Unnamed: 0,Features,Targets
0,"[1, 6, ?]",2
1,"[1, 6, ?]",10
2,"[1, 6, ?]",4
3,"[2, 3, ?]",7
4,"[2, 3, ?]",1
...,...,...
10123,"[6, ?, ?, ?]",8
10124,"[6, ?, ?, ?]",12
10125,"[6, ?, ?]",9
10126,"[6, ?, ?]",8


['focus_1', 'focus_2', 'focus_3', 'focus_4', 'focus_5', 'focus_6', 'focus_7', 'focus_9', 'focus_10', 'focus_11', 'focus_12', 'last_item']


Unnamed: 0,focus_1,focus_2,focus_3,focus_4,focus_5,focus_6,focus_7,focus_9,focus_10,focus_11,focus_12,last_item,Next
0,True,False,False,False,False,True,False,False,False,False,False,?,2
1,True,False,False,False,False,True,False,False,False,False,False,?,10
2,True,False,False,False,False,True,False,False,False,False,False,?,4
3,False,True,True,False,False,False,False,False,False,False,False,?,7
4,False,True,True,False,False,False,False,False,False,False,False,?,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
10123,False,False,False,False,False,True,False,False,False,False,False,?,8
10124,False,False,False,False,False,True,False,False,False,False,False,?,12
10125,False,False,False,False,False,True,False,False,False,False,False,?,9
10126,False,False,False,False,False,True,False,False,False,False,False,?,8


Starting fit with GULA

Converting transitions to nparray...
Sorting transitions...
Grouping transitions by initial state...
Start learning over 8 threads


Unnamed: 0,head_value,conditions,total_positives,rule_positive_supports,rule_decoded,rule_size,rule_logic_form,rule_raw_form
211,toy_mall,[fast_food],13,3,toy_mall :- fast_food.,1,Next(12) :- focus_11(True).,0=3 :- 9=1.
212,toy_mall,[toy_mall],13,3,toy_mall :- toy_mall.,1,Next(12) :- focus_12(True).,0=3 :- 10=1.
213,toy_mall,[last(fast_food)],13,3,toy_mall :- last(fast_food).,1,Next(12) :- last_item(11).,0=3 :- 11=2.
214,toy_mall,[last(toy_mall)],13,3,toy_mall :- last(toy_mall).,1,Next(12) :- last_item(12).,0=3 :- 11=3.
219,toy_mall,[ottvod_C_free_in],13,3,toy_mall :- ottvod_C_free_in.,6,"Next(12) :- focus_1(False), focus_4(False), fo...","0=3 :- 0=0, 3=0, 5=1, 6=0, 8=0, 11=11."
...,...,...,...,...,...,...,...,...
823,cinema,"[ottvod_B_paid_in, ottvod_B_paid_out, last(ott...",18,1,"cinema :- ottvod_B_paid_in, ottvod_B_paid_out,...",3,"Next(9) :- focus_3(True), focus_4(True), last_...","0=11 :- 2=1, 3=1, 11=7."
824,cinema,"[ottvod_B_paid_in, last(ottvod_C_paid_in)]",18,1,"cinema :- ottvod_B_paid_in, last(ottvod_C_paid...",3,"Next(9) :- focus_3(True), focus_6(False), last...","0=11 :- 2=1, 5=0, 11=7."
825,cinema,"[ottvod_B_paid_in, ottvod_C_paid_in]",18,1,"cinema :- ottvod_B_paid_in, ottvod_C_paid_in.",4,"Next(9) :- focus_1(False), focus_3(True), focu...","0=11 :- 0=0, 2=1, 4=1, 7=0."
826,cinema,"[ottvod_B_paid_in, ottvod_B_paid_out, ottvod_C...",18,1,"cinema :- ottvod_B_paid_in, ottvod_B_paid_out,...",4,"Next(9) :- focus_3(True), focus_4(True), focus...","0=11 :- 2=1, 3=1, 4=1, 7=0."


Unnamed: 0,Expected,Model_input,Model_attention,Model_prediction
0,10,"[1, 6, 7]","[0.0, 0.776, 1.0]","[2, 4, 10, 12, 9, 3, 1, 5, 11, 8, 6, 7]"
1,7,"[2, 3, 6]","[0.0, 1.0, 0.07]","[7, 1, 9, 8, 10, 12, 4, 6, 3, 5, 2, 11]"
2,7,"[2, 3, 6]","[0.0, 1.0, 0.07]","[7, 1, 9, 8, 10, 12, 4, 6, 3, 5, 2, 11]"
3,7,"[2, 3, 6]","[0.0, 1.0, 0.07]","[7, 1, 9, 8, 10, 12, 4, 6, 3, 5, 2, 11]"
4,7,"[2, 3, 6]","[0.0, 1.0, 0.07]","[7, 1, 9, 8, 10, 12, 4, 6, 3, 5, 2, 11]"
...,...,...,...,...
3371,6,"[2, 3, 7]","[0.0, 1.0, 0.575]","[6, 1, 4, 10, 3, 7, 11, 2, 9, 5, 8, 12]"
3372,6,"[2, 3, 7]","[0.0, 1.0, 0.575]","[6, 1, 4, 10, 3, 7, 11, 2, 9, 5, 8, 12]"
3373,4,"[7, 6, 5]","[0.699, 0.0, 1.0]","[4, 1, 12, 9, 8, 2, 3, 7, 10, 11, 6, 5]"
3374,9,"[6, 9, 9]","[1.0, 0.0, 0.0]","[9, 8, 4, 1, 7, 2, 11, 5, 12, 6, 3, 10]"


Unnamed: 0,Features,Targets
0,"[?, 6, 7]",2
1,"[?, 6, 7]",4
2,"[?, 6, 7]",10
3,"[?, 3, ?]",7
4,"[?, 3, ?]",1
...,...,...
10123,"[6, ?, ?]",8
10124,"[6, ?, ?]",4
10125,"[6, ?, ?, ?]",9
10126,"[6, ?, ?, ?]",8


['focus_1', 'focus_2', 'focus_3', 'focus_4', 'focus_5', 'focus_6', 'focus_7', 'focus_9', 'focus_10', 'focus_11', 'focus_12', 'last_item']


Unnamed: 0,focus_1,focus_2,focus_3,focus_4,focus_5,focus_6,focus_7,focus_9,focus_10,focus_11,focus_12,last_item,Next
0,False,False,False,False,False,True,True,False,False,False,False,7,2
1,False,False,False,False,False,True,True,False,False,False,False,7,4
2,False,False,False,False,False,True,True,False,False,False,False,7,10
3,False,False,True,False,False,False,False,False,False,False,False,?,7
4,False,False,True,False,False,False,False,False,False,False,False,?,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
10123,False,False,False,False,False,True,False,False,False,False,False,?,8
10124,False,False,False,False,False,True,False,False,False,False,False,?,4
10125,False,False,False,False,False,True,False,False,False,False,False,?,9
10126,False,False,False,False,False,True,False,False,False,False,False,?,8


Starting fit with GULA

Converting transitions to nparray...
Sorting transitions...
Grouping transitions by initial state...
Start learning over 8 threads


Unnamed: 0,head_value,conditions,total_positives,rule_positive_supports,rule_decoded,rule_size,rule_logic_form,rule_raw_form
133,toy_mall,[toy_mall],9,5,toy_mall :- toy_mall.,1,Next(12) :- focus_12(True).,0=3 :- 10=1.
134,toy_mall,[last(toy_mall)],9,5,toy_mall :- last(toy_mall).,1,Next(12) :- last_item(12).,0=3 :- 11=2.
135,toy_mall,"[ottvod_C_paid_out, fast_food]",9,1,"toy_mall :- ottvod_C_paid_out, fast_food.",2,"Next(12) :- focus_7(True), focus_11(True).","0=3 :- 6=1, 9=1."
136,toy_mall,"[ottvod_C_paid_out, last(fast_food)]",9,1,"toy_mall :- ottvod_C_paid_out, last(fast_food).",2,"Next(12) :- focus_7(True), last_item(11).","0=3 :- 6=1, 11=1."
137,toy_mall,"[ottvod_C_paid_in, ottvod_C_paid_out]",9,1,"toy_mall :- ottvod_C_paid_in, ottvod_C_paid_out.",3,"Next(12) :- focus_4(False), focus_5(True), foc...","0=3 :- 3=0, 4=1, 6=1."
...,...,...,...,...,...,...,...,...
554,cinema,[ottvod_A_paid_out],11,1,cinema :- ottvod_A_paid_out.,6,"Next(9) :- focus_2(True), focus_3(False), focu...","0=11 :- 1=1, 2=0, 3=0, 5=0, 6=0, 11=10."
555,cinema,"[ottvod_B_paid_in, ottvod_B_paid_out]",11,1,"cinema :- ottvod_B_paid_in, ottvod_B_paid_out.",4,"Next(9) :- focus_2(False), focus_3(True), focu...","0=11 :- 1=0, 2=1, 3=1, 11=10."
559,cinema,"[ottvod_C_free_in, cinema]",11,1,"cinema :- ottvod_C_free_in, cinema.",3,"Next(9) :- focus_2(False), focus_6(True), focu...","0=11 :- 1=0, 5=1, 7=1."
561,cinema,"[ottvod_A_paid_out, cinema]",11,1,"cinema :- ottvod_A_paid_out, cinema.",3,"Next(9) :- focus_2(True), focus_6(False), focu...","0=11 :- 1=1, 5=0, 7=1."


Unnamed: 0,Expected,Model_input,Model_attention,Model_prediction
0,10,"[1, 6, 7]","[0.604, 1.0, 0.0]","[2, 10, 4, 3, 11, 1, 9, 7, 12, 5, 6, 8]"
1,7,"[2, 3, 6]","[0.133, 1.0, 0.0]","[7, 9, 1, 4, 10, 6, 8, 12, 5, 11, 3, 2]"
2,7,"[2, 3, 6]","[0.133, 1.0, 0.0]","[7, 9, 1, 4, 10, 6, 8, 12, 5, 11, 3, 2]"
3,7,"[2, 3, 6]","[0.133, 1.0, 0.0]","[7, 9, 1, 4, 10, 6, 8, 12, 5, 11, 3, 2]"
4,7,"[2, 3, 6]","[0.133, 1.0, 0.0]","[7, 9, 1, 4, 10, 6, 8, 12, 5, 11, 3, 2]"
...,...,...,...,...
3371,6,"[2, 3, 7]","[0.184, 1.0, 0.0]","[1, 10, 6, 4, 3, 7, 5, 11, 8, 9, 2, 12]"
3372,6,"[2, 3, 7]","[0.184, 1.0, 0.0]","[1, 10, 6, 4, 3, 7, 5, 11, 8, 9, 2, 12]"
3373,4,"[7, 6, 5]","[0.0, 0.044, 1.0]","[4, 10, 1, 3, 9, 2, 11, 7, 12, 6, 8, 5]"
3374,9,"[6, 9, 9]","[0.0, 0.466, 1.0]","[9, 8, 12, 11, 5, 4, 1, 10, 2, 7, 6, 3]"


Unnamed: 0,Features,Targets
0,"[1, 6, ?]",2
1,"[1, 6, ?]",10
2,"[1, 6, ?]",4
3,"[?, 3, ?]",7
4,"[?, 3, ?]",9
...,...,...
10123,"[?, 9, 9]",8
10124,"[?, 9, 9]",12
10125,"[?, 9, 9, 9]",9
10126,"[?, 9, 9, 9]",8


['focus_1', 'focus_2', 'focus_3', 'focus_4', 'focus_5', 'focus_6', 'focus_7', 'focus_9', 'focus_10', 'focus_11', 'focus_12', 'last_item']


Unnamed: 0,focus_1,focus_2,focus_3,focus_4,focus_5,focus_6,focus_7,focus_9,focus_10,focus_11,focus_12,last_item,Next
0,True,False,False,False,False,True,False,False,False,False,False,?,2
1,True,False,False,False,False,True,False,False,False,False,False,?,10
2,True,False,False,False,False,True,False,False,False,False,False,?,4
3,False,False,True,False,False,False,False,False,False,False,False,?,7
4,False,False,True,False,False,False,False,False,False,False,False,?,9
...,...,...,...,...,...,...,...,...,...,...,...,...,...
10123,False,False,False,False,False,False,False,True,False,False,False,9,8
10124,False,False,False,False,False,False,False,True,False,False,False,9,12
10125,False,False,False,False,False,False,False,True,False,False,False,9,9
10126,False,False,False,False,False,False,False,True,False,False,False,9,8


Starting fit with GULA

Converting transitions to nparray...
Sorting transitions...
Grouping transitions by initial state...
Start learning over 8 threads


Unnamed: 0,head_value,conditions,total_positives,rule_positive_supports,rule_decoded,rule_size,rule_logic_form,rule_raw_form
97,toy_mall,[toy_mall],8,5,toy_mall :- toy_mall.,1,Next(12) :- focus_12(True).,0=3 :- 10=1.
98,toy_mall,[last(toy_mall)],8,4,toy_mall :- last(toy_mall).,1,Next(12) :- last_item(12).,0=3 :- 11=3.
99,toy_mall,[fast_food],8,2,toy_mall :- fast_food.,2,"Next(12) :- focus_3(False), focus_11(True).","0=3 :- 2=0, 9=1."
100,toy_mall,[last(fast_food)],8,2,toy_mall :- last(fast_food).,2,"Next(12) :- focus_3(False), last_item(11).","0=3 :- 2=0, 11=2."
101,toy_mall,[cinema],8,1,toy_mall :- cinema.,3,"Next(12) :- focus_3(False), focus_4(False), fo...","0=3 :- 2=0, 3=0, 7=1."
...,...,...,...,...,...,...,...,...
334,cinema,"[ottvod_B_paid_in, last(ottvod_B_paid_out)]",16,1,"cinema :- ottvod_B_paid_in, last(ottvod_B_paid...",2,"Next(9) :- focus_3(True), last_item(4).","0=11 :- 2=1, 11=6."
336,cinema,"[ottvod_A_paid_out, ottvod_B_paid_in, last(ott...",16,1,"cinema :- ottvod_A_paid_out, ottvod_B_paid_in,...",3,"Next(9) :- focus_2(True), focus_3(True), last_...","0=11 :- 1=1, 2=1, 11=8."
337,cinema,[last(ottvod_B_paid_in)],16,1,cinema :- last(ottvod_B_paid_in).,4,"Next(9) :- focus_2(False), focus_4(False), foc...","0=11 :- 1=0, 3=0, 8=0, 11=5."
341,cinema,[toy_mall],16,1,cinema :- toy_mall.,2,"Next(9) :- focus_12(True), last_item(?).","0=11 :- 10=1, 11=11."
