In [2]:
import pandas as pd
import tensorflow as tf
import numpy as np
import copy

In [3]:
batch_size = 1024
learning_rate = 0.001

In [4]:
@tf.keras.saving.register_keras_serializable()
class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(units=128, activation=tf.nn.leaky_relu)
        self.dense2 = tf.keras.layers.Dense(units=1024, activation=tf.nn.leaky_relu)
        self.dense3 = tf.keras.layers.Dense(units=128, activation=tf.nn.leaky_relu)
        self.dense4 = tf.keras.layers.Dense(units=1024, activation=tf.nn.leaky_relu)
        self.dense5 = tf.keras.layers.Dense(units=8)

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        x = self.dense3(x)
        x = self.dense4(x)
        output = self.dense5(x)
        return output

In [5]:
class ParaServer:
    def __init__(self):
        self.model = MLP()
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    def upload(self, grads):
        self.optimizer.apply_gradients(grads_and_vars=zip(grads, self.model.variables))
        return self.model
    def download(self):
        return self.model
    def initModel(self, x):
        self.model(x)

In [6]:
def valiAll(index_epoch):
    m = ps.download()
    model = copy.deepcopy(m)
    y_v_p = model(X_v)
    va_mse = tf.reduce_mean(tf.square(y_v_p - y_v))
    va_rmse = tf.sqrt(va_mse)
    va_mae = tf.reduce_mean(tf.abs(y_v_p - y_v))
    va_r2 = 1 - tf.reduce_sum(tf.square(y_v_p - y_v)) / tf.reduce_sum(tf.square(y_v - tf.reduce_mean(y_v)))
    print("mse:{} rmse:{} mae:{} r2:{}".format(va_mse, va_rmse, va_mae, va_r2))
    r2sv[index_epoch] = va_r2.numpy()

In [12]:
class Node:
    def __init__(self, freq):
        self.freq = freq
        self.model = MLP()
        self.dataset = pd.read_csv('./trainset.csv', encoding='utf-8').sample(frac=1).reset_index(drop=True)
        self.dataset = self.dataset[(self.dataset['freq'] >= freq[0]) & (self.dataset['freq'] <= freq[1])]
        self.X = self.dataset.loc[:,'freq':'L2'].to_numpy(dtype = np.float32)
        self.y = self.dataset.loc[:,'S11r':'S41i'].to_numpy(dtype = np.float32)
        self.dataset_train = tf.data.Dataset.from_tensor_slices((self.X, self.y))
        self.dataset_train = self.dataset_train.shuffle(buffer_size=23000)
        self.dataset_train = self.dataset_train.batch(batch_size)
        self.dataset_train = self.dataset_train.prefetch(tf.data.experimental.AUTOTUNE)
    def train(self, index_epoch):
        m = ps.download()
        self.model = copy.deepcopy(m)
        for X, y in self.dataset_train:
            with tf.GradientTape() as tape:
                y_pred = self.model(X)
                tr_mse = tf.reduce_mean(tf.square(y_pred - y))
            tr_rmse = tf.sqrt(tr_mse)
            tr_mae = tf.reduce_mean(tf.abs(y_pred - y))
            tr_r2 = 1 - tf.reduce_sum(tf.square(y_pred - y)) / tf.reduce_sum(tf.square(y - tf.reduce_mean(y)))
            grads = tape.gradient(tr_mse, self.model.variables)
            m = ps.upload(grads)
            self.model = copy.deepcopy(m)
        # if epoch_index in np.arange(0, num_epochs, 25).tolist() or epoch_index == num_epochs - 1:
        if True:
            print("node:{} epoch:{}".format(self.freq, index_epoch))
            print("train mse:{} rmse:{} mae:{} r2:{}".format(tr_mse, tr_rmse, tr_mae, tr_r2))
            r2s[self.freq[0]][index_epoch] = tr_r2.numpy()

In [13]:
r2s = {1.0:{}, 2.7:{}, 4.4:{}}
r2sv = {}

In [14]:
test_dataset = pd.read_csv("testset.csv", encoding='utf-8').sample(frac=1).reset_index(drop=True)
X_v = test_dataset.loc[:,'freq':'L2'].to_numpy(dtype = np.float32)
y_v = test_dataset.loc[:,'S11r':'S41i'].to_numpy(dtype = np.float32)

In [15]:
ps = ParaServer()
ps.initModel(X_v)

In [16]:
nodeList = [Node((1.0, 2.6)), Node((2.7, 4.3)), Node((4.4, 6.0))]

In [17]:
orders = [0, 1, 2]
turn = [np.array([[26, 104], [178, 312], [344, 464], [520, 600]]), np.array([[0, 94], [149, 223], [319, 433], [464, 580]]), np.array([[32, 151], [155, 248], [270, 354], [378, 502]])]
for i in range(600):
    random.shuffle(orders)
    for j in orders:
        for l, r in turn[j]:
            if l <= i < r:
                nodeList[j].train(i)
    valiAll(i)

2023-09-14 01:55:06.856912: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55f64852b140 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-09-14 01:55:06.856932: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2023-09-14 01:55:06.856936: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (1): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2023-09-14 01:55:06.860048: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-09-14 01:55:06.951186: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8600
2023-09-14 01:55:07.058633: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


node:(2.7, 4.3) epoch:0
train mse:0.043809596449136734 rmse:0.20930741727352142 mae:0.16039542853832245 r2:0.6301136016845703
mse:0.1492132693529129 rmse:0.3862813413143158 mae:0.3013881742954254 r2:-0.2220081090927124
node:(2.7, 4.3) epoch:1
train mse:0.03530925139784813 rmse:0.18790756165981293 mae:0.1421242356300354 r2:0.7042906284332275
mse:0.16464200615882874 rmse:0.4057610034942627 mae:0.31490427255630493 r2:-0.3483644723892212


KeyboardInterrupt: 

In [16]:
for k, v in r2s[2.6].items():
    print(v)

0.5486207
0.5958206
0.6077136
0.6231611
0.6104752
0.6293508
0.6425981
0.6651129
0.6868254
0.70473087
0.7077899
0.7163201
0.74149126
0.7564358
0.7607001
0.7586043
0.7359556
0.77295303
0.7995131
0.79581016
0.81139106
0.79567385
0.7558851
0.7826637
0.7788899
0.7793881
0.8076835
0.8089497
0.8327596
0.81962246
0.7543961
0.7840984
0.81794167
0.79270005
0.82521594
0.78134686
0.83092093
0.80101323
0.8339965
0.7948794
0.84356713
0.8465845
0.79730135
0.83424723
0.8134524
0.80595034
0.83938706
0.846708
0.80759996
0.84521496
0.85964924
0.81562316
0.7884283
0.8559452
0.82728237
0.8672112
0.859355
0.8188611
0.8648717
0.86271995
0.79972637
0.85287476
0.8543722
0.8266123
0.82080746
0.8168412
0.8181273
0.8719065
0.82265854
0.8701856
0.8311549
0.8286662
0.8626292
0.87516916
0.8826012
0.8897865
0.8933047
0.89149964
0.89433753
0.89729017
0.8999928
0.8984779
0.9020348
0.88937795
0.90497017
0.8998585
0.89999324
0.9050486
0.90359885
0.9058355
0.91599554
0.9061472
0.9080994
0.90987
0.9071824
0.9126997
0.91493