In [None]:
from local.node import Node
from local.constnode import ConstNode
from local.varnode import VarNode
from local.funcnode import FuncNode
from local.func import *

import os
import glob
import copy
import random
import pickle
import graphviz
import datetime
from PIL import Image
from tqdm import tqdm
import tensorflow as tf
from tensorflow.keras import datasets, layers, models, losses, metrics, optimizers

In [None]:
class TreeOptimizer:
    # 一つの木には1つの数学演算子とその引数(定数,変数)で構成する
    mutate_probability = 0.1
    max_depth = 1

    def __init__(self):
        # 構成する木の本数を決定
        self.tree_num = random.randint(1,10)
        # 仮枠の木をはやす
        self.tree_list = [FuncNode(max_depth=self.max_depth) for i in range(self.tree_num)]
        # 引数を決める際に，他の木の出力を用いられるように，set_random_idの時に用いる辞書を作成
        self.tree_dict = {i : {"node_type" : "function"} for i in range(self.tree_num)}

        self.set_node_id()
        self.set_random_id()
        self.set_save_flag()
        self.initialize_state()

    def build(self, var_list):
        for i in range(self.tree_num):
            self.tree_list[i].save_flag = True
            self.tree_list[i].build(var_list)
            self.iteration_slots = list()
            for V in var_list:
                self.iteration_slots.append( tf.Variable( tf.zeros(V.shape) ) )
            self.tree_list[i].is_built = True
    
    def set_node_id(self):
        node_ids = [list() for i in range(self.tree_num)]
        for i in range(self.tree_num):
            self.tree_list[i].set_node_id(node_ids[i])

    def set_random_id(self):
        for i in range(self.tree_num):
            self.tree_list[i].set_random_id(self.tree_dict)

    # ノードの計算を保存するか決める関数
    def set_save_flag(self):
        # とりあえず，全部保存する方向でTrue
        for i in range(self.tree_num):
            # self.tree_list[i].set_save_flag(self.tree_dict)
            self.tree_list[i].save_flag = True

    def initialize_state(self):
        for i in range(self.tree_num):
            self.tree_list[i].is_built = False
            self.tree_list[i].iteration = 0.0

    def make_struct_dict(self):
        struct_dict = {i : dict() for i in range(self.tree_num)}
        for i in range(self.tree_num):
            self.tree_list[i].make_struct_dict(struct_dict[i])
        return struct_dict

    def make_variable_dict(self, slot_index: int):
        variable_dict = dict()
        for i in range(self.tree_num):
            variable_dict[i] = self.tree_list[i].slots[slot_index]
        return variable_dict
    
    def update_step(self, gradient, parameter, slot_index):
        # 複数の木で構成し，引数に他の木の出力を用いるときのために，
        # 作成するvariable_dictの要素に各木の出力を保持するように変更
        variable_dict = self.make_variable_dict(slot_index)
        variable_dict["gradient"] = gradient
        variable_dict["parameter"] = parameter
        self.iteration_slots[slot_index].assign_add(tf.ones(parameter.shape))
        variable_dict["iteration"] = self.iteration_slots[slot_index]

        for i in range(self.tree_num):
            if not self.tree_list[i].skip_flag:
                self.tree_list[i](variable_dict, slot_index)
            
        parameter.assign_sub(self.tree_list[self.tree_num-1].slots[slot_index])
    
    def apply_gradients(self, grad_and_vars):
        for i in range(self.tree_num):
            if self.tree_list[i].is_built == False:
                var_list = [GaV[1] for GaV in grad_and_vars]
                self.build(var_list)
        for i, (G, V) in enumerate(grad_and_vars):
            self.update_step(G, V, i)
            
    def mutation(self):
        # 突然変異を行う次元の選択
        mutate_index = random.choice(range(self.tree_num))

        p = random.random()
        # 遺伝子の削除 --> 削除して参照先がなくなってしまったときのために0として扱う 
        if p < 0.3:
            tmp = FuncNode(max_depth = 1, function = SignFunc())
            tmp.args[0] = ConstNode(depth = 1, constant = 0.0)
            tmp.skip_flag = True
            self.tree_list[mutate_index] = tmp
        # 遺伝子の追加
        elif p < 0.6:
            tmp = FuncNode(max_depth = 1)
            tmp.skip_flag = False
            self.tree_list.insert(mutate_index, tmp)
        # 要素の変更
        else:
            p = random.random()
            # 葉の要素の変更
            if p < 0.3:
                for i in range(len(self.tree_list[mutate_index].args)):
                    if random.choice([0, 1]):
                        self.tree_list[mutate_index].args[i] = VarNode(depth = 1)
                    else:
                        self.tree_list[mutate_index].args[i] = ConstNode(depth = 1)
            # 関数のみ変更
            elif p < 0.6:
                args = copy.deepcopy(self.tree_list[mutate_index].args)
                self.tree_list[mutate_index] = FuncNode(max_depth = 1)
                self.tree_list[mutate_index].skip_flag = False
                if len(args) >= len(self.tree_list[mutate_index].args):
                    for i in range(len(self.tree_list[mutate_index].args)):
                        self.tree_list[mutate_index].args[i] = args[i]
                else:
                    self.tree_list[mutate_index].args[0] = args[0]
                    if random.choice([0, 1]):
                        self.tree_list[mutate_index].args[1] = VarNode(depth = 1)
                    else:
                        self.tree_list[mutate_index].args[1] = ConstNode(depth = 1)
            # 全て変更
            else:
                self.tree_list[mutate_index] = FuncNode(max_depth = 1)
                self.tree_list[mutate_index].skip_flag = False
                
        
        self.tree_num = len(self.tree_list)
        self.tree_dict = {i : {"node_type" : "function"} for i in range(self.tree_num)}

        for i in range(self.tree_num):
            for j in range(len(self.tree_list[i].args)):
                if isinstance(self.tree_list[i].args[j], VarNode) and self.tree_list[i].args[j].variable_id == None:
                    self.tree_list[i].set_random_id(self.tree_dict)
        
        self.set_node_id()
        self.initialize_state()



