<a href="https://colab.research.google.com/github/648lsp666/Federal_Learning/blob/main/SPU%E5%8A%A0%E5%AF%86.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install secretflow

In [None]:
import secretflow as sf
import random
import numpy as np
import math


# 初始化SecretFlow
sf.init(['alice', 'bob', 'carol'], num_cpus=8, log_to_driver=False)



In [None]:
alice, bob, carol,= (
    sf.PYU('alice'),
    sf.PYU('bob'),
    sf.PYU('carol'),
)

aby3_config = sf.utils.testing.cluster_def(parties=['alice', 'bob', 'carol'])

aby3_config

In [None]:
spu_device = sf.SPU(aby3_config)

spu_device.conf.field

In [None]:

# 模拟的公钥和私钥
public_key = """
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuP4iIL0Zq/ks36F0xdgB
i1gS7OlHCHuXaD9v+nZ5hX3C1F9a8k+h9oVDJ2CBQkZnyT7ukE0zS9qg+tWhSZWp
5TAUtdZhbHR1chik+yz4JwTCzAkHsh0RlF0Gz7RRGWmdK3+x5NwO7B3Ib/mc1H3+
1ZEBgzJjAsC3iC+p1h4W9b3kM5p7HgZYp+XMTVw1aO1SxLFq4VRiMT6wMrPZIvVh
Yw9H9NmW5oVbdI5l4EJ+K4XUZoFbDqZLj7FwR2G+GQX1z1UZtNH+dMC+WV1Vb/iA
HnUy1UwOGgBOekNxDN4f2wMBl8QdH6gr4eOa1J6dMIkYhDIfj4/6dE76oXZsNe72
OwIDAQAB
-----END PUBLIC KEY-----
"""
private_key = """
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQC4/iIgvRmr+Szf
oXTF2AGLWBLo6UcIe5doP2/6dnmFfcLUX1ryT6H2hUMnYIFCRmfJPu6QTTNL2qD6
1aFJlal5MBS11mFsdHVyGKT7LPgnBMLMCQeyHRGU3QbPtFEZaZ0rf7Hk3A7sHchv
+ZzUff7VkQGDMmMCwLeIL6nWNhb1veQzmnsfBlin5cxNXDVr7VLEsWrhVGIxPrAy
s9ki9WFjD0f02ZbmhVt0jmXgQn4rhdRmgVsOpkuPsXBHYY4ZBfXPVRm00f50wL5Z
XVVu+IAedTLVTQ4aAE56Q3EM3h/bAwGXxB0fqCvh45rUnp0wiRiEMh+Pj/p0Tvqh
dmw173Y7AgMBAAECggEAJxP4OVW6UZ3o6E9ON5S0gJnVRzZKhuzGhdvJUTQ/Ygxl
uS42mF8pOwN8B7/vVi5elhnPIt2n2eh4/6vw1RrsY+N6DUBn3FeZ5X3sExfRpzUt
xKN4VDfXcZh+eOv6azJvDxdEmsqI7M9j29OQ+vYULlh5kAjgDr+nibBd9m+qKyXy
Mv3xVp0AT1Z9P3shW3P4RCZwmwIlF/8gYDzJXY1OtUSpSmBtrCGQPLc5UbE5xMK0
IzAjB4xWvUoI1NzYLOVJ0kvX+fS5G1D6+7dHFEFv+6E6A19I3Yh3FvUOT6GFDdIa
NcYgjE8TxWkj1zZCB+FZJz8G7qlxl0YKiFWTjX4WTQKBgQD6HP+9XAVu0NHjED5a
tGpdNU16+GzFA5mjwSe3J2RQAYk5mnAV/Wk6jxR/LUlyDS8wQbEAXr8EyomjOZ7n
z7JwsLJ9mAgxUJYjoHzCtCgF1UJHg0oezL+uMRxw7ybOaCQuUO4yyGIl5RxStJmB
y01IUmLzXq1OtnLifqCZP5HEcwKBgQDKgIm84eYlWck3V4ZZgXxtK8ewO/nwvj/f
nxjFb/L0EkhbEC+Ht/5rAlXVrZ/xREzXkbl+tCZOWAoZTVY3vHn2Is1RYkVrkG2n
o1o3UTz5+udDsiySeEqHnN5z6+nsQUeM8MuFF/WzD4E6Kwkt7G4YJZ49GAseYOmA
D5W05Kfz8QKBgA5K4hplcRty+uwQh+cO/4Vas+odOd7EbmGJl9BXy0pTr/INfL49
Gwe2UgA85/t5ZfkcDZCw6KMiOrR1ro+UJ25z8OItVpepgRSCELFZPbZVQEYKT1CV
KOjZBJArNmsO+ozEXsmRV9e7LU/JcLgje5PSqnl4n3djvmyxWVmdn4JhAoGAU9P9
w2KoJLvFZhHapY8eVXe1xwBR0J6+tZQnFr+RtJr7PC5TS0lX+ojMYp9CTs7PR1Fq
BIF6uv05csVBOJPa1kxtLtEyz6aSv6NdrjG7UXqx/Qg8sS41o5LBSDcV83+POAhO
0pC3oQ7A8j7Gb5rM+Txi99eW9B8K4R3xGViB+6kCgYBGR5sWCVbmOjq0USdFfPn3
j/RuCZ/qnFJq9VDI83KaY8s7ySoMuS2K3PHSqQzss03o8UDt1hdCxHjFehzWmkO5
/jUGmVLDDg4G6Y+KvW01hz4yYzIlc2Pmc8ts4d5ifL9XMsbEn1YWiBXkaSHr4RC7
zL9auFIB8ALN7xEz3yxgtg==
-----END PRIVATE KEY-----
"""

