In [1]:
import os
import torch
import seaborn as sns
import matplotlib.pyplot as plt

def plot_h(matrix, name):
    # Plot heatmap
    sns.heatmap(matrix.cpu().detach().numpy(), annot=False, cmap='viridis', vmin=-0.1, vmax=0.3)
    plt.title(name)


In [2]:
import re
import glob

def parse_digraph(input_str):
    """
    Parses the input string to extract edges of the directed graph.
    Returns a list of tuples representing the edges.
    """
    # Find all matches for edges in the format 'source -> target'
    edges = re.findall(r'(\d+)\s*->\s*(\d+)', input_str)
    # Convert each source and target from string to integer and store as tuples
    return [(int(src), int(dest)) for src, dest in edges]

def parse_digraphs(directory):
    """
    Open all files in this directory
    Parses the input string to extract edges of the directed graph.
    Returns a list of tuples representing the edges.
    """
    # Find all the txt files in this directory
    txt_files_content = {}

    txt_files = glob.glob(os.path.join(directory, '*'))

    for file_path in txt_files:
        if file_path.lower().endswith('.png'):
            continue
        
        try:
            with open(file_path, 'r', encoding='ascii') as file:
                c = file_path[-1]
                input_str = file.read()
                # Find all matches for edges in the format 'source -> target'
                edges = re.findall(r'(\d+)\s*->\s*(\d+)', input_str)
                # Convert each source and target from string to integer and store as tuples
                txt_files_content[c] = [(int(src), int(dest)) for src, dest in edges]
        except UnicodeDecodeError:
            print("Skipping Non-ASCII Files")
        except Exception as e:
            print("No txt files?")
    return txt_files_content

def what_i_need(edges_graph1, edges_graph2):
    """ Return the nodes that client 1 wants from client 2 """
    # Extract nodes from Graph 1
    nodes_graph1 = set()
    for src, dest in edges_graph1:
        nodes_graph1.update([src, dest])

    # Extract nodes from Graph 2
    nodes_graph2 = set()
    for _, dest in edges_graph2:
        nodes_graph2.update([dest])

    # Find common nodes
    common_nodes = nodes_graph1.intersection(nodes_graph2)

    return common_nodes

def untouched(edges_graph1, edges_graph2, edges_graph3, end):
    """ Return the nodes that all clients don't have information of """
    full_range = set(range(end+1))

    # Extract nodes from Graph 1
    nodes_graph1 = set()
    for src, dest in edges_graph1:
        nodes_graph1.update([src, dest])

    # Extract nodes from Graph 2
    nodes_graph2 = set()
    for _, dest in edges_graph2:
        nodes_graph2.update([src, dest])

    nodes_graph3 = set()
    for _, dest in edges_graph3:
        nodes_graph3.update([src, dest])

    union_set = nodes_graph1 | nodes_graph2 | nodes_graph3
    
    # Find the untouched nodes
    the_untouched = full_range - union_set

    return the_untouched


In [3]:
digraphs = parse_digraphs('graph_output/')

# Parse the input strings to get the list of edges
edges_graph0 = digraphs['0']
edges_graph1 = digraphs['1']
edges_graph2 = digraphs['2']
# edges_graph3 = digraphs['3']
# edges_graph4 = digraphs['4']

In [4]:
untouched_nodes = untouched(edges_graph0, edges_graph1, edges_graph2, 30)
print(untouched_nodes)

{8, 10, 11, 13, 17, 18, 21, 23, 26, 29}


In [5]:
# Find common nodes
common_nodes = what_i_need(edges_graph0, edges_graph1)
print("Client 0 needs", common_nodes, "from Client 1")

common_nodes = what_i_need(edges_graph0, edges_graph2)
print("Client 0 needs", common_nodes, "from Client 2")

common_nodes = what_i_need(edges_graph1, edges_graph0)
print("Client 1 needs", common_nodes, "from Client 0")

common_nodes = what_i_need(edges_graph1, edges_graph2)
print("Client 1 needs", common_nodes, "from Client 2")

common_nodes = what_i_need(edges_graph2, edges_graph0)
print("Client 1 needs", common_nodes, "from Client 2")

common_nodes = what_i_need(edges_graph2, edges_graph1)
print("Client 2 needs", common_nodes, "from Client 1")

Client 0 needs {0, 1, 2, 3, 7, 9, 12, 20} from Client 1
Client 0 needs {0, 2, 3, 5, 12} from Client 2
Client 1 needs {0, 1, 2, 3, 5, 7, 16, 20} from Client 0
Client 1 needs {0, 2, 3, 4, 5, 12} from Client 2
Client 1 needs {0, 2, 3, 5, 16, 20} from Client 2
Client 2 needs {0, 2, 3, 4, 9, 12, 20, 25, 28} from Client 1


In [2]:
import torch
import copy

def share_embeddings(embeddings, weights):
    temp_embeddings = copy.deepcopy(embeddings)
    avg = torch.sum(weights, dim=0, keepdim=False)

    for i in range(len(weights)): # for a client
        indices = (weights[i] > 0).nonzero(as_tuple=True)[0]
        for l in range(2): # 2 conv layers
            embeddings[i][l][indices] *= weights[i][indices][:, None]
            for j in range(len(weights)): # other clients
                if j != i:
                    mul = temp_embeddings[j][l][indices] * weights[j][indices][:, None]
                    embeddings[i][l][indices] += mul
            result = (embeddings[i][l][indices] / avg[indices][:, None]).long()
            embeddings[i][l][indices] = result

    return embeddings

# Given tensor
tensor10 = torch.tensor([[5, 20],
                       [1, 9],
                       [1, 7]], device='cuda:0')

tensor11 = torch.tensor([[1, 0],
                       [5, 1],
                       [7, 3]], device='cuda:0')

tensor20 = torch.tensor([[3, 20],
                        [1, 9],
                       [2, 8]], device='cuda:0')

tensor21 = torch.tensor([[4, 3],
                        [5, 1],
                       [8, 6]], device='cuda:0')

# tensor30 = torch.tensor([[5, 6],
#                         [6, 5],
#                        [7, 8]], device='cuda:0')

# tensor31 = torch.tensor([[4, 5],
#                         [3, 2],
#                        [10, 9]], device='cuda:0')

# tensor40 = torch.tensor([[1, 2],
#                         [2, 0],
#                        [7, 8]], device='cuda:0')

# tensor41 = torch.tensor([[1, 1],
#                         [20, 9],
#                        [10, 9]], device='cuda:0')

weights = torch.tensor([[1, 0, 2],
                        [0, 3, 1]], device='cuda:0')

""" Something to Change, last_embeddings is list of list of 2 NE"""
print(share_embeddings([[tensor10, tensor11], [tensor20, tensor21]], weights))

[[tensor([[ 5, 20],
        [ 1,  9],
        [ 1,  7]], device='cuda:0'), tensor([[1, 0],
        [5, 1],
        [7, 4]], device='cuda:0')], [tensor([[ 3, 20],
        [ 1,  9],
        [ 1,  7]], device='cuda:0'), tensor([[4, 3],
        [5, 1],
        [7, 4]], device='cuda:0')]]