class SGD(TreeOptimizer):
    def __init__(self, learning_rate:float = 0.001):

        self.tree_num = 1
        self.tree_list = [FuncNode(max_depth=self.max_depth) for i in range(self.tree_num)]
        self.tree_dict = {i : {"node_type" : "function"} for i in range(self.tree_num)}

        self.tree_list[0] = FuncNode(max_depth=self.max_depth, function=MulFunc())
        self.tree_list[0].args[0] = ConstNode(depth = 2, constant = learning_rate)
        self.tree_list[0].args[1] = VarNode(depth = 2, variable_id = "gradient")

        self.set_node_id()
        self.set_save_flag()
        self.initialize_state()

class Momentum(TreeOptimizer):
    def __init__(self, learning_rate:float = 0.001, momentum=0.9):

        self.tree_num = 3
        self.tree_list = [FuncNode(max_depth=self.max_depth) for i in range(self.tree_num)]
        self.tree_dict = {i : {"node_type" : "function"} for i in range(self.tree_num)}
        # tree_list[0].slots = m * v_t
        self.tree_list[0] = FuncNode(max_depth=self.max_depth, function=MulFunc())
        self.tree_list[0].args[0] = ConstNode(depth = 1, constant = momentum)
        self.tree_list[0].args[1] = VarNode(depth = 1, variable_id = 2)
        # tree_list[1].slots = lr * gradient
        self.tree_list[1] = FuncNode(max_depth=self.max_depth, function=MulFunc())
        self.tree_list[1].args[0] = ConstNode(depth = 1, constant = learning_rate)
        self.tree_list[1].args[1] = VarNode(depth = 1, variable_id="gradient")
        # tree_list[2].slots = tree_list[0].slots - tree_list[1].slots
        self.tree_list[2] = FuncNode(max_depth=self.max_depth, function=AddFunc())
        self.tree_list[2].args[0] = VarNode(depth = 1, variable_id = 0)
        self.tree_list[2].args[1] = VarNode(depth = 1, variable_id = 1)
    
        self.set_node_id()
        self.set_save_flag()
        self.initialize_state()

class RMSProp(TreeOptimizer):
    def __init__(self, learning_rate = 0.001, rho = 0.9, epsilon = 1e-7):
        self.tree_num = 9
        self.tree_list = [FuncNode(max_depth=self.max_depth) for i in range(self.tree_num)]
        self.tree_dict = {i : {"node_type" : "function"} for i in range(self.tree_num)}

        # rho * h_t
        self.tree_list[0] = FuncNode(max_depth = 1, function=MulFunc())
        self.tree_list[0].args[0] = ConstNode(depth = 1, constant = rho)
        self.tree_list[0].args[1] = VarNode(depth = 1, variable_id = 4)
        # 1.0 - rho
        self.tree_list[1] = FuncNode(max_depth = 1, function=SubFunc())
        self.tree_list[1].args[0] = ConstNode(depth = 1, constant = 1.0)
        self.tree_list[1].args[1] = ConstNode(depth = 1, constant = rho)
        # gradient ^ 2
        self.tree_list[2] = FuncNode(max_depth = 1, function=SquareFunc())
        self.tree_list[2].args[0] = VarNode(depth = 1, variable_id = "gradient")
        # (1.0 - rho) * gradient ^ 2
        self.tree_list[3] = FuncNode(max_depth = 1, function=MulFunc())
        self.tree_list[3].args[0] = VarNode(depth = 1, variable_id = 1)
        self.tree_list[3].args[1] = VarNode(depth = 1, variable_id = 2)
        # rho * h_t + (1.0 - rho) * gradient ^ 2
        self.tree_list[4] = FuncNode(max_depth = 1, function=AddFunc())
        self.tree_list[4].args[0] = VarNode(depth = 1, variable_id = 0)
        self.tree_list[4].args[1] = VarNode(depth = 1, variable_id = 3)
        # sqrt(h_t+1)
        self.tree_list[5] = FuncNode(max_depth = 1, function = SqrtFunc())
        self.tree_list[5].args[0] = VarNode(depth = 1, variable_id = 4)
        # h_t+1 + epsilon
        self.tree_list[6] = FuncNode(max_depth = 1, function=AddFunc())
        self.tree_list[6].args[0] = ConstNode(depth = 1, constant = epsilon)
        self.tree_list[6].args[1] = VarNode(depth = 1, variable_id = 5)
        # learning_rate / sqrt(h_t+1 + epsilon)
        self.tree_list[7] = FuncNode(max_depth = 1, function = DivFunc())
        self.tree_list[7].args[0] = ConstNode(depth = 1, constant = learning_rate)
        self.tree_list[7].args[1] = VarNode(depth = 1, variable_id = 6)
        # {learning_rate / sqrt(h_t+1 + epsilon)} * gradient
        self.tree_list[8] = FuncNode(max_depth = 1, function = MulFunc())
        self.tree_list[8].args[0] = VarNode(depth = 1, variable_id = 7)
        self.tree_list[8].args[1] = VarNode(depth = 1, variable_id = "gradient")

        self.set_node_id()
        self.set_save_flag()
        self.initialize_state()

