In [5]:
import matplotlib.pyplot as plt
import collections
import csv
import sys
import os

DATA_ACCURACY_DIR_PATH = "./cross_validation/accuracy"
DATA_LOSS_DIR_PATH = "./cross_validation/loss"

def get_files_in_dir(dir_path):
    files = [os.path.join(dir_path, file) for file in os.listdir(dir_path) if os.path.isfile(
                                                                    os.path.join(dir_path, file))]
    return files
    

def load_tensorflow_csv(csv_file_path):
    step_value_dict = collections.defaultdict(int)
    with open(csv_file_path, "rb") as csv_file:
        reader = csv.reader(csv_file, delimiter=",", quotechar="|")
        next(reader, None)
        for time, step, value in reader:
            step = int(step)
            value = float(value)
            step_value_dict[step] = value
        
    return step_value_dict


def get_max_key(dict):
    max_key = 0
    for key in dict:
        max_key = max_key if max_key > key else key
    
    return max_key


def get_cross_validation_data(dir_path):
    files = get_files_in_dir(dir_path)
    step_value_maps = [load_tensorflow_csv(file) for file in files]
    max_keys = [get_max_key(step_value_map) for step_value_map in step_value_maps]
    end_map_key = min(max_keys)
    
    cross_validation_map = collections.defaultdict(float)
    for step_value_map in step_value_maps:
        for step_key in step_value_map:
            if step_key <= end_map_key:
                cross_validation_map[step_key] += step_value_map[step_key] / len(files)
    
    step_value_pairs = [(map_key, cross_validation_map[map_key]) for map_key in cross_validation_map]
    sorted_pairs = sorted(step_value_pairs, key=lambda tup: tup[0])
    steps = []
    values = []
    for step, value in sorted_pairs:
        steps.append(step)
        values.append(value)
    
    return steps, values


def plot_cross_validation(dir_path):
    steps, values = get_cross_validation_data(dir_path)
    # Everything's computed, time to draw some fancy stuff
    plt.figure(figsize=(20,10)) # Plot dimensions
    plt.ylim(-8,8) # Set range for y values
    handle, = plt.plot(steps, values, label="cross val")
    
    plt.legend(handles=[handle], loc='upper left')
    pb.savefig("cross_validation.png", bbox_inches='tight')
    plt.show()


def plot_size_power(dir_path):
    files = get_files_in_dir(dir_path)
    step_value_maps = [load_tensorflow_csv(file) for file in files]
    step_value_pairs = []
    steps = []
    values = []
    for step_value_map in step_value_maps:
        current_pairs = [(map_key, step_value_map[map_key]) for map_key in step_value_map]
        sorted_pairs = sorted(current_pairs, key=lambda tup: tup[0])
        step_value_pairs.append(sorted_pairs)
    
    for current_pairs in step_value_pairs:
        current_steps = [current_pair[0] for current_pair in current_pairs]
        current_values = [current_pair[1] for current_pair in current_pairs]
        steps.append(current_steps)
        values.append(current_values)
    
    # Everything's computed, time to draw some fancy stuff
    plt.figure(figsize=(20,10)) # Plot dimensions
    plt.ylim(-8,8) # Set range for y values
    handles = []
    for i in range(0, len(steps)):
        handle, = plt.plot(steps[i], values[i], label="")
        handles.append(handle) 
    
    plt.legend(handles=handles, loc='upper left')
    pb.savefig("size_power.png", bbox_inches='tight')
    plt.show()


dict = load_data_csv("./cross_validation/set_1/accuracy.csv")
print(dict)
    

defaultdict(<type 'int'>, {1025: 0.857421875, 2050: 0.880859375, 3075: 0.794921875, 4100: 0.830078125, 5125: 0.822265625, 6150: 0.884765625, 525: 0.76953125, 1550: 0.8359375, 2575: 0.857421875, 3600: 0.853515625, 4625: 0.82421875, 5650: 0.8671875, 25: 0.681640625, 1050: 0.796875, 2075: 0.84765625, 3100: 0.814453125, 4125: 0.861328125, 5150: 0.859375, 6175: 0.884765625, 550: 0.771484375, 1575: 0.853515625, 2600: 0.849609375, 3625: 0.826171875, 4650: 0.869140625, 5675: 0.873046875, 50: 0.671875, 1075: 0.80859375, 2100: 0.8359375, 3125: 0.849609375, 4150: 0.873046875, 5175: 0.849609375, 6200: 0.84765625, 575: 0.7421875, 1600: 0.81640625, 2625: 0.833984375, 3650: 0.798828125, 4675: 0.880859375, 5700: 0.869140625, 75: 0.671875, 1100: 0.798828125, 2125: 0.84375, 3150: 0.853515625, 4175: 0.857421875, 5200: 0.828125, 6225: 0.876953125, 600: 0.8046875, 1625: 0.822265625, 2650: 0.86328125, 3675: 0.8515625, 4700: 0.845703125, 5725: 0.853515625, 100: 0.697265625, 1125: 0.818359375, 2150: 0.8320312