In [None]:
Fedproto 구현

In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from collections import deque
import AI_model
import pickle
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import LabelEncoder
import random


class Queue:
    def __init__(self):
        self.items = deque()

    def enqueue(self, item):
        self.items.append(item)

    def dequeue(self):
        if not self.is_empty():
            return self.items.popleft()
        else:
            print("Queue is empty")

    def is_empty(self):
        return len(self.items) == 0

    def size(self):
        return len(self.items)

def build_model(num_classes):
    model = AI_model.build_model(num_classes)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
    return model


def extract_local_prototypes(model, x_local, y_local, num_classes):
    #dense_features, conv_features = model(x_local, training=False, feature_extraction=True)
    dense_features = model(x_local, training=False, feature_extraction=True)

    dense_features_flattened = tf.reshape(dense_features, [dense_features.shape[0], -1]).numpy()
    #conv_features_flattened = tf.reshape(conv_features, [conv_features.shape[0], -1]).numpy()

    labels = np.argmax(y_local, axis=1)

    dense_prototypes = np.zeros((num_classes, dense_features_flattened.shape[1]))
    #conv_prototypes = np.zeros((num_classes, conv_features_flattened.shape[1]))

    for i in range(num_classes):
        class_features_dense = dense_features_flattened[labels == i]
        if len(class_features_dense) > 0:
            dense_prototypes[i] = np.mean(class_features_dense, axis=0)

        #class_features_conv = conv_features_flattened[labels == i]
        #if len(class_features_conv) > 0:
        #    conv_prototypes[i] = np.mean(class_features_conv, axis=0)

    return dense_prototypes#, conv_prototypes


def update_global_prototypes(local_prototypes_list, num_classes):
    global_prototypes = np.zeros((num_classes, local_prototypes_list[0].shape[1]))
    for i in range(num_classes):
        class_prototypes = np.array([proto[i] for proto in local_prototypes_list])
        global_prototypes[i] = np.mean(class_prototypes, axis=0)
    return global_prototypes

def classify_with_global_prototypes(model, x, global_prototypes):
    features = model(x, training=False, feature_extraction=True).numpy()
    distances = np.sqrt(((features[:, np.newaxis, :] - global_prototypes[np.newaxis, :, :]) ** 2).sum(axis=2))
    predictions = np.argmin(distances, axis=1)
    return predictions


def prototype_loss(y_true, y_pred, features, global_prototypes):
    global_prototypes = tf.convert_to_tensor(global_prototypes, dtype=tf.float32)
    c = tf.argmax(y_true, axis=1)
    prototypes = tf.gather(global_prototypes, c)
    features_flattened = tf.reshape(features, [features.shape[0], -1])
    distance = tf.reduce_mean(tf.square(features_flattened - prototypes))
    return distance


def train_step(model, x_batch, y_batch, optimizer, global_prototypes, epoch):
    with tf.GradientTape() as tape:
        features = model(x_batch, training=True, feature_extraction=True)
        y_pred = model(x_batch, training=True)
        loss = tf.keras.losses.categorical_crossentropy(y_batch, y_pred)
        prototype_loss_value = prototype_loss(y_batch, y_pred, features, global_prototypes)
        total_loss = loss + prototype_loss_value

    gradients = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return total_loss, prototype_loss_value


def train_model(model, x_train, y_train, global_prototypes, batch_size, epochs):
    Total_loss, P_loss = [], []
    optimizer = tf.keras.optimizers.Adam()
    dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
    for epoch in range(epochs):
        for x_batch, y_batch in dataset:
            total_loss, prototype_loss_value = train_step(model, x_batch, y_batch, optimizer, global_prototypes, epochs)
            Total_loss.append(total_loss.numpy())
            P_loss.append(prototype_loss_value.numpy())

    return Total_loss, P_loss

In [2]:
class Get_Data:

    def __init__(self, data_type):
        self.data_type = data_type  # "RML2016.10a" or "RML2016.10b"
        if data_type == 'RML2016.10a':
            self.file_path = f'D:/Research/Dataset_nocode/{self.data_type}/RML2016.10a_dict.pkl'
        elif data_type == 'RML2016.10b':
            self.file_path = f'D:/Research/Dataset_nocode/{self.data_type}/RML2016.10b.dat'

        self.global_comm_round = 20
        self.num_locals = 3


    def data_import(self):
        with open(self.file_path, 'rb') as file:
            pickle_data = pickle.load(file, encoding='latin1')

        data_item = list(pickle_data.items())
        data, SNR, label = [], [], []

        for i in range(len(data_item)):
            data.append(data_item[i][1])
            for j in range(len(data_item[i][1])):
                label.append(data_item[i][0][0])
                SNR.append(data_item[i][0][1])

        label_encoder = LabelEncoder()
        integer_labels = label_encoder.fit_transform(label)
        print(f"RML Dataset Length - 1st (data): {data[0].shape}, 2nd element(SNR): {len(SNR)}, 3rd element(label): {integer_labels}")

        label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
        print("Label Mapping:", label_mapping)

        return data, SNR, integer_labels


    def data_process(self, data, SNR, integer_labels, test_ratio=0.2):

        def one_hot_to_label(one_hot_encoded):
            lst = []
            for i in range(len(one_hot_encoded)):
                lst.append(np.argmax(one_hot_encoded[i]))

            return lst

        def one_hot_encode(labels):
            labels_reshaped = labels.reshape(-1, 1)
            encoder = OneHotEncoder(sparse_output=False)
            one_hot_encoded = encoder.fit_transform(labels_reshaped)

            return one_hot_encoded

        OH_label = one_hot_encode(integer_labels)

        if self.data_type == "RML2016.10a":
            X_data = np.array(data).reshape(1000*len(data), 2, 128, 1)
        elif self.data_type == "RML2016.10b":
            X_data = np.array(data).reshape(6000*len(data), 2, 128, 1)
        else:
            print("data_type either RML2016.10a or RML2016.10b")
            return None

        combined_data = list(zip(X_data, OH_label, SNR))
        random.shuffle(combined_data)
        shuffled_x_data, shuffled_y_label, shuffled_SNR = zip(*combined_data)
        x, y, z = np.array(shuffled_x_data), np.array(shuffled_y_label), np.array(shuffled_SNR)

        shuffled_indices = np.random.permutation(len(x))

        #test_ratio = 0.2  ##########  0.2 Test / 0.8 Train
        split_index = int(len(x) * (1 - test_ratio))

        x_train, x_test = x[shuffled_indices[:split_index]], x[shuffled_indices[split_index:]]
        y_train, y_test = y[shuffled_indices[:split_index]], y[shuffled_indices[split_index:]]
        z_train, z_test = z[shuffled_indices[:split_index]], z[shuffled_indices[split_index:]]

        print(f"x_train shape: {x_train.shape}, y_train: {y.shape}, x_test: {x_test.shape}, y_test: {y_test.shape}")

        return x_train, y_train, x_test, y_test

    def main(self):
        data, SNR, integer_labels = self.data_import()
        X_train, y_train, X_test, y_test = self.data_process(data, SNR, integer_labels, test_ratio=0.2)

        return X_train, y_train, X_test, y_test