In [None]:
BATCH_SIZE = 512
EPOCHS = 10
def load_data():
    (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
    x_train, x_test = x_train[..., tf.newaxis] / 255.0, x_test[..., tf.newaxis] / 255.0

    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(BATCH_SIZE)
    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

    return train_dataset, test_dataset

def build_model():
    input = layers.Input(shape = (32, 32, 3))
    x = layers.Conv2D(32, 3, 2)(input)
    x = layers.LeakyReLU()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(units = 32, activation = "relu")(x)
    output = layers.Dense(units = 10, activation = "softmax")(x)

    return models.Model(input, output)

train_dataset, test_dataset = load_data()

model = build_model()
model.save_weights("model.keras")

In [None]:
optimizer = optimizers.RMSprop(learning_rate=0.001,)

model.load_weights("model.keras")

loss_obj = metrics.Mean()
accuracy_obj = metrics.SparseCategoricalAccuracy()
val_loss_obj = metrics.SparseCategoricalCrossentropy()
val_accuracy_obj = metrics.SparseCategoricalAccuracy()

@tf.function
def train_step(X, Y):
    with tf.GradientTape() as tape:
        pred = model(X)
        loss = losses.SparseCategoricalCrossentropy()(Y, pred)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    loss_obj(loss)
    accuracy_obj(Y, pred)

@tf.function
def test_step(X, Y):
    pred = model(X)
    
    val_loss_obj(Y, pred)
    val_accuracy_obj(Y, pred)

for i in range(EPOCHS):
    
    loss_obj.reset_state()
    accuracy_obj.reset_state()
    val_loss_obj.reset_state()
    val_accuracy_obj.reset_state()

    for X, Y in tqdm(train_dataset):
        train_step(X, Y)
        
    for X, Y in tqdm(test_dataset):
        test_step(X, Y)

    print(f"epoch : {i}, loss : {float(loss_obj.result())}, accuracy : {float(accuracy_obj.result())}")
    print(f"val_loss : {float(val_loss_obj.result())}, val_accuracy : {float(val_accuracy_obj.result())}")

In [None]:
optimizer = RMSProp()

model.load_weights("model.keras")

loss_obj = metrics.Mean()
accuracy_obj = metrics.SparseCategoricalAccuracy()
val_loss_obj = metrics.SparseCategoricalCrossentropy()
val_accuracy_obj = metrics.SparseCategoricalAccuracy()

@tf.function
def train_step(X, Y):
    with tf.GradientTape() as tape:
        pred = model(X)
        loss = losses.SparseCategoricalCrossentropy()(Y, pred)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    loss_obj(loss)
    accuracy_obj(Y, pred)

@tf.function
def test_step(X, Y):
    pred = model(X)
    
    val_loss_obj(Y, pred)
    val_accuracy_obj(Y, pred)

for i in range(EPOCHS):
    
    loss_obj.reset_state()
    accuracy_obj.reset_state()
    val_loss_obj.reset_state()
    val_accuracy_obj.reset_state()

    for X, Y in tqdm(train_dataset):
        train_step(X, Y)
        
    for X, Y in tqdm(test_dataset):
        test_step(X, Y)

    print(f"epoch : {i}, loss : {float(loss_obj.result())}, accuracy : {float(accuracy_obj.result())}")
    print(f"val_loss : {float(val_loss_obj.result())}, val_accuracy : {float(val_accuracy_obj.result())}")