In [1]:
import sys
sys.path.append("../")

In [2]:
import pandas as pd
from abduction_memorability.event import Event, Label
from abduction_memorability.abduction_module import SurpriseAbductionModule
from abduction_memorability.predicate import Predicate, MonthPredicate, HasLabelPredicate, AxisRankPredicate, RandomChoicePredicate, DevicePredicate
from abduction_memorability.memory import Memory
import random
import math
from typing import Tuple

import plotly.io as pio
import plotly.express as px
import plotly.offline as py
import datetime as dt

import matplotlib.pyplot as plt

In [3]:
# defining labels
root_label = Label()
light_label = Label(name="light", parent=root_label)
blinds_label = Label(name="blinds", parent=root_label)
device_label = Label(name="device", parent=root_label)
device_addition_label = Label(name="addition", parent=device_label)
device_deletion_label = Label(name="deletion", parent=device_label)
tv_label = Label(name="tv_on", parent=root_label)


# some sensors
light_sensor = "light_sensor"
smart_tv = "smart_tv"
old_tv = "old_tv"
window = "window"
CURRENT_DAY = 100
EPOCH_SHIFT = 1577836800 # Epoch time of Jan, 1st 2021

# generating some events:
all_events: list[Event] = []
for day in range(0, CURRENT_DAY):
    for hour in range(0, 24):
        if day == CURRENT_DAY - 1 and hour == 13:
            event = Event(
                timestamp=day*86400 +  hour * 3600 + EPOCH_SHIFT, 
                duration=3600, 
                label=light_label,
                characteristics={
                    "luminosity": 100,
                    "device": light_sensor
                }
            )
            all_events.append(event)
            event = Event(
                timestamp=day*86400 + (hour-1) * 3600 + EPOCH_SHIFT,
                duration=3600,
                label=tv_label,
                characteristics={
                    "device": smart_tv
                }
            )
            all_events.append(event)
        else:
        # hour = 6 * hour
            if hour < 6 or hour > 18:
                event = Event(
                    timestamp=day*86400 +  hour * 3600 + EPOCH_SHIFT, 
                    duration=3600, 
                        label=light_label,
                        characteristics={
                            "luminosity": 100,
                            "device": light_sensor
                        }
                    )
                all_events.append(event)
            elif random.random() < 0.05:
                event = Event(
                    timestamp=day*86400 + hour * 3600 + EPOCH_SHIFT,
                    duration=3600,
                    label=light_label,
                    characteristics={
                        "luminosity": 200,
                        "device": light_sensor
                    }
                )
                all_events.append(event)
            else:
                if random.random() < 0.1 and day < CURRENT_DAY - 1:
                    event = Event(
                        timestamp=day*86400 + hour * 3600 + EPOCH_SHIFT,
                        duration=3600,
                        label=tv_label,
                        characteristics={
                            "device": old_tv
                        }
                    )
                    all_events.append(event)
                event = Event(
                    timestamp=day*86400 + hour * 3600 + EPOCH_SHIFT,
                    duration=3600,
                    label=light_label,
                    characteristics={
                        "luminosity": 1000,
                        "device": light_sensor
                    }
                )
                all_events.append(event)

len(all_events)

2524

In [4]:

class ByDayPredicate(Predicate):
    def __init__(self, mem, prog, aux_predicate=None):
        super().__init__(mem, prog)
    
    def __call__(self, event):
        if self._prog > 0:
            return None
        mod_day = event.timestamp % 86400
        hour = mod_day / 3600
        if hour > 6 and hour <= 18:
            return True
        else:
            return False

    def __str__(self):
        return "day()"

class ByNightPredicate(Predicate):
    def __init__(self, mem, prog, aux_predicate=None):
        super().__init__(mem, prog)

    def __call__(self, event):
        if self._prog > 0:
            return None
        mod_day = event.timestamp % 86400
        hour = mod_day / 3600
        if hour <= 6 or hour > 18:
            return True
        else:
            return False

    def __str__(self):
        return "night()"

class Dark(Predicate):
    def __init__(self, mem, prog, aux_predicate=None):
        super().__init__(mem, prog, aux_predicate=aux_predicate)

    def __call__(self, event):
        if self._prog > 0:
            return None
        if event.get_char("luminosity") is None:
            return False
        return event.get_char("luminosity") < 300

    def __str__(self):
        return "dark()"

