In [None]:
try:
    from pokemon_data import *
    from model_creator import *
    import ipywidgets as w
    from IPython.display import clear_output
except:
    !pip install -r requirements.txt
    from pokemon_data import *
    from model_creator import *
    import ipywidgets as w
    from IPython.display import clear_output

In [None]:
%%html
<style>
.container{width:75%}
</style>

In [None]:
class DataSetup:
    def __init__(
        self,
        generation
    ):
        self.gen_data = GenData(generation=generation)
        self.gen_data.create_train_test()
        if f'{generation}.keras' in os.listdir():
            self.model = tf.keras.models.load_model(f'{generation}.keras')
        else:
            self.model_obj = ModelCreator(
                                X=self.gen_data.X_train,
                                y=self.gen_data.y_train,
                                val_X=self.gen_data.X_test,
                                val_y=self.gen_data.y_test,
                                model_name=generation
                            )
            self.model = self.model_obj.model
            clear_output()
            
    def pred_one(self, height, weight, types, abilities, max_stats):
        row = [height / self.gen_data.max_height, weight / self.gen_data.max_weight] + \
              list(self.gen_data.types_binarizer.transform([types])[0])
        if len(abilities):
              row = row + list(self.gen_data.abilities_binarizer.transform([abilities])[0])
        row = row + [max_stats / self.gen_data.stat_max]
        row_tensor = tf.expand_dims(tf.constant(row), axis=0)
        self.pred_tensor = row_tensor
        pred_values = self.model(row_tensor)
        return pred_values * max_stats


class MainGui:
    def __init__(
        self
    ):
        self.selection_box = None
    
    def create_setup_elements(self):
        self.generation_selection = w.Dropdown(options=[f'gen_{i}' for i in range(1,10)])
        self.gen_selection_button = w.Button(description='Select Generation')
        def genbtn_script(Btn):
            self.generation = self.generation_selection.value
            self.data_obj = DataSetup(generation=self.generation)
            self.create_selection_elements()
            self.create_pred_button()
            clear_output()
            display(
                w.VBox([
                    self.selection_box,
                    self.pred_button
                ])
            )
        self.gen_selection_button.on_click(genbtn_script)
        
        self.setup_box = w.VBox([
            self.generation_selection,
            self.gen_selection_button
        ])
        
    def create_selection_elements(self):
        self.gen_label = w.Label(self.generation.replace('_', 'eration ').capitalize())
        self.type_selector = w.SelectMultiple(options=self.data_obj.gen_data.types_binarizer.classes_)
        self.ability_selector = w.SelectMultiple(options=self.data_obj.gen_data.abilities_binarizer.classes_) 
        self.weight_selector = w.FloatText()
        self.height_selector = w.FloatText()
        self.max_stat_selector = w.IntText()
        
        self.reset_button = w.Button(description='Reset Gui')
        def reset_btn_fn(Btn):
            clear_output()
            display(self.setup_box)
        self.reset_button.on_click(reset_btn_fn)
        
        self.selection_box = w.VBox([
            self.reset_button,
            self.gen_label,
            w.HBox([w.Label('Types: '), self.type_selector, 
                    w.Label('Abilities: ') ,self.ability_selector]),
            w.HBox([w.Label('Weight: ') ,self.weight_selector,
                    w.Label('Height: '), self.height_selector]),
            w.HBox([w.Label('Stat Total: '), self.max_stat_selector])
        ])
        
    def create_pred_button(self):
        self.pred_button = w.Button(description='Predict stats')
        def pred_btn_fn(Btn):
            pred_stats = self.data_obj.pred_one(
                height=self.height_selector.value, 
                weight=self.weight_selector.value, 
                types=self.type_selector.value, 
                abilities=self.ability_selector.value, 
                max_stats=self.max_stat_selector.value
            )[0].numpy()
            
            output_vbox = w.VBox([
                w.HBox([w.Label(f'{self.data_obj.gen_data.STATS[i]}: '), w.Label(str(pred_stats[i]))]) for i in range(len(pred_stats))
            ])
            clear_output()
            display(
                w.VBox([
                    self.selection_box,
                    self.pred_button,
                    output_vbox
                ])
            )
        self.pred_button.on_click(pred_btn_fn)
    
    def draw(self):
        self.create_setup_elements()
        display(self.setup_box)
        

In [None]:
if __name__ == '__main__':
    maingui = MainGui()
    maingui.draw()