# Testing global embedding function

In [24]:
from collections import deque, defaultdict
import torch
import numpy as np

In [None]:
'''Create test graph 1'''
ccn = defaultdict(list)
ccn[3].append(4)
ccn[4].append(3)

node_client_map = {0:0,1:0,2:0,3:0,4:1,5:1,6:1,7:1}
embeddings = torch.tensor([
                [[[1,1],[2,2],[3,3],[4,4],[0,0],[0,0],[0,0],[0,0]],
                 [[2,2],[4,4],[6,6],[8,8],[0,0],[0,0],[0,0],[0,0]],
                 [[3,3],[5,5],[7,7],[9,9],[0,0],[0,0],[0,0],[0,0]]],
                 
                [[[0,0],[0,0],[0,0],[0,0],[1,1],[2,2],[3,3],[4,4]],
                [[0,0],[0,0],[0,0],[0,0],[2,2],[4,4],[6,6],[8,8]],
                [[0,0],[0,0],[0,0],[0,0],[3,3],[5,5],[7,7],[9,9]]]])
adj_list = [[0,2], [1,3], [0,2,3], [1,2,3,4], [3,4,5], [4,5,6],[5,6,7],[6,7]]

# node_embedding_update_sum(4, ccn_1, 0)
# get_global_embedding(embeddings_1, ccn_1, node_assignment_1)

# '''Create test graph 2'''
# ccn = defaultdict(list)
# ccn[3].append(4)
# ccn[4].append(3)
# ccn[3].append(8)
# ccn[8].append(3)
# node_client_map = {0:0,1:0,2:0,3:0,4:1,5:1,6:1,7:1,8:2,9:2,10:2}
# embeddings = torch.tensor([[[[1,1],[2,2],[3,3],[4,4],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0]],
#                  [[2,2],[4,4],[6,6],[8,8],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0]],
#                  [[3,3],[5,5],[7,7],[9,9],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0]]],

#                 [[[0,0],[0,0],[0,0],[0,0],[1,1],[2,2],[3,3],[4,4],[0,0],[0,0],[0,0]],
#                 [[0,0],[0,0],[0,0],[0,0],[2,2],[4,4],[6,6],[8,8],[0,0],[0,0],[0,0]],
#                 [[0,0],[0,0],[0,0],[0,0],[3,3],[5,5],[7,7],[9,9],[0,0],[0,0],[0,0]]],

#                 [[[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[1,1],[2,2],[3,3]],
#                 [[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[2,2],[4,4],[6,6]],
#                 [[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[3,3],[5,5],[7,7]]],
#                 ])
# adj_list = [[0,2], [1,3], [0,2,3], [1,2,3,4,8], [3,4,5], [4,5,6],[5,6,7],[6,7],[3,8,9,10],[8,9,10],[8,9,10]]

# node_embedding_update_sum(3, ccn_2, 1)

In [None]:
def node_embedding_update_sum(start_node, ccn, k):
    embeddings_required = []
    dq = deque([(start_node, k, {start_node})])
    while dq:
        node, hop, nodes_visited = dq.popleft()
        embeddings_required.append([node, hop])
        if hop > 1 and node == start_node:
            embeddings_required += [[node, 0]]  * len(ccn[node]) # count the times 1-hop ccn visits itself
        elif hop > 1:
            embeddings_required += [[node, 0]] # add 0-hop whenever it is visited

        for neigh in ccn[node]:
            if neigh not in nodes_visited and hop>0:
                dq.append((neigh, hop-1, nodes_visited|{neigh}))

    return embeddings_required

def get_global_embedding(embeddings, ccn, node_client_map):
    hop_embeddings = []
    for hop in range(3):
        hop_matrix = []
        for node in range(len(node_client_map)):
            node_embdedding_sum = node_embedding_update_sum(node, ccn, hop)
            final_embedding = torch.zeros(embeddings[0][0][0].shape)
            for update_node, k in node_embdedding_sum:
                final_embedding += embeddings[node_client_map[update_node]][k][update_node]
            hop_matrix.append(final_embedding)
        hop_embeddings.append(hop_matrix)

    return hop_embeddings

### Corrected Version

In [None]:
def node_embedding_update_sum(start_node, ccn, k):
    '''
    Function to return the contribution of each neighbouring node to start node and its hop embedding
    Inputs:
    1) start_node -> node we wish to find contribution for next node embedding
    2) ccn -> defaultdict(list) of cross client nodes
    3) k -> Hop we wish to find embedding of start_node for

    Output:
    list of tuples corresponding to (node required, hop) for vector embedding update
    '''
    embeddings_required = []
    dq = deque([(start_node, k, {start_node})])
    while dq:
        node, hop, nodes_visited = dq.popleft()
        embeddings_required.append([node, hop])
        if hop > 1 and node == start_node:
            embeddings_required += [[node, 0]]  * len(ccn[node]) # count the times 1-hop ccn visits itself
        elif hop > 1:
            embeddings_required += [[node, 0]] # add 0-hop whenever it is visited

        for neigh in ccn[node]:
            if neigh not in nodes_visited and hop>0:
                dq.append((neigh, hop-1, nodes_visited|{neigh}))

    return embeddings_required

