# Tale Tree

In [1]:
from collections import deque
from enum import Enum
import numpy as np

In [81]:
class Tension_Level(Enum):
    LOW = 0
    MID = 1
    HIGH = 2

class Event():
    """ Basic building block of a narrative. """
    def __init__(self,name:str,tension_level:Tension_Level):
        self.name = name
        self.tension_level = tension_level 

    def __str__(self):
        return f"{self.tension_level} | Name: {self.name}. "
    
    def __repr__(self):
        return self.__str__()

class Timeline():
    """ Ordered sequence of Events alternating (key,query) pairs. """
    def __init__(self,timeline=None):
        if timeline:
            self.content = timeline.content.copy()
        else:
            self.content = list()

    def add_event(self,event:Event):
        """ Add an event to the existing Timeline, should be different from the previous one"""
        #if not self.is_empty() and event.type == self.content[-1].type:
        #    raise ValueError('Trying to insert an event of same type than the previous one.')
        #elif self.is_empty() and event.type != Event_Type.KEY:
        #    raise ValueError('The first element of a Timeline should be a KEY.')
        #else:
        self.content.append(event)

    def is_empty(self) -> bool:
        """ Check if the Timeline is empty. """
        return len(self.content) == 0

    def pop_last_event(self) -> Event:
        """ Return and remove the last inserted event in the Timeline. """
        if self.is_empty():
            raise Exception('Timeline is empty')
        last_event = self.content[-1]
        self.content = self.content[:-1]
        return last_event

    def pop_first_event(self) -> Event:
        """ Remove the first inserted event in the Timeline. """
        if self.is_empty():
            raise Exception('Timeline is empty')
        first_event = self.content[0]
        self.content = self.content[1:]
        return first_event
    
    @classmethod
    def default(cls,prefix:str='a',size:int=5):
        """ Create a basic timeline with the given set of parameters. """
        timeline = cls()
        for i in range(size):
            timeline.add_event(Event(prefix+str(i),Tension_Level.MID))
        return timeline 

    def __getitem__(self, item):
         return self.content[item]

    def __len__(self):
        return len(self.content)

    def __str__(self):
        if self.is_empty():
            return 'Empty Timeline'
        return '### \t\t TIMELINE \t\t ### \n\n' + "\n".join([e.__str__() for e in self.content]) 

# Algorithms

class Algorithm():
    """ Base class for event selection and Timelince creation. """
    def __init__(self,timelines:list):
        self.timelines = [Timeline(t) for t in timelines]
    
    def get_next_event(self) -> Event:
        """ Get the next pair of (key,query) selected by the algorithm. """
        return
    
    def get_sequence(self) -> Timeline:
        """ Get the full sequence derived by the algorithm. """
        main_sequence = Timeline()
        while not all([t.is_empty() for t in self.timelines]):
            event = self.get_next_event()
            main_sequence.add_event(event)
        return main_sequence
    
class Random_Algorithm(Algorithm):
    """ Algorithm selecting randomly events for Timeline creation. """
    def __init__(self, seed, timelines):
        super().__init__(timelines)
        self.rng = np.random.default_rng(seed)

    def get_next_event(self) -> Event:
        candidate_timelines = [t for t in self.timelines if not t.is_empty()]
        selected_timeline = self.rng.choice(candidate_timelines)
        event = selected_timeline.pop_first_event()
        return event
    
class Balanced_Algorithm(Algorithm):
    def __init__(self, seed, timelines):
        super().__init__(timelines)
        self.rng = np.random.default_rng(seed)

    def get_next_event(self) -> Event:
        timelines_length = [len(t) for t in self.timelines]
        total_length = sum(timelines_length)
        probabilities = [length/total_length for length in timelines_length]
        selected_timeline = self.rng.choice(self.timelines,p=probabilities)
        event = selected_timeline.pop_first_event()
        return event
    
class Tension_Algorithm(Algorithm):

    zero_distances = {(Tension_Level.LOW,Tension_Level.MID),
                      (Tension_Level.MID,Tension_Level.HIGH),
                      (Tension_Level.HIGH,Tension_Level.LOW)}

    def __init__(self, timelines, time_horizon:int=2):
        super().__init__(timelines)
        self.time_horizon = time_horizon

    def get_ancestry(self,id_sequence:dict,init_state:tuple) -> list:
        """ Get the previous list of selected timelines. """
        ancestry = []
        state = init_state
        for _ in range(self.time_horizon):
            prev_state, timeline = id_sequence[state]
            if prev_state:
                ancestry.append(timeline)
            else:
                break
            state = prev_state
        return ancestry

    def get_neighborhood(self,last_event:Event,state:tuple,id_sequences:dict) -> list:
        """ Retrieve the set of possible neighbors from the given state together with their distances. """
        # Check ancestry
        valid_timelines = set(list(range(len(self.timelines))))
        past_ids = self.get_ancestry(id_sequences,state)
        if len(past_ids) >= 2 and past_ids[-1] == past_ids[-2]:
            valid_timelines.remove(past_ids[-1])
        # Compute neighborhodd
        neighborhood = []
        for i, max_length in [(i,len(t)-1) for i,t in enumerate(self.timelines)]:
            if state[i] < max_length and i in valid_timelines:
                new_state = tuple([state[idx] if idx != i else state[idx]+1 for idx in range(len(state))])
                new_event = self.timelines[i][new_state[i]]
                distance = self.get_tension_distance((last_event.tension_level,new_event.tension_level))
                neighborhood.append((new_state,new_event,i,distance))
        return neighborhood

    def get_tension_distance(self,pair:tuple) -> int:
        """ Report the distance between two levels of tension. """
        if pair in Tension_Algorithm.zero_distances or None in pair:
            return 0
        else:
            return 1

    def get_sequence(self) -> Timeline:
        """ Get the full sequence derived by the algorithm. """
        # Initiate sequences
        start = tuple([-1 for _ in range(len(self.timelines))])
        end = tuple([len(t)-1 for t in self.timelines])
        sequences = {start:(None,Event('root',None))}
        id_sequences = {start:(None,-1)}
        distances = {start:0}
        candidates = deque([start])
        # Run 0-1 BFS
        while candidates:
            state = candidates.popleft()
            last_event = sequences[state][-1]
            neighborhood = self.get_neighborhood(last_event,state,id_sequences)
            for n_state,n_event,idx,w in neighborhood:
                if not n_state in distances or distances[n_state] > distances[state]+w:
                    distances[n_state] = distances[state]+w
                    sequences[n_state] = (state,n_event)
                    id_sequences[n_state] = (state,idx)
                    if w == 0:
                        candidates.appendleft(n_state)
                    else:
                        candidates.append(n_state)
        # Create Final Timeline
        timeline = Timeline()
        s = end
        all_events = []
        while s != start:
            new_s,e = sequences[s]
            all_events.append(e)
            s = new_s
        for e in reversed(all_events):
            timeline.add_event(e)
        return timeline

