In [22]:
from collections import deque
from queue import PriorityQueue
from quality_measure import quality_measure
from load_data import make_growth_target_df
import pandas as pd
import numpy as np
import math

In [23]:
def gteq(a, b):
    return a >= b

def leeq(a, b):
    return a <= b

def eq(a, b):
    return a == b

def neq(a, b):
    return not eq(a, b)

def extract_subgroup(descriptors, data):
    result = []
    for row in data:
        check = True
        for attribute in descriptors:
            att_index, descr_value, operator = attribute # unpack 3 values from attribute
            #print(att_index, descr_value, operator)
            value = operator(row[att_index], descr_value)

            if not value:
                check = False
                break
        
        if check:
            result.append(row)

    return result


def refin(seed, data, types, nr_bins, descr_indices):

    res = []
    used_descr = [i[0] for i in seed]
    not_used_indices = descr_indices[:]
    [not_used_indices.remove(i) for i in used_descr]
    # refinement of descriptors

    for i in not_used_indices: # functies voor voorwaarden gewoon erin doen
        aux = list(seed)[:]

        if types[i] == 'numeric':
            s = extract_subgroup(seed, data)
            all_values = [float(entry[i]) for entry in s]
            all_values = sorted(all_values)
            n = len(all_values)
            split_points = [all_values[math.floor(j * (n/nr_bins))] for j in range(1, nr_bins)] # weet niet of dit nu goed gaat
            for s in split_points:
                func1 = leeq
                func2 = gteq
         
                local0 = aux[:]
                local0.append((i, s, func1))
                res.append(local0)
           
                local1 = aux[:]
                local1.append((i, s, func2))
                res.append(local1)

        elif types[i] == 'binary':
            func = eq
            local0 = aux[:]
            local0.append((i, 0, func))
            local1 = aux[:]
            local1.append((i, 1, func))
            res.append(local0)
            res.append(local1)

        else:
            all_values = [entry[i] for entry in data]
            for j in set(all_values):
                func1 = eq
                func2 = neq
                local0 = aux[:]
                local0.append((i, j, func1))
                res.append(local0)
    return res

            

def constraints_satisfied(descriptors, constraints): # and subgroup len > 0
    return True

def put_item_in_queue(queue, quality, descriptor):
    if queue.full():
        min_quality, min_descriptor = queue.get()
        if min_quality >= quality:
            queue.put((min_quality, min_descriptor))
        else:
            queue.put((quality, descriptor))
    else:
        queue.put((quality, descriptor))

def categorize_columns_in_order(df, att_columns):
    # Define empty list to store the categories in order
    column_types = []

    # Loop through attribute columns in the DataFrame
    for col in att_columns:
        # Check if the column is numeric
        if pd.api.types.is_numeric_dtype(df[col]):
            column_types.append('numeric')
        # Check if the column has exactly 2 unique values (binary)
        elif df[col].nunique() == 2:
            column_types.append('binary')
        # Otherwise, treat it as nominal
        else:
            column_types.append('nominal')

    return column_types


In [24]:
def beam_search(data, targets_baseline, column_names, quality, beam_width, beam_depth, nr_bins, nr_saved, subgroup_size, constraints, target, types):

    target_ind = column_names.index(target) 
    att_indices = list(range(0, len(column_names)))
    att_indices.remove(target_ind) # ONLY INCLUDE TARGET GROWTH!
    beam_queue = deque([()]) # initialize with emtpy tuple (the empty set)
    results = PriorityQueue(nr_saved) # queue with max amount of descriptors saved

    for depth in range(beam_depth):
        beam = PriorityQueue(beam_width) # amount of combinations we keep investigating
    
        while bool(beam_queue):# while there are items in the queue
            seed = beam_queue.popleft()
            descriptor_set = refin(seed, data, types, nr_bins, att_indices)

            for descriptor in descriptor_set:
                subgroup = extract_subgroup(descriptor, data)
                if constraints_satisfied(descriptor, constraints) and len(subgroup) >= subgroup_size:
                    targets_subgroup = [i[target_ind] for i in subgroup]
                    quality_result = quality(targets_subgroup, targets_baseline)
                    put_item_in_queue(results, quality_result, tuple(descriptor))
                    put_item_in_queue(beam, quality_result, tuple(descriptor))
        


        while not beam.empty(): 

            new_combination = beam.get()
            new_combination = new_combination[1]# dit moet het item gaan geven met de beste quality measure
            beam_queue.append(new_combination)

    return results



In [25]:
# Create stock_df from the target file
stock_df = make_growth_target_df('stock_data_for_emm.pkl')

In [26]:
stock_df.drop(['index'], inplace=True, axis=1)
stock_df