def get_global_embedding(embeddings, ccn, node_client_map, subnodes_union, first_parti_client):
    '''
    Function to return the global embedding to update the client's local embeddings, using the formula:
    1 hop NE of node i => NE1[i] + SUM(NE0[j]) for j in ccn[i]
    2 hop NE of node i => NE2[i] + SUM(NE1[j] + NE0[j] + NE0[i]) for j in ccn[i] + SUM(NE0[k]) for k in ccn[j]

    Inputs:
    1) embeddings -> defaultdict(Tensor) of 0-hop, 1-hop and 2-hop NE of each client
    2) ccn -> defaultdict(list) of cross client nodes
    3) node_client_map -> the client each node is assigned for training

    Output:
    list of 0-hop, 1-hop and 2-hop Global NE 
    '''
    if len(embeddings) == 1:
        return embeddings[0] # Only one client
    
    hop_embeddings = []
    for hop in range(3):
        hop_matrix = []
        for node in range(len(node_client_map)):
            node_embdedding_sum = node_embedding_update_sum(node, ccn, hop)
            final_embedding = torch.zeros(embeddings[first_parti_client][0][0].shape).to("cuda:0")
            for update_node, k in node_embdedding_sum:
                if update_node in subnodes_union:
                    final_embedding += embeddings[node_client_map[update_node]][k][update_node]
            hop_matrix.append(final_embedding)
        stack = torch.stack(hop_matrix)
        hop_embeddings.append(stack)

    return hop_embeddings

### Optimized Version using Matrix Multiplication

In [23]:
def get_node_embedding_needed(start_node, global_adj_matrix, clients_adj_matrix, ccn, node_client_map, k):
    ''' Return all the (client, node, number of times needed to add) for each hop. '''
    if k == 1:
        ne_needed = [[] for _ in range(k)] # info needed for hop 0
        to_subtract = clients_adj_matrix[node_client_map[start_node]]
        adjustment_coefficient = global_adj_matrix[start_node] - to_subtract[start_node]
        for i, coe in enumerate(adjustment_coefficient):
            if coe > 0:
                ne_needed[0].append((node_client_map[i], i, coe))

    elif k == 2:
        ne_needed = [[] for _ in range(k)] # info needed for hop 0, 1
        global_two_hop = np.linalg.matrix_power(global_adj_matrix, 2) # Corrected (** 2 is wrong)
        to_subtract = np.linalg.matrix_power(clients_adj_matrix[node_client_map[start_node]], 2)
        for neigh in ccn[start_node]:
            to_subtract[start_node] += clients_adj_matrix[node_client_map[neigh]][neigh] # Correct (have to specify which row)
            ne_needed[1].append((node_client_map[neigh], neigh, 1))

        adjustment_coefficient = global_two_hop[start_node] - to_subtract[start_node]
        for i, coe in enumerate(adjustment_coefficient):
            if coe > 0:
                ne_needed[0].append((node_client_map[i], i, coe))
        
    return ne_needed # Corrected

def fast_get_global_embedding(embeddings, ccn, node_client_map, adj_list, subnodes_union):
    global_adj_matrix = np.array([[1 if dst in adj_list[src] else 0 for dst in range(len(adj_list))] for src in range(len(adj_list))]) # correct
    clients_adj_matrix = []
    for client in range(max(node_client_map.values()) + 1): # Correct
        client_adj_matrix = np.array([[1 if dst in adj_list[src] and node_client_map[src] == client and node_client_map[dst] == client else 0 for dst in range(len(adj_list))] for src in range(len(adj_list))])
        clients_adj_matrix.append(client_adj_matrix)

    hop_embeddings = []
    for hop in range(3):
        hop_matrix = []
        for node in range(len(node_client_map)):
            if ccn[node] == [] or hop == 0:
                final_embedding = embeddings[node_client_map[node]][hop][node].clone()
            else:
                final_embedding = embeddings[node_client_map[node]][hop][node].clone()
                # print("hop", hop, "node", node, "starting emb", final_embedding)
                ne_needed = get_node_embedding_needed(node, global_adj_matrix, clients_adj_matrix, ccn, node_client_map, hop)
                for hop_needed, tuples in enumerate(ne_needed):
                    for client, node, num_times in tuples:
                        if node in subnodes_union:
                            # print(f"hop {hop_needed}: ({client},{node},{num_times}) => {embeddings[client][hop_needed][node]}")
                            final_embedding += embeddings[client][hop_needed][node] * num_times

            hop_matrix.append(final_embedding)
        hop_embeddings.append(hop_matrix)

    return hop_embeddings

hop_embeddings = fast_get_global_embedding(embeddings, ccn, node_client_map, adj_list)
# print("Final:", hop_embeddings)

hop 1 node 3 starting emb tensor([8, 8])
hop 0: (1,4,1) => tensor([1, 1])
hop 1 node 4 starting emb tensor([2, 2])
hop 0: (0,3,1) => tensor([4, 4])
hop 2 node 3 starting emb tensor([9, 9])
hop 0: (0,3,1) => tensor([4, 4])
hop 0: (1,4,1) => tensor([1, 1])
hop 1: (1,4,1) => tensor([2, 2])
hop 2 node 4 starting emb tensor([3, 3])
hop 0: (0,3,1) => tensor([4, 4])
hop 0: (1,4,1) => tensor([1, 1])
hop 1: (0,3,1) => tensor([8, 8])
Final: [[tensor([1, 1]), tensor([2, 2]), tensor([3, 3]), tensor([4, 4]), tensor([1, 1]), tensor([2, 2]), tensor([3, 3]), tensor([4, 4])], [tensor([2, 2]), tensor([4, 4]), tensor([6, 6]), tensor([9, 9]), tensor([6, 6]), tensor([4, 4]), tensor([6, 6]), tensor([8, 8])], [tensor([3, 3]), tensor([5, 5]), tensor([7, 7]), tensor([16, 16]), tensor([16, 16]), tensor([5, 5]), tensor([7, 7]), tensor([9, 9])]]