In [82]:
e = Event('test',Tension_Level.HIGH)
t_1,t_2,t_3  = Timeline.default(), Timeline.default(size=10,prefix='b'), Timeline.default(size=1,prefix='c')
algo = Balanced_Algorithm(seed=0,timelines=[t_1,t_2,t_3])

In [83]:
t_1 = Timeline()
t_1.add_event(Event('A',Tension_Level.LOW))
t_1.add_event(Event('B',Tension_Level.HIGH))
t_1.add_event(Event('C',Tension_Level.LOW))
t_1.add_event(Event('D',Tension_Level.MID))
t_1.add_event(Event('E',Tension_Level.HIGH))
#
t_2 = Timeline()
t_2.add_event(Event('O',Tension_Level.MID))
t_2.add_event(Event('OO',Tension_Level.LOW))
t_2.add_event(Event('OOO',Tension_Level.MID))
t_2.add_event(Event('OOOO',Tension_Level.HIGH))
t_2.add_event(Event('OOOOO',Tension_Level.LOW))
#
t_3 = Timeline()
t_3.add_event(Event('1',Tension_Level.LOW))
t_3.add_event(Event('2',Tension_Level.MID))
t_3.add_event(Event('3',Tension_Level.HIGH))
t_3.add_event(Event('4',Tension_Level.MID))
t_3.add_event(Event('5',Tension_Level.HIGH))
#
algo = Tension_Algorithm(timelines=[t_1,t_2,t_3])
print(algo.get_sequence())

### 		 TIMELINE 		 ### 

Tension_Level.LOW | Name: 1. 
Tension_Level.MID | Name: 2. 
Tension_Level.LOW | Name: A. 
Tension_Level.MID | Name: O. 
Tension_Level.HIGH | Name: 3. 
Tension_Level.LOW | Name: OO. 
Tension_Level.MID | Name: 4. 
Tension_Level.HIGH | Name: B. 
Tension_Level.LOW | Name: C. 
Tension_Level.MID | Name: OOO. 
Tension_Level.HIGH | Name: 5. 
Tension_Level.MID | Name: D. 
Tension_Level.HIGH | Name: OOOO. 
Tension_Level.LOW | Name: OOOOO. 
Tension_Level.HIGH | Name: E. 


In [58]:
%timeit algo.get_sequence()

554 µs ± 3.93 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
t = algo.get_sequence()

In [28]:
print(t)

### 		 TIMELINE 		 ### 

Event_Type.KEY | Tension_Level.MID | Name: b0. 
Event_Type.QUERY | Tension_Level.MID | Name: b0. 
Event_Type.KEY | Tension_Level.MID | Name: a0. 
Event_Type.QUERY | Tension_Level.MID | Name: a0. 
Event_Type.KEY | Tension_Level.MID | Name: a1. 
Event_Type.QUERY | Tension_Level.MID | Name: a1. 
Event_Type.KEY | Tension_Level.MID | Name: a2. 
Event_Type.QUERY | Tension_Level.MID | Name: a2. 
Event_Type.KEY | Tension_Level.MID | Name: b1. 
Event_Type.QUERY | Tension_Level.MID | Name: b1. 
Event_Type.KEY | Tension_Level.MID | Name: c0. 
Event_Type.QUERY | Tension_Level.MID | Name: c0. 
Event_Type.KEY | Tension_Level.MID | Name: b2. 
Event_Type.QUERY | Tension_Level.MID | Name: b2. 
Event_Type.KEY | Tension_Level.MID | Name: b3. 
Event_Type.QUERY | Tension_Level.MID | Name: b3. 
Event_Type.KEY | Tension_Level.MID | Name: b4. 
Event_Type.QUERY | Tension_Level.MID | Name: b4. 
Event_Type.KEY | Tension_Level.MID | Name: b5. 
Event_Type.QUERY | Tension_Level.MID | Name: 