class Recent(Predicate):
    def __init__(self, mem, prog, aux_predicate=None):
        super().__init__(mem, prog, aux_predicate=aux_predicate)

    def __call__(self, event):
        if self.aux_predicate is None:
            if self._mem is None:
                last_date = CURRENT_DAY
            else:
                last_date = math.floor(self._mem.get_last_time() / 86400)
        else:
            last_date = math.floor(self.aux_predicate.timestamp / 86400)
            # print(last_date)
        if self._prog > 100:
            return None
        event_day = math.floor(event.timestamp / 86400)
        if last_date - event_day == self._prog:
            return True
        else:
            return False

    def __str__(self):
        return f"recent({self._prog})"

class Hour(Predicate):
    def __init__(self, mem, prog, aux_predicate=None):
        super().__init__(mem, prog, aux_predicate=aux_predicate)

    def __call__(self, event):
        if self._prog > 24:
            return None
        hour = event.timestamp % 86400
        hour = hour // 3600
        if self.aux_predicate is None:
            if hour == self._prog:
                return True
            return False
        else:
            other_event_hour = (self.aux_predicate.timestamp % 86400) // 3600
            diff_hour = min(abs(other_event_hour - hour), abs(other_event_hour - (24 + hour)))
            return diff_hour == self._prog

    def program_length(self):
        if self.aux_predicate is None:
            return math.log2(24) + 1
        else:
            return Helpers.bit_length(self._prog)                
    
    def __str__(self):
        return f'hour({self._prog})'
        

In [5]:
# Checking that, for all events in the list, they are either in day or in night
pred = ByDayPredicate(None, 0)
pred1 = ByNightPredicate(None, 0)
print(all([pred1(event) ^ pred(event) for event in all_events]))
pred2 = Recent(None, 1)
pred3 = Dark(None, 0)
mem = Memory(all_events)

True


In [6]:
# testing some filters
from abduction_memorability.predicate_filter import OptimizedFilter
filt1 = OptimizedFilter(pred2)
test = filt1(mem)
# test.print_all_events()
rand1 = RandomChoicePredicate(test, 0)
filt2 = OptimizedFilter(rand1)
test2 = filt2(test)
filt3 = OptimizedFilter(Hour(test, 2))
test3 = filt3(test)
test3.print_all_events()

## Creating the module
You can try out different predicates in the initialization, to see the effect on the results. (see the other notebook on this issue)

In [7]:

module = SurpriseAbductionModule(mem,
    predicates=[
        Recent,
        ByDayPredicate,
        ByNightPredicate,
        Dark,
        Hour,
        RandomChoicePredicate,
        DevicePredicate,
        HasLabelPredicate
    ],
    max_depth=3)

Loaded the memory with 2524 items!
Computing complexities with 3 passes
Starting pass 0 with 1 memories to explore
Finished pass 0 in 1.1964750289916992s.
Improved complexity for 2525 event(s)
Starting pass 1 with 131 memories to explore
Finished pass 1 in 6.015798330307007s.
Improved complexity for 515 event(s)
Starting pass 2 with 1286 memories to explore
Finished pass 2 in 16.206896543502808s.
Improved complexity for 31 event(s)
Computing memorability scores for all events !


In [8]:
df_dict = {
    "id": [],
    "time": [],
    "label": [],
    "complexity": [],
    "memorability": [],
    "recipe":[]
}
for event_id in range(len(mem)):
    event = module.get_event_by_id(event_id)
    df_dict["id"].append(event_id)
    df_dict["time"].append(event.timestamp)
    df_dict["label"].append(str(event.label))
    df_dict["recipe"].append(module.get_event_recipe(event_id))
    df_dict["complexity"].append(module.get_event_complexity(event_id))
    df_dict["memorability"].append(module.get_memorability(event_id))
df = pd.DataFrame.from_dict(df_dict)
df["date"] = pd.to_datetime(df["time"], unit="s")

In [9]:

# ! Change the IDs depending on the random seed ! 
# You can identify them by exploring the set of events
LIGHT_DIMMING_ID = 2513
SMART_TV_ID = 2503

