In [None]:
import torch

import numpy as np
import pandas as pd

import submission2_utils as util
import genetic

In [None]:
# Prep our data
df = pd.read_csv('../data/train.csv', index_col=0)
df.fillna(0, inplace=True)
num_df, idx_mapping, column_mapping = util.categorical_to_numeric(df)
# num_df.to_csv('../data/train_num.csv')
norm_df = util.normalize(num_df)
# norm_df.to_csv('../data/train_norm_num.csv')

In [None]:
norm_thresholds = (.2, .4, .6, .8, 1)
unnorm_thresholds = \
    [util.unnormalize(num_df, 'SalePrice', x) for x in norm_thresholds]

fitness_function = util.Net()
fitness_function.load_state_dict(torch.load('../model/sub2.pth'))
fitness_function = fitness_function.eval()

first_generation = [house for house in util.create_houses(20, num_df)]

In [None]:
generations, pruned_generations = genetic.genetic_algorithm(first_generation, fitness_function, 20,
                                                            20, unnorm_thresholds[0], num_df, idx_mapping,
                                                            column_mapping, secondary_fitness='GarageArea')

In [None]:
print(unnorm_thresholds[0])

In [None]:
for idx, generation in enumerate(generations):
    valid_generation = [house for house in generation if house.fitness <= unnorm_thresholds[0]]
    sum = 0
    sum2 = 0
    num = 0
    for house in valid_generation:
        sum += house.fitness
        sum2 += house.secondary_fitness
        num += 1
    if num > 0:
        sum /= num
        sum2 /= num
    print(f'{idx}, {num}: {sum}, {sum2}')

In [None]:
for idx, generation in enumerate(pruned_generations):
    sum = 0
    num = 0
    for house in generation:
        sum += house.fitness
        num += 1
    if num > 0:
        sum /= num
    print(f'{idx}, {num}: {sum}')

In [None]:
print(pruned_generations[0])

In [None]:
print(pruned_generations[-1])

In [None]:
# 0, 2: 172470.4269297421, 816.0
# 1, 6: 169888.41894467673, 816.0
# 2, 15: 163586.1464269956, 938.0
# 3, 17: 167929.81650540058, 1009.7647058823529
# 4, 16: 170053.41291092336, 1136.25
# 5, 17: 171586.64349694463, 1182.0
# 6, 20: 176049.36519771814, 1182.0
# 7, 20: 177044.42321754992, 1178.8
# 8, 19: 178265.89332222939, 1183.6842105263158
# 9, 18: 177840.5337173078, 1183.7777777777778
# 10, 17: 178354.82766084812, 1185.764705882353
# 11, 20: 176918.75500254333, 1185.2
# 12, 17: 176772.7445823305, 1189.5294117647059
# 13, 19: 178819.08342838287, 1190.421052631579
# 14, 14: 176397.19132663947, 1188.857142857143
# 15, 17: 178063.51887224353, 1191.4117647058824
# 16, 20: 176126.51624292135, 1183.6
# 17, 19: 176058.6268246174, 1182.0
# 18, 19: 177120.94549387693, 1182.0
# 19, 18: 178853.4616050621, 1182.0
# 20, 20: 178795.33989913762, 1182.0

In [None]:
# 0, 4: 166682.46216475964, 549.0
# 1, 15: 160545.95816403627, 781.3333333333334
# 2, 17: 166090.4618641033, 870.1764705882352
# 3, 17: 166097.23774948542, 709.8823529411765
# 4, 14: 171883.4121730711, 678.6428571428571
# 5, 19: 174523.81574958563, 713.4736842105264
# 6, 18: 173821.80838121308, 635.3888888888889
# 7, 15: 175880.09072552124, 629.3333333333334
# 8, 17: 177000.53930790984, 644.2941176470588
# 9, 19: 177332.88524895906, 656.1052631578947
# 10, 15: 175265.9748438994, 629.3333333333334
# 11, 18: 178736.39613100223, 574.8333333333334
# 12, 18: 178755.39536575475, 484.0
# 13, 18: 178885.2490040991, 484.0
# 14, 20: 178409.1474123299, 484.0
# 15, 13: 178859.2916113826, 484.0
# 16, 13: 178704.5685620262, 484.0
# 17, 18: 178817.54004980126, 484.0
# 18, 17: 178653.9392901694, 484.0
# 19, 18: 177626.08099513582, 484.0
# 20, 17: 178823.70504298632, 484.0