Unnamed: 0,country,industry,currency,exchangeTimezoneName,exchange,sector,averageVolume10days,enterpriseToEbitda,marketCap,debtToEquity,fullTimeEmployees,growth_target
0,France,Aerospace & Defense,EUR,Europe/Paris,PAR,Industrials,382.0,11.822,5.750674e+07,50.783,1102.0,"[0.0, 1.17, 2.41, 0.41, 2.65, -1.39, -7.24, -4..."
1,Germany,Software—Application,EUR,Europe/Berlin,GER,Technology,8329.0,33.294,1.241838e+09,32.492,650.0,"[0.0, -1.84, 17.01, 3.7, -8.24, 8.01, -3.6, 4...."
2,Italy,Entertainment,EUR,Europe/Rome,MIL,Communication Services,330.0,-21.050,2.849466e+07,181.181,34.0,"[0.0, -7.96, -2.61, 0.89, 13.94, -4.85, -6.12,..."
3,Italy,Packaged Foods,EUR,Europe/Rome,MIL,Consumer Defensive,3831.0,8.794,5.971590e+07,68.513,232.0,"[0.0, -3.39, 0.58, -10.17, 11.33, -2.91, 2.1, ..."
4,United Kingdom,Other Precious Metals & Mining,EUR,Europe/Berlin,FRA,Basic Materials,108.0,1.402,2.123667e+08,81.212,3474.0,"[0.0, 8.5, -15.79, 5.28, 2.24, -1.94, 7.5, -1...."
...,...,...,...,...,...,...,...,...,...,...,...,...
9774,Japan,Electrical Equipment & Parts,JPY,Europe/London,LSE,Industrials,4893.0,6.287,2.079737e+10,10.654,145696.0,"[0.0, 1.08, 10.44, -11.75, -11.15, 7.96, -9.17..."
9775,Japan,Business Equipment & Supplies,JPY,Europe/London,LSE,Industrials,1940.0,5.531,5.086319e+09,35.268,78360.0,"[0.0, 5.82, 1.71, 7.84, -10.55, 3.58, -7.94, 2..."
9776,Japan,Electronic Gaming & Multimedia,JPY,Europe/London,LSE,Communication Services,3530.0,7.976,6.297095e+09,18.311,4894.0,"[0.0, 6.5, 0.0, -10.85, -1.95, -3.62, -3.95, 9..."
9777,Japan,Electronic Components,JPY,Europe/London,LSE,Technology,350.0,7.200,1.774490e+09,1.943,1297.0,"[0.0, 0.0, 0.0, 0.0, 18.86, 0.0, 0.0, 0.0, 10...."


In [27]:
data = stock_df.values.tolist()
column_names = list(stock_df.columns)
quality = quality_measure
beam_width = 3
beam_depth = 5
nr_bins = 3
nr_saved = 10
subgroup_size = 10
constraints = None
target = 'growth_target'

In [28]:
target_ind = column_names.index(target)
all_time_series = [i[target_ind] for i in data]
all_time_series = np.array(all_time_series)
targets_baseline = np.mean(all_time_series, axis=0)

In [29]:
att_indices = list(range(0, len(column_names)))
att_indices.remove(target_ind)
att_columns = [column_names[i] for i in att_indices]
types = categorize_columns_in_order(stock_df, att_columns)

In [30]:
results  = beam_search(data, targets_baseline, column_names, quality, beam_width, beam_depth, nr_bins, nr_saved, subgroup_size, constraints, target, types)


In [31]:
while not results.empty():
    item = results.get()
    print(f"Priority: {item[0]}, Descriptor: {item[1]}")

Priority: 64.86614408246143, Descriptor: ((3, 'Europe/Berlin', <function eq at 0x00000174FF509CF0>), (2, 'EUR', <function eq at 0x00000174FF509CF0>))
Priority: 64.86614408246143, Descriptor: ((3, 'Europe/Berlin', <function eq at 0x00000174FF509CF0>), (6, 0.0, <function gteq at 0x000001748E4FB7F0>))
Priority: 65.62197956465624, Descriptor: ((10, 10000.0, <function gteq at 0x000001748E4FB7F0>), (6, 39.0, <function gteq at 0x000001748E4FB7F0>))
Priority: 65.81396857128121, Descriptor: ((2, 'EUR', <function eq at 0x00000174FF509CF0>), (6, 0.0, <function gteq at 0x000001748E4FB7F0>), (7, 5.111, <function gteq at 0x000001748E4FB7F0>))
Priority: 65.81396857128121, Descriptor: ((2, 'EUR', <function eq at 0x00000174FF509CF0>), (7, 5.111, <function gteq at 0x000001748E4FB7F0>))
Priority: 67.74505877067688, Descriptor: ((10, 10000.0, <function gteq at 0x000001748E4FB7F0>), (8, 6693983744.0, <function gteq at 0x000001748E4FB7F0>))
Priority: 68.37178453583074, Descriptor: ((2, 'EUR', <function eq a