# 配置HEU
heu_config = {
    'sk_keeper': {
        'party': 'alice'
    },
    'evaluators':
    {
        'party': 'bob'
    },
    'he_parameters': {
        # which HE algorithm to use,
        # see https://www.secretflow.org.cn/docs/heu/latest/en-US/getting_started/algo_choice for detail
        'schema': 'paillier',
        'key_pair': {
            'generate': {
                'bit_size': 2048,
            },
        }
    }
}

# 初始化HEU
heu = sf.HEU(heu_config,spu_field_type=spu_device.conf.field)

# 原始数据
data1 = 10
data2 = 5

# 加密数据
encrypted_data1 = heu.encrypt(data1)
encrypted_data2 = heu.encrypt(data2)

print("加密后的数据1:", encrypted_data1)
print("加密后的数据2:", encrypted_data2)

# 对加密数据进行同态乘法运算
encrypted_result = heu.mul(encrypted_data1, encrypted_data2)

print("同态乘法后的加密结果:", encrypted_result)

# 解密结果
decrypted_result = heu.decrypt(encrypted_result)

print("解密后的结果:", decrypted_result)


In [42]:
# 定义客户端数量和其他参数
num_clients = 3
B = 3  # 类别数量
d = 10  # 数据集大小下限的基数
p_M = 0.25  # 剪枝率上限
prune_rate_change = 0.01  # 剪枝率变化的步长

# 随机生成客户端数据
clients = {}
for i in range(num_clients):
    q = random.randint(100, 200)  # 随机样本数量
    alpha = np.random.dirichlet(np.ones(B), size=1).flatten()  # 随机生成类别分布
    p = random.uniform(0.1, 0.25)  # 随机剪枝率
    clients[f'client{i+1}'] = {'q': q, 'alpha': alpha, 'p': p}

spu_io = sf.device.SPUIO(spu_device.conf, spu_device.world_size)

import spu

meta, io_info, *shares = spu_io.make_shares(clients, spu.Visibility.VIS_SECRET)


In [None]:
clients_data = spu_io.reconstruct(shares, io_info, meta)

clients_data

In [None]:
def calculate_QCID(M):
    total_q = sum(clients[client]['q'] for client in M)
    sum_qqT = 0
    for n in M:
        for m in M:
            sum_qqT += clients[n]['q'] * clients[m]['q'] * np.dot(clients[n]['alpha'], clients[m]['alpha'])
    return sum_qqT / (total_q ** 2) - 1 / B

def get_neighbor(M):
    new_M = M.copy()
    if random.random() < 0.5 and len(new_M) > 1:
        new_M.remove(random.choice(list(new_M)))  # 随机移除一个客户端
    else:
        choices = list(set(clients.keys()) - new_M)
        if choices:
            new_M.add(random.choice(choices))  # 随机添加一个新的客户端
    return new_M

def simulated_annealing():
    current_M = set(random.sample(list(clients.keys()), random.randint(2, len(clients))))
    current_score = calculate_QCID(current_M)
    T = 1.0
    T_min = 0.00001
    alpha = 0.9
    while T > T_min:
        i = 1
        while i <= 100:
            new_M = get_neighbor(current_M)
            if sum(clients[client]['q'] for client in new_M) >= d**4 and all(clients[client]['p'] < p_M for client in new_M):
                # 更新剪枝率
                for client in new_M:
                    clients[client]['p'] += prune_rate_change
                    if clients[client]['p'] > p_M:
                        clients[client]['p'] = p_M
                new_score = calculate_QCID(new_M)
                ap = math.exp((current_score - new_score) / T)
                if new_score < current_score or random.random() < ap:
                    current_M = new_M
                    current_score = new_score
            i += 1
        T *= alpha

    return current_M, current_score

best_group, best_score = simulated_annealing()
print("Best group:", best_group)
print("Best QCID score:", best_score)