In [6]:
from gat import GAT 
import process

import numpy as np
import tensorflow as tf

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops

In [None]:
class BasicModel:
    def __init__(self, dim_input, dim_output, seq_length,
                 filter_num, dim_cnn_flatten, dim_fc, dim_lstm_hidden,
                 update_lr, meta_lr, meta_batch_size, update_batch_size,
                 test_num_updates):
        """ must call construct_model() after initializing MAML! """
        self.dim_input = dim_input
        self.channels = dim_output
        self.img_size = int(np.sqrt(self.dim_input / self.channels))

        self.dim_output = dim_output
        self.seq_length = seq_length
        self.filter_num = filter_num
        self.dim_cnn_flatten = dim_cnn_flatten
        self.dim_fc = dim_fc
        self.dim_lstm_hidden = dim_lstm_hidden

        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.update_batch_size = update_batch_size
        self.test_num_updates = test_num_updates

        self.meta_batch_size = meta_batch_size

        self.inputa = tf.placeholder(tf.float32)
        self.inputb = tf.placeholder(tf.float32)
        self.labela = tf.placeholder(tf.float32)
        self.labelb = tf.placeholder(tf.float32)

    def update(self, loss, weights):
        grads = tf.gradients(loss, list(weights.values()))
        gradients = dict(zip(weights.keys(), grads))
        new_weights = dict(
            zip(weights.keys(), [weights[key] - self.update_lr * gradients[key] for key in weights.keys()]))
        return new_weights

    def construct_convlstm(self):
        weights = {}
        dtype = tf.float32
        conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
        k = 3

        weights['conv1'] = tf.get_variable('conv1', [k, k, self.channels, self.filter_num],
                                           initializer=conv_initializer, dtype=dtype)
        weights['b_conv1'] = tf.Variable(tf.zeros([self.filter_num]))

        weights['conv2'] = tf.get_variable('conv2', [k, k, self.filter_num, self.filter_num],
                                           initializer=conv_initializer, dtype=dtype)
        weights['b_conv2'] = tf.Variable(tf.zeros([self.filter_num]))

        weights['conv3'] = tf.get_variable('conv3', [k, k, self.filter_num, self.filter_num],
                                           initializer=conv_initializer, dtype=dtype)
        weights['b_conv3'] = tf.Variable(tf.zeros([self.filter_num]))

        weights['fc1'] = tf.Variable(tf.random_normal([self.dim_cnn_flatten, self.dim_fc]), name='fc1')
        weights['b_fc1'] = tf.Variable(tf.zeros([self.dim_fc]))

        weights['kernel_lstm'] = tf.get_variable('kernel_lstm', [self.dim_fc + self.dim_lstm_hidden,
                                                                 4 * self.dim_lstm_hidden])
        weights['b_lstm'] = tf.Variable(tf.zeros([4 * self.dim_lstm_hidden]))

        weights['b_fc2'] = tf.Variable(tf.zeros([self.dim_output]))

        return weights

    def lstm(self, inp, weights):
        def lstm_block(linp, pre_state, kweight, bweight, activation):
            sigmoid = math_ops.sigmoid
            one = constant_op.constant(1, dtype=dtypes.int32)
            c, h = pre_state

            gate_inputs = math_ops.matmul(
                array_ops.concat([linp, h], 1), kweight)
            gate_inputs = nn_ops.bias_add(gate_inputs, bweight)

            i, j, f, o = array_ops.split(
                value=gate_inputs, num_or_size_splits=4, axis=one)

            forget_bias_tensor = constant_op.constant(1.0, dtype=f.dtype)

            add = math_ops.add
            multiply = math_ops.multiply
            new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
                        multiply(sigmoid(i), activation(j)))
            new_h = multiply(activation(new_c), sigmoid(o))

            new_state = [new_c, new_h]
            return new_h, new_state

        # unstack对矩阵分解
        # transpose多维矩阵转置 perm=[1,0,2] 例如：2*3*4 -> 3*2*4
        inp = tf.unstack(tf.transpose(inp, perm=[1, 0, 2]))
        state = [tf.zeros([self.update_batch_size, self.dim_lstm_hidden]),
                 tf.zeros([self.update_batch_size, self.dim_lstm_hidden])]
        output = None
        for t in range(len(inp)):
            output, state = lstm_block(inp[t], state,
                                       weights['kernel_lstm'], weights['b_lstm'],
                                       tf.nn.tanh)
        return output

    def forward_convlstm(self, inp, weights, feature_size, nb_nodes, is_train,
                                attn_drop, ffd_drop,
                                bias_mat=bias_in,
                                hid_units=hid_units, n_heads=n_heads,
                                residual=residual, activation=nonlinearity):
        
        inp = tf.reshape(inp, [-1, self.dim_input])

        gat_outputs = GAT.inference(inp, feature_size, nb_nodes, is_train,
                                attn_drop, ffd_drop,
                                bias_mat=bias_in,
                                hid_units=hid_units, n_heads=n_heads,
                                residual=residual, activation=nonlinearity)
        #gat_outputs(1,nodeNum,features)
        
        cnn_outputs = tf.reshape(cnn_outputs, [-1, self.seq_length, self.dim_fc])

        lstm_outputs = self.lstm(cnn_outputs, weights)
        return lstm_outputs



