In [1]:
import os
import pickle
from itertools import product

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

from pycgp import probabilistic_mutation, point_mutation, single_mutation
from pycgp.gems import MatchByActiveStrategy, MatchSMStrategy, MatchPMStrategy

In [8]:
class DataIterator():
    def __init__(self, folder):
        self.mutations = [
          #  (probabilistic_mutation, MatchSMStrategy),
            #(point_mutation, MatchPMStrategy),
            #(single_mutation, MatchSMStrategy),
            (single_mutation, MatchByActiveStrategy),
            (probabilistic_mutation, MatchByActiveStrategy)
        ]
        self.gems = [5, 10]
        
        self.cols = [10, 50, 100]
        
        data = []
        raw = []
        for m,s,g,c,d in self.__iterate_folder(folder):
            raw.append(d)
            data.append([m, s, g, c, [x['gem_data'] for x in d], np.mean([x['gem_better_after'] for x in d]), np.mean([x['gem_worse_after'] for x in d]), np.mean([x['gem_same_after'] for x in d])])
        data = pd.DataFrame(data)
        data.columns = ['m', 's', 'g', 'c', 'gem_data', 'better_after', 'worse_after', 'safe_after']
        self.data = data
        self.raw = raw
    
    def __iterate_folder(self,folder):
        for index, ((mutation, strategy), gem, column) in enumerate(product(self.mutations, self.gems, self.cols)):
            file = os.path.join(folder,  f'{mutation.__name__}-{strategy.__name__}-gems{gem}-n_cols{column}.csv')
            with open (file, 'rb') as fp:
                data = pickle.load(fp)

                yield mutation.__name__, strategy.__name__, gem, column, data
    
    def iterate_gem_data(self, mutation, strategy, axis=False):
        gdatas = self.data[(self.data.m == mutation) & (self.data.s == strategy) & (self.data.g != 0)]

        if axis:
            _, axs = plt.subplots(2, len(gdatas)//2, figsize=(8,6))

        for i, (_, gdata) in enumerate(gdatas.iterrows()):
            pgdata = []
            #pdb.set_trace()
            for gem in [item for sublist in gdata.gem_data for item in sublist]:
                row = [gdata.g, gdata.c, gem.match_checks, gem.match_count, gem.n_uses, gem.value, gem.match_probability, gdata.better_after, gdata.worse_after, gdata.safe_after]
                pgdata.append(row)

            pgdata = pd.DataFrame(pgdata)
            pgdata.columns = ['gems', 'columns', 'match_checks', 'match_count', 'n_uses', 'value', 'match_probability', 'better_after', 'worse_after', 'same_after']
            pgdata['success_rate_counts'] = pgdata.n_uses / pgdata.match_count
            pgdata['success_rate_checks'] = pgdata.n_uses / pgdata.match_checks
            
            if axis:
                yield pgdata, axs[i//3][i%3]
            else:
                yield pgdata, None
    
    def stats(self):
        frames = []
        for m, s in self.mutations:
            for pgdata, _ in self.iterate_gem_data(m.__name__, s.__name__):
                pgdata['m'] = m.__name__
                pgdata['s'] = s.__name__
                frames.append(pgdata)
        data = pd.concat(frames)
        return data

import pdb
            

symreg = DataIterator('scripts/symbolic_basic/')
#bincls = DataIterator('scripts/bin_class_out/')
#santaf = DataIterator('scripts/santa_fe_out/')

# def densities(di, m, s):
#     frames = []
#     for pgdata, ax in di.iterate_gem_data(m, s, True):
#         frames.append(pgdata)
#         sns.distplot(pgdata.n_uses, ax=ax)
#     data = pd.concat(frames)
#     print(f'Count of gems: {len(data)}')
#     print(data.groupby('gems').median())


In [9]:
data = symreg.raw

In [13]:
data[0].keys()

AttributeError: 'list' object has no attribute 'keys'