In [12]:
import pandas as pd
import tensorflow as tf
import numpy as np
import copy
import random

In [13]:
batch_size = 1024
learning_rate = 0.001

In [14]:
@tf.keras.saving.register_keras_serializable()
class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(units=128, activation=tf.nn.leaky_relu)
        self.dense2 = tf.keras.layers.Dense(units=1024, activation=tf.nn.leaky_relu)
        self.dense3 = tf.keras.layers.Dense(units=128, activation=tf.nn.leaky_relu)
        self.dense4 = tf.keras.layers.Dense(units=1024, activation=tf.nn.leaky_relu)
        self.dense5 = tf.keras.layers.Dense(units=8)

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        x = self.dense3(x)
        x = self.dense4(x)
        output = self.dense5(x)
        return output

In [15]:
class ParaServer:
    def __init__(self):
        self.model = MLP()
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    def upload(self, grads):
        self.optimizer.apply_gradients(grads_and_vars=zip(grads, self.model.variables))
        return self.model
    def download(self):
        return self.model
    def initModel(self, x):
        self.model(x)

In [16]:
def valiAll(index_epoch):
    m = ps.download()
    model = copy.deepcopy(m)
    y_v_p = model(X_v)
    va_mse = tf.reduce_mean(tf.square(y_v_p - y_v))
    va_rmse = tf.sqrt(va_mse)
    va_mae = tf.reduce_mean(tf.abs(y_v_p - y_v))
    va_r2 = 1 - tf.reduce_sum(tf.square(y_v_p - y_v)) / tf.reduce_sum(tf.square(y_v - tf.reduce_mean(y_v)))
    print("mse:{} rmse:{} mae:{} r2:{}".format(va_mse, va_rmse, va_mae, va_r2))
    r2sv[index_epoch] = va_r2.numpy()

In [17]:
class Node:
    def __init__(self, id, freq):
        self.id = id
        self.freq = freq
        self.model = MLP()
        self.dataset1 = pd.read_csv('./20-24Trainset.csv', encoding='utf-8').sample(frac=1).reset_index(drop=True)
        self.dataset2 = pd.read_csv('./50-54Trainset.csv', encoding='utf-8').sample(frac=1).reset_index(drop=True)
        self.dataset = pd.concat([self.dataset1, self.dataset2], axis=0)
        self.dataset = self.dataset[self.dataset['freq'].isin(self.freq)]
        self.X = self.dataset.loc[:,'freq':'L2'].to_numpy(dtype = np.float32)
        self.y = self.dataset.loc[:,'S11r':'S41i'].to_numpy(dtype = np.float32)
        self.dataset_train = tf.data.Dataset.from_tensor_slices((self.X, self.y))
        self.dataset_train = self.dataset_train.shuffle(buffer_size=23000)
        self.dataset_train = self.dataset_train.batch(batch_size)
        self.dataset_train = self.dataset_train.prefetch(tf.data.experimental.AUTOTUNE)
    def train(self, index_epoch):
        self.model = ps.download()
        for X, y in self.dataset_train:
            with tf.GradientTape() as tape:
                y_pred = self.model(X)
                tr_mse = tf.reduce_mean(tf.square(y_pred - y))
            tr_rmse = tf.sqrt(tr_mse)
            tr_mae = tf.reduce_mean(tf.abs(y_pred - y))
            tr_r2 = 1 - tf.reduce_sum(tf.square(y_pred - y)) / tf.reduce_sum(tf.square(y - tf.reduce_mean(y)))
            grads = tape.gradient(tr_mse, self.model.variables)
            self.model = ps.upload(grads)
        # if epoch_index in np.arange(0, num_epochs, 25).tolist() or epoch_index == num_epochs - 1:
        print("node:{} epoch:{}".format(self.freq, index_epoch))
        print("train mse:{} rmse:{} mae:{} r2:{}".format(tr_mse, tr_rmse, tr_mae, tr_r2))
        r2s[self.id][index_epoch] = tr_r2.numpy()

In [18]:
r2s = {0:{}, 1:{}}
r2sv = {}

In [19]:
test_dataset = pd.read_csv("testset.csv", encoding='utf-8').sample(frac=1).reset_index(drop=True)
X_v = test_dataset.loc[:,'freq':'L2'].to_numpy(dtype = np.float32)
y_v = test_dataset.loc[:,'S11r':'S41i'].to_numpy(dtype = np.float32)

In [20]:
ps = ParaServer()
ps.initModel(X_v)

In [21]:
nodeList = [Node(0, [2.0, 2.3, 5.1, 5.4]), Node(1, [2.1, 2.4, 5.2]), Node(2, [2.2, 5.0, 5.3])]

In [30]:
orders = [0, 1]
turn = [np.array([[92, 146], [158, 255], [347, 475], [531, 555]]), np.array([[42, 116], [226, 277], [363, 423], [543, 600]]), np.array([[0, 200], [214, 252], [271, 347], [474, 528]])]
for i in range(400):
    random.shuffle(orders)
    for j in orders:
        for l, r in turn[j]:
            if l <= i < r:
                nodeList[j].train(i)
    valiAll(i)

mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025703944265842438 rmse:0.050699055194854736 mae:0.032664213329553604 r2:0.9789476990699768
mse:0.0025

KeyboardInterrupt: 

In [16]:
for k, v in r2s[2.6].items():
    print(v)

0.5486207
0.5958206
0.6077136
0.6231611
0.6104752
0.6293508
0.6425981
0.6651129
0.6868254
0.70473087
0.7077899
0.7163201
0.74149126
0.7564358
0.7607001
0.7586043
0.7359556
0.77295303
0.7995131
0.79581016
0.81139106
0.79567385
0.7558851
0.7826637
0.7788899
0.7793881
0.8076835
0.8089497
0.8327596
0.81962246
0.7543961
0.7840984
0.81794167
0.79270005
0.82521594
0.78134686
0.83092093
0.80101323
0.8339965
0.7948794
0.84356713
0.8465845
0.79730135
0.83424723
0.8134524
0.80595034
0.83938706
0.846708
0.80759996
0.84521496
0.85964924
0.81562316
0.7884283
0.8559452
0.82728237
0.8672112
0.859355
0.8188611
0.8648717
0.86271995
0.79972637
0.85287476
0.8543722
0.8266123
0.82080746
0.8168412
0.8181273
0.8719065
0.82265854
0.8701856
0.8311549
0.8286662
0.8626292
0.87516916
0.8826012
0.8897865
0.8933047
0.89149964
0.89433753
0.89729017
0.8999928
0.8984779
0.9020348
0.88937795
0.90497017
0.8998585
0.89999324
0.9050486
0.90359885
0.9058355
0.91599554
0.9061472
0.9080994
0.90987
0.9071824
0.9126997
0.91493