In [None]:
class STDN(BasicModel):
    def __init__(self, dim_input, dim_output, seq_length,
                 filter_num, dim_cnn_flatten, dim_fc, dim_lstm_hidden,
                 update_lr, meta_lr, meta_batch_size, update_batch_size,
                 test_num_updates):
        print("Initializing STDN...")
        BasicModel.__init__(self, dim_input, dim_output, seq_length,
                            filter_num, dim_cnn_flatten, dim_fc, dim_lstm_hidden,
                            update_lr, meta_lr, meta_batch_size, update_batch_size,
                            test_num_updates)

    def loss_func(self, pred, label):
        pred = tf.reshape(pred, [-1])
        label = tf.reshape(label, [-1])
        return tf.reduce_mean(tf.square(pred - label))

    def construct_model(self):
        with tf.variable_scope('model', reuse=None):
            with tf.variable_scope('maml', reuse=None):
                self.weights = weights = self.construct_convlstm()
                weights['fc2'] = tf.Variable(tf.random_normal(
                    [self.dim_lstm_hidden, self.dim_output]), name='fc6')   # output layer

            num_updates = self.test_num_updates

            def task_metalearn(inp):
                """ Perform gradient descent for one task in the meta-batch. """
                inputa, inputb, labela, labelb = inp
                task_outputbs, task_lossesb = [], []

                task_outputa = self.forward(inputa, weights)  # only reuse on the first iter
                task_lossa = self.loss_func(task_outputa, labela)

                fast_weights = self.update(task_lossa, weights)

                output = self.forward(inputb, fast_weights)
                task_outputbs.append(output)
                task_lossesb.append(self.loss_func(output, labelb))

                for j in range(num_updates - 1):
                    loss = self.loss_func(self.forward(inputa, fast_weights), labela)
                    fast_weights = self.update(loss, fast_weights)

                    output = self.forward(inputb, fast_weights)
                    task_outputbs.append(output)
                    task_lossesb.append(self.loss_func(output, labelb))

                task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]
                return task_output

            out_dtype = [tf.float32, [tf.float32]*num_updates, tf.float32, [tf.float32]*num_updates]

            inputs = (self.inputa, self.inputb, self.labela, self.labelb)
            result = tf.map_fn(task_metalearn,
                               elems=inputs,
                               dtype=out_dtype,
                               parallel_iterations=self.meta_batch_size)
            outputas, outputbs, lossesa, lossesb = result

        # Performance & Optimization
        self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(self.meta_batch_size)
        self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(self.meta_batch_size)
                                              for j in range(num_updates)]
        self.total_rmse1 = tf.sqrt(lossesa)
        self.total_rmse2 = [tf.sqrt(total_losses2[j]) for j in range(num_updates)]

        self.outputas, self.outputbs = outputas, outputbs
        self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1)
        self.metatrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_losses2[num_updates-1])

        maml_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "model/maml")
        self.finetune_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1, var_list=maml_vars)

    def forward(self, inp, weights):
        convlstm_outputs = self.forward_convlstm(inp, weights)
        preds = tf.nn.tanh(tf.matmul(convlstm_outputs, weights['fc2']) + weights['b_fc2'])
        return preds
