In [1]:
# 加载初始值
import pickle
import os
import torch

file_name = f"initial.pkl"
if not os.path.exists(file_name):
    print(f"File {file_name} not found")
    raise FileNotFoundError

with open(file_name, "rb") as file:
    initial_global_model = pickle.load(file)

initial_global_model_params = torch.cat(
    [param.reshape(1, -1) for param in initial_global_model.values()], dim=1
)  # 初始全局模型，形状为1*d, n为参与者数量，d为梯度向量维度

print(initial_global_model_params.shape)

  return torch.load(io.BytesIO(b))


torch.Size([1, 2674688])


In [2]:
attack_defense_data = {
    "MR": {
        "FLAME": "FL_Backdoor_CV/2024-10-02_03-37-50/",
        "FLTRUST": "FL_Backdoor_CV/2024-10-02_21-40-11/",
        "FOOLSGOLD": "FL_Backdoor_CV/2024-10-27_16-25-41/",
        "SECFFT": "FL_Backdoor_CV/2024-10-29_22-41-06/",
    },
    "EDGE_CASE": {
        "FLAME": "FL_Backdoor_CV/2024-10-02_09-11-28/",
        "FLTRUST": "FL_Backdoor_CV/2024-10-03_03-09-36/",
        "FOOLSGOLD": "FL_Backdoor_CV/2024-10-28_00-55-46/",
        "SECFFT": "FL_Backdoor_CV/2024-10-30_04-48-52/",
    },
    "NEUROTOXIN": {
        "FLAME": "FL_Backdoor_CV/2024-10-02_12-55-32/",
        "FLTRUST": "FL_Backdoor_CV/2024-10-03_06-48-00/",
        "FOOLSGOLD": "FL_Backdoor_CV/2024-10-28_05-59-22/",
        "SECFFT": "FL_Backdoor_CV/2024-10-30_08-40-55/",
    },
}

dataset = "fmnist"

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
import os
import torch

import math
import hdbscan
import pickle


def flame(trained_params, current_model_param, param_updates):
    trained_params = trained_params.to(torch.float64)
    current_model_param = current_model_param.to(torch.float64)

    # === clustering ===
    cluster = hdbscan.HDBSCAN(
        metric="cosine",
        algorithm="generic",
        min_cluster_size=50 // 2 + 1,
        min_samples=1,
        allow_single_cluster=True,
    )
    cluster.fit(trained_params)
    predict_good = []
    for i, j in enumerate(cluster.labels_):
        if j == 0:
            predict_good.append(i)
    k = len(predict_good)

    # === median clipping ===
    model_updates = trained_params[predict_good] - current_model_param
    local_norms = torch.norm(model_updates, dim=1)
    S_t = torch.median(local_norms)
    scale = S_t / local_norms
    scale = torch.where(scale > 1, torch.ones_like(scale), scale)
    model_updates = model_updates * scale.view(-1, 1)

    # === aggregating ===
    trained_params = current_model_param + model_updates
    trained_params = trained_params.sum(dim=0) / k

    # === noising ===
    delta = 1 / (50**2)
    epsilon = 10000
    lambda_ = 1 / epsilon * (math.sqrt(2 * math.log((1.25 / delta))))
    sigma = lambda_ * S_t.numpy()
    print(
        f"sigma: {sigma}; #clean models / clean models: {k} / {predict_good}, median norm: {S_t},"
    )
    trained_params.add_(torch.normal(0, sigma, size=trained_params.size()))

    # === bn ===
    global_update = dict()
    for i, (name, param) in enumerate(param_updates.items()):
        if "num_batches_tracked" in name:
            global_update[name] = (
                1 / k * param_updates[name][predict_good].sum(dim=0, keepdim=True)
            )
        elif "running_mean" in name or "running_var" in name:
            local_norms = torch.norm(param_updates[name][predict_good], dim=1)
            S_t = torch.median(local_norms)
            scale = S_t / local_norms
            scale = torch.where(scale > 1, torch.ones_like(scale), scale)
            global_update[name] = param_updates[name][predict_good] * scale.view(-1, 1)
            global_update[name] = 1 / k * global_update[name].sum(dim=0, keepdim=True)

    return trained_params.float().to(device), global_update

In [5]:
# 输出当前路径
print(os.getcwd())

# 加载模型更新


global_param = initial_global_model_params
for attack_type, defense2data in attack_defense_data.items():
    for defense_type, defense_data in defense2data.items():
        if defense_type != "FLAME":
            continue
        for i in range(25):
            file_name = f"{defense_data}model_updates/{dataset}_{attack_type}_{i}.pkl"
            if not os.path.exists(file_name):
                print(f"File {file_name} not found")
                raise FileNotFoundError
            
            with open(file_name, "rb") as file:
                model_updates = pickle.load(file)

            # updates的形状为n*d, n为参与者数量，d为梯度向量维度
            updates = torch.cat(
                list(model_updates.values()), dim=1
            )

            # 计算每个用户的训练参数
            user_params = global_param + updates

            # 计算新的全局模型
            global_param, global_update = flame(
                user_params.cpu(), global_param.cpu(), model_updates
            )

            