In [None]:
g_round, num_local = 5, 5
FedEMG_loss, FedEMG_acc = [], []
num_classes = 11
local_clients = 3
local_models = [build_model(num_classes) for _ in range(local_clients)]

for idx in range(local_clients): #Subject
    globals()['Q_X_{}'.format(idx)], globals()['Q_Y_{}'.format(idx)] = Queue(), Queue()

for epoch in range(g_round):   #Repetition
    for idx in range(local_clients): #Subject
        Current_rep = epoch+1
        _, _, x_test_local, y_test_local = Get_Data("RML2016.10a").main()

        if idx==0 and epoch==0:
            client_data_indices = np.random.choice(len(x_test_local), size=500, replace=False)
            X_test, Y_test = x_test_local[client_data_indices], y_test_local[client_data_indices]
        else:
            client_data_indices = np.random.choice(len(x_test_local), size=500, replace=False)
            X_test = np.concatenate((X_test, x_test_local[client_data_indices]))
            Y_test = np.concatenate((Y_test, y_test_local[client_data_indices]))


for epoch in range(g_round):   #Repetition
    print(f'\nGlobal Epoch {epoch+1}/{g_round} start\n\n')
    local_prototypes = []

    for idx in range(local_clients): #Subject
        Current_rep = epoch+1
        X_train, Y_train, x_test_local, y_test_local = Get_Data("RML2016.10a").main()


        if epoch > 3:
            globals()['Q_X_{}'.format(idx)].dequeue()
            globals()['Q_Y_{}'.format(idx)].dequeue()

        if epoch > 0:
            for i in range(globals()['Q_X_{}'.format(idx)].size()):
                X_train = np.concatenate((X_train, globals()['Q_X_{}'.format(idx)].items[i]), axis=0)
                Y_train = np.concatenate((Y_train, globals()['Q_Y_{}'.format(idx)].items[i]), axis=0)



        result = local_models[idx].fit(X_train, Y_train, epochs=200, batch_size=256, verbose=0)

        #print(f'{idx} intra-subject performance ====> 1')
        #local_models[idx].evaluate(x_test_local, y_test_local, verbose=1)
        #print(f'{idx} inter-subject performance =======> 1')
        #local_models[idx].evaluate(X_test, Y_test, verbose=1)

        random_indices = np.random.choice(len(X_train), size=1000, replace=False)
        X_for_Q, Y_for_Q = X_train[random_indices], Y_train[random_indices]

        globals()['Q_X_{}'.format(idx)].enqueue(X_for_Q)
        globals()['Q_Y_{}'.format(idx)].enqueue(Y_for_Q)
        prototype = extract_local_prototypes(local_models[idx], X_train, Y_train, num_classes)
        local_prototypes.append(prototype)


    global_prototypes = update_global_prototypes(local_prototypes, num_classes)
    for idx in range(local_clients): #Subject
        Current_rep = epoch+1

        X_train, Y_train, x_test_local, y_test_local = Get_Data("RML2016.10a").main()

        batch_size = 256

        train_model(local_models[idx], X_train, Y_train, global_prototypes, batch_size, epochs=30)  # 글로벌 프로토타입을 사용하여 모델 학습
        print(f'{idx} intra-subject performance ====> 2')
        local_models[idx].evaluate(x_test_local, y_test_local, verbose=1)
        print(f'{idx} inter-subject performance =======> 2')
        local_models[idx].evaluate(X_test, Y_test, verbose=1)

    for idx in range(local_clients): #Subject
        Current_rep = epoch+1

        X_train, Y_train, x_test_local, y_test_local = Get_Data("RML2016.10a").main()

        if epoch > 0:
            for i in range(globals()['Q_X_{}'.format(idx)].size()):
                X_train = np.concatenate((X_train, globals()['Q_X_{}'.format(idx)].items[i]), axis=0)
                Y_train = np.concatenate((Y_train, globals()['Q_Y_{}'.format(idx)].items[i]), axis=0)

        result = local_models[idx].fit(X_train, Y_train, epochs=200, batch_size=256, verbose=0)

        print(f'{idx} intra-subject performance ====> 3')
        local_models[idx].evaluate(x_test_local, y_test_local, verbose=1)
        print(f'{idx} inter-subject performance =======> 3')
        local_models[idx].evaluate(X_test, Y_test, verbose=1)

