In [195]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
from PIL import Image
import tensorflow.keras.backend as backend

In [196]:
# NUMBER_NODES = 9
NUMBER_NODES = 7
# NUMBER_NODES = 5
DATASET_PATH = f'../../../datasets/examples/opt_band_{NUMBER_NODES}_nodes_graph.csv'

def load_data():
    featuresNumber = (NUMBER_NODES * NUMBER_NODES - NUMBER_NODES) // 2 

    def get_tuple_dataset(row):
        X = row[0 : featuresNumber].astype('float32')
        Y = row[featuresNumber: ].astype('float32') # opt_band is in target
        return X, Y

    df = pd.read_csv(DATASET_PATH)
    train_dataset_df = df.sample(frac=0.8)
    df_remaining = df.drop(index=train_dataset_df.index)
    val_dataset_df = df_remaining.sample(frac=0.7)
    test_dataset_df = df_remaining.drop(index=val_dataset_df.index)

    train_dataset = list(map(get_tuple_dataset, train_dataset_df.to_numpy()))
    val_dataset = list(map(get_tuple_dataset, val_dataset_df.to_numpy()))
    test_dataset = list(map(get_tuple_dataset, test_dataset_df.to_numpy()))
    return train_dataset, val_dataset, test_dataset
    
train_dataset, val_dataset, test_dataset = load_data()

def getGraph(upperTriangleAdjMatrix):
    dense_adj = np.zeros((NUMBER_NODES, NUMBER_NODES))
    k = 0
    for i in range(NUMBER_NODES):
        for j in range(NUMBER_NODES):
            if i == j:
                continue
            elif i < j:
                dense_adj[i][j] = upperTriangleAdjMatrix[k]
                k += 1
            else:
                dense_adj[i][j] = dense_adj[j][i]
    return dense_adj

def processDataToImage(graphInput):
    adj = getGraph(graphInput)
    w, h = NUMBER_NODES, NUMBER_NODES
    data = np.zeros((h, w), dtype=np.uint8)
    for i in range(len(adj)):
        for j in range(len(adj)):
            if adj[i, j] == 1:
                data[i, j] = 255.0
    img = Image.fromarray(data, 'L')
    resized = img.resize((32, 32), Image.NEAREST)
    image = np.array(resized, dtype=np.float32)
    # image = np.reshape(image, (32, 32, 1))
    image = (image / 255.0) 
    return image

def getData_2(dataset):
    train_images = []
    train_nodelist = []
    data = np.zeros((len(dataset), 50, 21))
    for graphInput, target in dataset:
        graphNodeList = target[1 : ]
        # x_image = processDataToImage(graphInput)
        train_images.append([graphInput])
        train_nodelist.append(graphNodeList)
    return np.array(train_images), np.array(train_nodelist)

In [197]:
x_train, y_train = getData_2(train_dataset)
x_val, y_val = getData_2(val_dataset)
x_test, y_test = getData_2(test_dataset)

print(x_train.shape)
print(y_train.shape)

(835, 5, 21)
(835, 7)


In [198]:
y_train[0]

array([1., 5., 6., 2., 0., 4., 3.], dtype=float32)

In [199]:
class MinimalRNNCell(keras.layers.Layer):

    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = units
        super(MinimalRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units), initializer='uniform', name='kernel')
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer='uniform',
            name='recurrent_kernel')
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = backend.dot(inputs, self.kernel)
        output = h + backend.dot(prev_output, self.recurrent_kernel)
        return output, [output]

In [203]:
loss_object = tf.keras.losses.MeanSquaredError()

def loss_repeated_labels(roundedOutput, roundedTrue):
  true_used, true_indexes = tf.unique(tf.squeeze(roundedTrue))
  used_labels, indexes, counts = tf.unique_with_counts(tf.squeeze(roundedOutput))
  counts = tf.cast(counts, tf.float32)
  # 1 - counts = quao longe os elementos de counts estão de repetir uma vez só (elemento unico)
  squaredDiff = loss_object(tf.ones_like(counts), counts)
  mseIndexes = loss_object(tf.cast(true_indexes, tf.float32), tf.cast(indexes, tf.float32))
  # o quao longe os indexes estao de ser 0,1,2,3,4,5,6
  loss = tf.math.reduce_variance(counts) + squaredDiff + mseIndexes
  return loss

def customLoss(true, pred):
  c = loss_object(true, pred)
  roundedOutput = tf.round(pred)
  roundedTrue = tf.round(true)
  loss_repeated = loss_repeated_labels(roundedOutput, roundedTrue)
  return c + loss_repeated


In [205]:
batch_size = 64
num_batches = 10
timestep = 3

cell = MinimalRNNCell(7)
x = keras.Input((None, 21))
layer = keras.layers.RNN(cell)
y = layer(x)

model = keras.models.Model([x], y)

model.compile(optimizer="adam", loss=customLoss, metrics=["accuracy"])

In [207]:
# input_1_data = np.random.random((batch_size * num_batches, timestep, 21))

# print(input_1_data.shape)

# target_1_data = np.random.random((batch_size * num_batches, 7))

# print(target_1_data.shape)

model.fit(x_train, y_train, validation_data=(x_val, y_val), batch_size=1, epochs=1024)

Epoch 1/1024
Epoch 2/1024
Epoch 3/1024
Epoch 4/1024
Epoch 5/1024
Epoch 6/1024
Epoch 7/1024
Epoch 8/1024
Epoch 9/1024
Epoch 10/1024
Epoch 11/1024
Epoch 12/1024
Epoch 13/1024
Epoch 14/1024
Epoch 15/1024
Epoch 16/1024
Epoch 17/1024
Epoch 18/1024
Epoch 19/1024
Epoch 20/1024
Epoch 21/1024
Epoch 22/1024
Epoch 23/1024
Epoch 24/1024
Epoch 25/1024
Epoch 26/1024
Epoch 27/1024
Epoch 28/1024
Epoch 29/1024
Epoch 30/1024
Epoch 31/1024
Epoch 32/1024
Epoch 33/1024
Epoch 34/1024
Epoch 35/1024
Epoch 36/1024
Epoch 37/1024
Epoch 38/1024
Epoch 39/1024
Epoch 40/1024
Epoch 41/1024
Epoch 42/1024
Epoch 43/1024
Epoch 44/1024
Epoch 45/1024
Epoch 46/1024
Epoch 47/1024
Epoch 48/1024
Epoch 49/1024
Epoch 50/1024
Epoch 51/1024
Epoch 52/1024
Epoch 53/1024
Epoch 54/1024
Epoch 55/1024
Epoch 56/1024
Epoch 57/1024
Epoch 58/1024
Epoch 59/1024
Epoch 60/1024
Epoch 61/1024
Epoch 62/1024
Epoch 63/1024
Epoch 64/1024
Epoch 65/1024
Epoch 66/1024
Epoch 67/1024
Epoch 68/1024
Epoch 69/1024
Epoch 70/1024
Epoch 71/1024
Epoch 72/1024
E

<keras.callbacks.History at 0x1e1937484c0>

In [208]:
pred = model.predict(x_test)
pred[0]

array([2.3054986, 2.3375406, 2.459578 , 3.9895294, 2.5512688, 2.3145776,
       2.8820815], dtype=float32)