c:\Users\BUPT426\Desktop\RoseAgg\RoseAgg_Latest


  return torch.load(io.BytesIO(b))


sigma: 0.0005022467632193615; #clean models / clean models: 26 / [21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 1.2519299604304135,


  return torch.load(io.BytesIO(b))


sigma: 0.0004811260317863026; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 42, 43, 44, 45, 47, 48, 49], median norm: 1.1992831772082344,


  return torch.load(io.BytesIO(b))


sigma: 0.0004689034009280396; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 29, 30, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 1.1688163252793997,


  return torch.load(io.BytesIO(b))


sigma: 0.0004583627491532068; #clean models / clean models: 26 / [20, 21, 24, 25, 26, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 1.1425420738043064,


  return torch.load(io.BytesIO(b))


sigma: 0.0004504158509699741; #clean models / clean models: 26 / [20, 22, 23, 24, 25, 26, 27, 28, 29, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48], median norm: 1.122733165799988,


  return torch.load(io.BytesIO(b))


sigma: 0.00044295042602392144; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 49], median norm: 1.104124406437559,


  return torch.load(io.BytesIO(b))


sigma: 0.00043543248732604787; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 34, 35, 36, 37, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49], median norm: 1.0853847481942338,


  return torch.load(io.BytesIO(b))


sigma: 0.0004225346117904145; #clean models / clean models: 26 / [20, 21, 22, 24, 25, 26, 27, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49], median norm: 1.0532347414814789,


  return torch.load(io.BytesIO(b))


sigma: 0.00041717235365843055; #clean models / clean models: 26 / [20, 21, 22, 23, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 1.0398684599987245,


  return torch.load(io.BytesIO(b))


sigma: 0.00040830728887475426; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 28, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 48, 49], median norm: 1.0177708756704535,


  return torch.load(io.BytesIO(b))


sigma: 0.0003971812652982647; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 28, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 43, 44, 45, 46, 47, 48, 49], median norm: 0.9900374918521516,


  return torch.load(io.BytesIO(b))


sigma: 0.00038912156083996993; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 47, 48, 49], median norm: 0.9699473962607404,


  return torch.load(io.BytesIO(b))


sigma: 0.000379874979381836; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 33, 34, 35, 36, 37, 39, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 0.9468988209253882,


  return torch.load(io.BytesIO(b))


sigma: 0.0003707721506694253; #clean models / clean models: 26 / [20, 21, 22, 24, 25, 26, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 0.924208572178567,


  return torch.load(io.BytesIO(b))


sigma: 0.00036131495843098546; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 28, 29, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 49], median norm: 0.9006350159669531,


  return torch.load(io.BytesIO(b))


sigma: 0.0003523716574119157; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 42, 43, 44, 47, 48, 49], median norm: 0.8783424153752571,


  return torch.load(io.BytesIO(b))


sigma: 0.00034345483199051997; #clean models / clean models: 26 / [20, 21, 22, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 48, 49], median norm: 0.8561158094228019,


  return torch.load(io.BytesIO(b))


sigma: 0.00033505549471492165; #clean models / clean models: 26 / [20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 48, 49], median norm: 0.8351791250016246,


  return torch.load(io.BytesIO(b))


sigma: 0.00032715906054317696; #clean models / clean models: 26 / [20, 22, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 0.8154960065743267,


  return torch.load(io.BytesIO(b))


sigma: 0.0003191072819468745; #clean models / clean models: 26 / [20, 21, 24, 25, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49], median norm: 0.7954256674548674,


  return torch.load(io.BytesIO(b))


sigma: 0.00031431050320264594; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38, 39, 40, 41, 42, 43, 44, 46, 48, 49], median norm: 0.7834689333089618,


  return torch.load(io.BytesIO(b))


sigma: 0.00030886717336831194; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 44, 45, 47, 48, 49], median norm: 0.7699005677102948,


  return torch.load(io.BytesIO(b))


sigma: 0.00030518266610213664; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 44, 45, 47, 48, 49], median norm: 0.7607163471761871,


  return torch.load(io.BytesIO(b))


sigma: 0.0003001709587378503; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 40, 42, 43, 44, 45, 48, 49], median norm: 0.7482238692515071,


  return torch.load(io.BytesIO(b))


sigma: 0.0002953630880444409; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 28, 29, 31, 32, 33, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 0.7362394866576354,


  return torch.load(io.BytesIO(b))


sigma: 0.0005017087374266868; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49], median norm: 1.2505888455469323,


  return torch.load(io.BytesIO(b))


sigma: 0.0004808627456712072; #clean models / clean models: 26 / [21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 33, 34, 35, 36, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 1.1986268946798206,


  return torch.load(io.BytesIO(b))


sigma: 0.0004667055585932561; #clean models / clean models: 26 / [20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 44, 46, 48, 49], median norm: 1.1633378536022891,


  return torch.load(io.BytesIO(b))


sigma: 0.0004591502889457534; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 48, 49], median norm: 1.1445051420279828,


  return torch.load(io.BytesIO(b))