In [10]:
fig = px.scatter(df, x="date", y="memorability", color="label", hover_data=["id", "recipe"], height=800, color_discrete_sequence=["cadetblue", "lightsalmon"])
fig.update_traces(marker={'size':16})
fig.update_layout(
    showlegend=False,
    font=dict(
        size=30
    )
)
fig.add_annotation(
    text="Light Dimming",
    x=df["date"][LIGHT_DIMMING_ID],
    y=df["memorability"][LIGHT_DIMMING_ID],
    arrowwidth=3,
    font=dict(
        size=25,
    )
)
fig.add_shape(
    type="circle",
    xref="x",
    yref="y",
    x0 = df["date"][LIGHT_DIMMING_ID] - dt.timedelta(hours=10),
    y0 = df["memorability"][LIGHT_DIMMING_ID] - 0.2,
    x1 = df["date"][LIGHT_DIMMING_ID] + dt.timedelta(hours=10),
    y1 = df["memorability"][LIGHT_DIMMING_ID] + 0.2,
)
fig.add_shape(
    type="circle",
    xref="x",
    yref="y",
    x0 = df["date"][SMART_TV_ID] - dt.timedelta(hours=10),
    y0 = df["memorability"][SMART_TV_ID] - 0.2,
    x1 = df["date"][SMART_TV_ID] + dt.timedelta(hours=10),
    y1 = df["memorability"][SMART_TV_ID] + 0.2,
)
fig.add_annotation(
    text="Smart TV Usage",
    x=df["date"][SMART_TV_ID],
    y=df["memorability"][SMART_TV_ID],
    arrowwidth=3,
    arrowsize=3,
    font=dict(
        size=25
    )
)
fig.update_traces(marker={'size':10})

In [11]:
fig2 = px.scatter(df, x="date", y="memorability", color="label", hover_data=["id", "recipe"], height=800, color_discrete_sequence=["blue", "gray"])
# for event_id in [20, 120, 240]:
#     x_coord = df["date"][event_id]
#     y_coord = df["complexity"][event_id]
#     fig2.add_annotation(text=f'{event_id}', x=x_coord, y=y_coord, showarrow=True, arrowhead=3, arrowcolor="red")
fig2.update_traces(marker={'size':13})
fig2.update_layout(
    title = "Memorability of events: TV scenario",
    showlegend=False,
    font=dict(
        size=25
    )
)
fig2.update_traces(marker={'size':8})

## Abduction
Please note that due to the implementation, the abduction process requires that no infinite complexity is left in the module (i.e. all events can be retrieved). 
Errors can arise if this condition is not met

In [18]:
#! Replace EVENT_ID with the id of the event you want to explain, identified from the plot above
#! This cell can take some time to run, due to the relative memorability computation required for the abduction
EVENT_ID = 2503
candidates = module.abduction(EVENT_ID)
candidates_with_id = [(e[0].get_id(), e[1]) for e in candidates]
import pandas as pd
df = pd.DataFrame(candidates_with_id[:3])

		Score improved for 284: 21.543019978781743
		Score improved for 821: 21.543019978781743
		Score improved for 1359: 21.543019978781743
		Score improved for 1900: 21.543019978781743
		Score improved for 1320: 21.543019978781743
		Score improved for 2438: 21.543019978781743
		Score improved for 50: 21.543019978781743
		Score improved for 587: 21.543019978781743
		Score improved for 1125: 21.543019978781743
		Score improved for 1665: 21.543019978781743
		Score improved for 2203: 21.543019978781743
		Score improved for 353: 21.543019978781743
		Score improved for 888: 21.543019978781743
		Score improved for 1431: 21.543019978781743
		Score improved for 1972: 21.543019978781743
		Score improved for 2509: 21.543019978781743
		Score improved for 119: 21.543019978781743
		Score improved for 656: 21.543019978781743
		Score improved for 1193: 21.543019978781743
		Score improved for 1737: 21.543019978781743
		Score improved for 2272: 21.543019978781743
		Score improved for 423: 21.54301997878174

In [19]:
df

Unnamed: 0,0,1
0,2475,16.782836
1,2450,12.795736
2,2453,12.495663