sigma: 0.0004508821444576241; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 28, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 1.1238954765012072,


  return torch.load(io.BytesIO(b))


sigma: 0.00044200578377608427; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 35, 36, 37, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49], median norm: 1.101769735350432,


  return torch.load(io.BytesIO(b))


sigma: 0.000433413321916107; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 40, 41, 42, 43, 44, 45, 46, 47, 48], median norm: 1.0803516571782432,


  return torch.load(io.BytesIO(b))


sigma: 0.000424859985403524; #clean models / clean models: 26 / [20, 21, 23, 24, 25, 26, 28, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 48, 49], median norm: 1.059031105158938,


  return torch.load(io.BytesIO(b))


sigma: 0.00041509534936473856; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 1.0346911963633192,


  return torch.load(io.BytesIO(b))


sigma: 0.00040620784579280565; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 28, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 43, 44, 45, 46, 47, 48, 49], median norm: 1.012537679785502,


  return torch.load(io.BytesIO(b))


sigma: 0.0004005762640353997; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 44, 46, 47, 48, 49], median norm: 0.9985000663193291,


  return torch.load(io.BytesIO(b))


sigma: 0.00038971466865936804; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 0.9714258118589028,


  return torch.load(io.BytesIO(b))


sigma: 0.00037893686733651025; #clean models / clean models: 26 / [20, 22, 23, 25, 26, 27, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 0.9445604274069198,


  return torch.load(io.BytesIO(b))


sigma: 0.0003708143281680604; #clean models / clean models: 26 / [20, 21, 23, 25, 26, 27, 28, 29, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 0.9243137063039895,


  return torch.load(io.BytesIO(b))


sigma: 0.000361029775091249; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49], median norm: 0.8999241511224628,


  return torch.load(io.BytesIO(b))


sigma: 0.0003522072536791652; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 0.8779326129161678,


  return torch.load(io.BytesIO(b))


sigma: 0.0003406980240407839; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 37, 38, 40, 42, 43, 44, 45, 47, 48, 49], median norm: 0.8492440270238382,


  return torch.load(io.BytesIO(b))


sigma: 0.00033242388005997735; #clean models / clean models: 26 / [20, 21, 22, 24, 25, 26, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 0.8286194068070973,


  return torch.load(io.BytesIO(b))


sigma: 0.0003260434604948796; #clean models / clean models: 26 / [21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 48, 49], median norm: 0.8127151959716489,


  return torch.load(io.BytesIO(b))


sigma: 0.00031895260171688623; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 0.795040102373338,


  return torch.load(io.BytesIO(b))


sigma: 0.00031431476186726383; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 44, 46, 47, 48, 49], median norm: 0.7834795487080386,


  return torch.load(io.BytesIO(b))


sigma: 0.000309428434123029; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 35, 36, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 0.7712995994331477,


  return torch.load(io.BytesIO(b))


sigma: 0.00030327164454121026; #clean models / clean models: 26 / [20, 21, 22, 24, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 0.755952821908614,


  return torch.load(io.BytesIO(b))


sigma: 0.00029949008032995305; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 48, 49], median norm: 0.7465266714979709,


  return torch.load(io.BytesIO(b))


sigma: 0.0002930516511092805; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 0.7304778623671689,


  return torch.load(io.BytesIO(b))


sigma: 0.0005017995358462894; #clean models / clean models: 26 / [20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38, 39, 41, 42, 43, 44, 45, 46, 48, 49], median norm: 1.250815175053033,


  return torch.load(io.BytesIO(b))


sigma: 0.00047999844539330143; #clean models / clean models: 26 / [20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 49], median norm: 1.1964724887344587,


  return torch.load(io.BytesIO(b))


sigma: 0.00046777874829595994; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 27, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 1.1660129496714569,


  return torch.load(io.BytesIO(b))


sigma: 0.00045895744766362597; #clean models / clean models: 26 / [21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 1.144024454452904,


  return torch.load(io.BytesIO(b))


sigma: 0.00045006520428184267; #clean models / clean models: 26 / [20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48, 49], median norm: 1.121859123144972,


  return torch.load(io.BytesIO(b))


sigma: 0.00044014406337942064; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 28, 29, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 48, 49], median norm: 1.0971291010781719,


  return torch.load(io.BytesIO(b))


sigma: 0.00043084643565841416; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 1.073953284811592,


  return torch.load(io.BytesIO(b))


sigma: 0.00042652421194974997; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 35, 36, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 1.0631794545894229,


  return torch.load(io.BytesIO(b))


sigma: 0.00041545814191033684; #clean models / clean models: 26 / [20, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48], median norm: 1.035595514500372,


  return torch.load(io.BytesIO(b))


sigma: 0.0004087413505668777; #clean models / clean models: 26 / [20, 21, 22, 23, 24, 26, 28, 29, 30, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 48, 49], median norm: 1.0188528434935236,


  return torch.load(io.BytesIO(b))


sigma: 0.0003961614169678062; #clean models / clean models: 26 / [20, 21, 22, 23, 25, 26, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49], median norm: 0.9874953576394553,


  return torch.load(io.BytesIO(b))


KeyboardInterrupt: 