In [12]:
from datasets import load_dataset, concatenate_datasets, load_from_disk
import pandas as pd
import datasets
from datasets import Dataset
from pprint import pprint as pp

In [2]:
code_data = load_dataset("sahil2801/CodeAlpaca-20k")["train"]
fin_data = load_dataset("FinGPT/fingpt-sentiment-train")["train"]
med_data = load_dataset("medalpaca/medical_meadow_medical_flashcards")["train"]
general_data = load_dataset("tatsu-lab/alpaca")["train"]
math_data = load_dataset("TIGER-Lab/MathInstruct")["train"]

In [3]:
def alpaca_format(example):
    if example['input'] == "":
        example["instruction"] = example["instruction"]
    else:
        example["instruction"] = example["instruction"] + " " + example['input']
    example["response"] = example['output']
    return example

In [4]:
def process_sft_dataset(dataset_name, dataset, dataset_sample=None)->datasets.Dataset:
    if dataset_name in ["lucasmccabe-lmi/CodeAlpaca-20k", "yahma/alpaca-cleaned", "FinGPT/fingpt-sentiment-train"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["WizardLM/WizardLM_evol_instruct_70k"]:
        dataset = dataset.rename_column("output", "response")
    elif dataset_name in ["tatsu-lab/alpaca", "vicgalle/alpaca-gpt4", "gbharti/finance-alpaca"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["TIGER-Lab/MathInstruct"]:
        df = pd.DataFrame(dataset)
        df = df.drop_duplicates(subset=['instruction'])
        dataset = datasets.Dataset.from_pandas(df)
        dataset = dataset.rename_column("output", "response")
        dataset = dataset.remove_columns(['source'])
    elif dataset_name in ["lighteval/MATH"]:
        dataset = dataset.rename_column("solution", "response")
        dataset = dataset.rename_column("problem", "instruction")
        dataset = dataset.remove_columns(['level', 'type'])
    elif dataset_name in ['gsm8k']:
        dataset = dataset.rename_column("question", "instruction")
        dataset = dataset.rename_column("answer", "response")
    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']:       # TODO: 'lavita/ChatDoctor-HealthCareMagic-100k'. not sure whether to discard the instruction.
        dataset = dataset.remove_columns(['instruction'])
        dataset = dataset.rename_column("input", "instruction")
        dataset = dataset.rename_column("output", "response")
    elif "math" in dataset_name:
        dataset = dataset.remove_columns(['source'])
        dataset = dataset.rename_column("output", "response")
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported.")
    dataset = dataset.shuffle(seed=2023)
    if dataset_sample:
        num_sample = min(len(dataset), dataset_sample)
        dataset = dataset.select(range(num_sample))
    print(f">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====")
    return dataset

In [5]:
processed_data = []
for name, dataset in zip(["lucasmccabe-lmi/CodeAlpaca-20k","FinGPT/fingpt-sentiment-train","medalpaca/medical_meadow_medical_flashcards","tatsu-lab/alpaca","TIGER-Lab/MathInstruct"],[code_data,fin_data,med_data,general_data,math_data]):
    tmp:datasets.Dataset = process_sft_dataset(name,dataset)
    print(tmp.column_names)
    processed_data.append(tmp)

>> ===== After processing, Dataset lucasmccabe-lmi/CodeAlpaca-20k has 20022 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset FinGPT/fingpt-sentiment-train has 76772 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset medalpaca/medical_meadow_medical_flashcards has 33955 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset tatsu-lab/alpaca has 52002 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset TIGER-Lab/MathInstruct has 224567 examples. =====
['response', 'instruction', '__index_level_0__']


In [6]:
data_concated = concatenate_datasets(processed_data)

# 构造base数据集

In [8]:
import numpy as np
import random
random.seed(10)
sampled_indices = random.sample(range(len(processed_data[0])), 1000)
sampled_data = processed_data[0].select(sampled_indices)
sampled_set = set(sampled_indices)
base_set = set(range(len(data_concated)))
# 计算差集，即在 idx_set 中但不在 sampled_set 中的元素
remaining_idx = list(base_set - sampled_set)
print(len(remaining_idx))
data_concated = data_concated.select(remaining_idx)

405318


# 将base数据集随机拆成十份

In [9]:
sampled_data = sampled_data.shuffle(seed=42)  
local_datasets = []
for i in range(10):
    local_datasets.append(sampled_data.shard(10, i))

In [10]:
print(len(local_datasets[0]))

100


## 将公共数据集也随机拆成10份

In [17]:
data_concated = data_concated.shuffle(seed=42)
public_datasets = []
for i in range(10):
    public_datasets.append(data_concated.shard(10,i))

# 构造随机采样数据集

In [13]:
client_random_datasets = []
dataset: Dataset
for dataset in public_datasets:
    idxs = random.sample(range(len(dataset)), 5000)
    client_random_datasets.append(dataset.select(idxs))
print(len(client_random_datasets[0]))

5000


In [16]:
for i, dataset in enumerate(client_random_datasets):
    dataset = concatenate_datasets([dataset,local_datasets[i]]).shuffle(seed=42)
    dataset.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/random_with_base_{i}.parquet")

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

# 对每一个客户端的数据集进行检索，构造 pos 和 neg 数据集

In [14]:
from FlagEmbedding import FlagModel
model = FlagModel('BAAI/bge-large-en-v1.5', 
                  query_instruction_for_retrieval="",
                  use_fp16=True)

----------using 8*GPUs----------


In [None]:
heapq.nlargest(500, range(len(numbers)), key=lambda x: numbers[x])

In [38]:
public_datasets[0].save_to_disk("/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/public")

Saving the dataset (0/1 shards):   0%|          | 0/40532 [00:00<?, ? examples/s]

In [34]:
from datasets import Dataset
from sklearn.cluster import KMeans
from tqdm import tqdm
import torch
import heapq
client_pos_datasets, client_neg_datasets = [], []
for i, sampled_data in enumerate(local_datasets):
    print(i)
    sampled_embeddings = model.encode(sampled_data["instruction"])
    # 假设 embeddings 是你的嵌入数据
    k = 10
    kmeans = KMeans(n_clusters=k, random_state=0).fit(sampled_embeddings)
    concated_embeddings = model.encode(public_datasets[i]["instruction"])
    clusters = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
    concated_embeddings = torch.tensor(concated_embeddings, dtype=torch.float32)
    similarity_scores = clusters @ concated_embeddings.T
    top_idxs = []
    bot_idxs = []
    for i in range(similarity_scores.shape[0]):
        print(i)
        tmp = similarity_scores[i]
        # print(similarity_scores[i][:10])
        top_idxs.append(heapq.nlargest(500, range(len(tmp)-1), key=lambda x: tmp[x]))
        bot_idxs.append(heapq.nsmallest(500, range(len(tmp)-1), key=lambda x: tmp[x]))
        top_scores = [similarity_scores[i][idx] for idx in top_idxs[i][:10]]
        bot_scores = [similarity_scores[i][idx] for idx in bot_idxs[i][:10]]
        print(top_scores)
        print(bot_scores)
        
    pos_datasets: Dataset = []
    neg_datasets: Dataset = []
    # for i in range(len(top_idxs)):
    #     instruction_length = len(public_datasets[i]["instruction"])
    #     filtered_top_idxs = list(filter(lambda idx: idx < instruction_length, top_idxs[i]))
    #     filtered_bot_idxs = list(filter(lambda idx: idx < instruction_length, bot_idxs[i]))
        # pos_datasets.append(public_datasets[i].select(filtered_top_idxs))
        # neg_datasets.append(public_datasets[i].select(filtered_bot_idxs))
    top_idxs=np.concatenate(top_idxs,axis=None)
    bot_idxs=np.concatenate(bot_idxs,axis=None)
    pos_datasets = public_datasets[i].select(top_idxs)
    neg_datasets = public_datasets[i].select(bot_idxs)
    pos_datasets = concatenate_datasets([pos_datasets, sampled_data])
    neg_datasets = concatenate_datasets([neg_datasets, sampled_data])
    # pos_datasets = pos_datasets.shuffle(seed=42)
    # neg_datasets = neg_datasets.shuffle(seed=42)
    client_pos_datasets.append(pos_datasets)
    client_neg_datasets.append(neg_datasets)

0


Inference Embeddings: 100%|██████████| 20/20 [00:20<00:00,  1.00s/it]


0
[tensor(0.6716), tensor(0.6623), tensor(0.6423), tensor(0.6422), tensor(0.6420), tensor(0.6408), tensor(0.6408), tensor(0.6386), tensor(0.6314), tensor(0.6310)]
[tensor(0.2119), tensor(0.2182), tensor(0.2213), tensor(0.2292), tensor(0.2306), tensor(0.2319), tensor(0.2331), tensor(0.2358), tensor(0.2361), tensor(0.2376)]
1
[tensor(0.7447), tensor(0.7424), tensor(0.7290), tensor(0.7249), tensor(0.7165), tensor(0.7165), tensor(0.7152), tensor(0.7093), tensor(0.7077), tensor(0.7066)]
[tensor(0.2246), tensor(0.2262), tensor(0.2264), tensor(0.2271), tensor(0.2273), tensor(0.2331), tensor(0.2431), tensor(0.2449), tensor(0.2478), tensor(0.2494)]
2
[tensor(0.7038), tensor(0.6996), tensor(0.6914), tensor(0.6863), tensor(0.6837), tensor(0.6805), tensor(0.6783), tensor(0.6780), tensor(0.6776), tensor(0.6733)]
[tensor(0.2140), tensor(0.2180), tensor(0.2380), tensor(0.2442), tensor(0.2445), tensor(0.2479), tensor(0.2510), tensor(0.2526), tensor(0.2528), tensor(0.2575)]
3
[tensor(0.6978), tensor(0.

Inference Embeddings: 100%|██████████| 20/20 [00:19<00:00,  1.05it/s]


0
[tensor(0.7223), tensor(0.7146), tensor(0.7128), tensor(0.7119), tensor(0.7081), tensor(0.7060), tensor(0.7054), tensor(0.7052), tensor(0.7028), tensor(0.7019)]
[tensor(0.2167), tensor(0.2241), tensor(0.2268), tensor(0.2274), tensor(0.2289), tensor(0.2377), tensor(0.2378), tensor(0.2395), tensor(0.2396), tensor(0.2422)]
1
[tensor(0.6424), tensor(0.6419), tensor(0.6397), tensor(0.6366), tensor(0.6331), tensor(0.6315), tensor(0.6309), tensor(0.6303), tensor(0.6297), tensor(0.6289)]
[tensor(0.2353), tensor(0.2412), tensor(0.2498), tensor(0.2615), tensor(0.2621), tensor(0.2654), tensor(0.2669), tensor(0.2674), tensor(0.2677), tensor(0.2683)]
2
[tensor(0.7071), tensor(0.7006), tensor(0.6983), tensor(0.6964), tensor(0.6964), tensor(0.6959), tensor(0.6949), tensor(0.6910), tensor(0.6910), tensor(0.6899)]
[tensor(0.2100), tensor(0.2110), tensor(0.2347), tensor(0.2379), tensor(0.2395), tensor(0.2402), tensor(0.2466), tensor(0.2467), tensor(0.2498), tensor(0.2503)]
3
[tensor(0.6862), tensor(0.

Inference Embeddings: 100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


0
[tensor(0.6415), tensor(0.6379), tensor(0.6374), tensor(0.6371), tensor(0.6367), tensor(0.6343), tensor(0.6332), tensor(0.6322), tensor(0.6321), tensor(0.6316)]
[tensor(0.2386), tensor(0.2400), tensor(0.2499), tensor(0.2608), tensor(0.2623), tensor(0.2646), tensor(0.2678), tensor(0.2682), tensor(0.2689), tensor(0.2689)]
1
[tensor(0.7212), tensor(0.7179), tensor(0.7107), tensor(0.7085), tensor(0.7057), tensor(0.7027), tensor(0.7019), tensor(0.7019), tensor(0.7009), tensor(0.7000)]
[tensor(0.2287), tensor(0.2316), tensor(0.2357), tensor(0.2401), tensor(0.2414), tensor(0.2453), tensor(0.2467), tensor(0.2492), tensor(0.2494), tensor(0.2505)]
2
[tensor(0.6972), tensor(0.6944), tensor(0.6906), tensor(0.6897), tensor(0.6893), tensor(0.6832), tensor(0.6804), tensor(0.6784), tensor(0.6771), tensor(0.6768)]
[tensor(0.2381), tensor(0.2411), tensor(0.2414), tensor(0.2416), tensor(0.2424), tensor(0.2470), tensor(0.2480), tensor(0.2494), tensor(0.2498), tensor(0.2519)]
3
[tensor(0.7328), tensor(0.

Inference Embeddings: 100%|██████████| 20/20 [00:19<00:00,  1.05it/s]


0
[tensor(0.6498), tensor(0.6469), tensor(0.6433), tensor(0.6428), tensor(0.6356), tensor(0.6333), tensor(0.6318), tensor(0.6311), tensor(0.6309), tensor(0.6293)]
[tensor(0.2371), tensor(0.2398), tensor(0.2553), tensor(0.2558), tensor(0.2568), tensor(0.2589), tensor(0.2607), tensor(0.2612), tensor(0.2618), tensor(0.2634)]
1
[tensor(0.7750), tensor(0.7488), tensor(0.7482), tensor(0.7433), tensor(0.7408), tensor(0.7354), tensor(0.7345), tensor(0.7338), tensor(0.7328), tensor(0.7325)]
[tensor(0.2125), tensor(0.2388), tensor(0.2430), tensor(0.2455), tensor(0.2463), tensor(0.2470), tensor(0.2490), tensor(0.2531), tensor(0.2537), tensor(0.2537)]
2
[tensor(0.7405), tensor(0.7394), tensor(0.7392), tensor(0.7383), tensor(0.7377), tensor(0.7367), tensor(0.7349), tensor(0.7285), tensor(0.7283), tensor(0.7255)]
[tensor(0.2297), tensor(0.2341), tensor(0.2351), tensor(0.2383), tensor(0.2437), tensor(0.2451), tensor(0.2457), tensor(0.2463), tensor(0.2488), tensor(0.2492)]
3
[tensor(0.6753), tensor(0.

Inference Embeddings: 100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


0
[tensor(0.7488), tensor(0.6886), tensor(0.6838), tensor(0.6745), tensor(0.6658), tensor(0.6593), tensor(0.6546), tensor(0.6453), tensor(0.6442), tensor(0.6312)]
[tensor(0.2257), tensor(0.2265), tensor(0.2343), tensor(0.2361), tensor(0.2366), tensor(0.2373), tensor(0.2446), tensor(0.2447), tensor(0.2449), tensor(0.2463)]
1
[tensor(0.7057), tensor(0.6943), tensor(0.6931), tensor(0.6928), tensor(0.6902), tensor(0.6864), tensor(0.6862), tensor(0.6843), tensor(0.6817), tensor(0.6771)]
[tensor(0.2311), tensor(0.2369), tensor(0.2413), tensor(0.2447), tensor(0.2468), tensor(0.2474), tensor(0.2501), tensor(0.2513), tensor(0.2533), tensor(0.2548)]
2
[tensor(0.7189), tensor(0.6942), tensor(0.6755), tensor(0.6704), tensor(0.6591), tensor(0.6586), tensor(0.6573), tensor(0.6567), tensor(0.6563), tensor(0.6559)]
[tensor(0.2283), tensor(0.2355), tensor(0.2449), tensor(0.2474), tensor(0.2479), tensor(0.2489), tensor(0.2490), tensor(0.2493), tensor(0.2524), tensor(0.2538)]
3
[tensor(0.6754), tensor(0.

Inference Embeddings: 100%|██████████| 20/20 [00:19<00:00,  1.05it/s]


0
[tensor(0.7475), tensor(0.7092), tensor(0.6954), tensor(0.6916), tensor(0.6916), tensor(0.6809), tensor(0.6793), tensor(0.6790), tensor(0.6734), tensor(0.6669)]
[tensor(0.1843), tensor(0.1909), tensor(0.2063), tensor(0.2172), tensor(0.2174), tensor(0.2188), tensor(0.2193), tensor(0.2225), tensor(0.2240), tensor(0.2250)]
1
[tensor(0.6728), tensor(0.6694), tensor(0.6673), tensor(0.6651), tensor(0.6649), tensor(0.6615), tensor(0.6608), tensor(0.6596), tensor(0.6586), tensor(0.6581)]
[tensor(0.2447), tensor(0.2482), tensor(0.2520), tensor(0.2603), tensor(0.2634), tensor(0.2636), tensor(0.2654), tensor(0.2655), tensor(0.2673), tensor(0.2675)]
2
[tensor(0.6521), tensor(0.6166), tensor(0.6136), tensor(0.6123), tensor(0.6099), tensor(0.6090), tensor(0.6058), tensor(0.6052), tensor(0.6047), tensor(0.6040)]
[tensor(0.2042), tensor(0.2060), tensor(0.2133), tensor(0.2209), tensor(0.2214), tensor(0.2226), tensor(0.2228), tensor(0.2253), tensor(0.2286), tensor(0.2292)]
3
[tensor(0.6882), tensor(0.

Inference Embeddings: 100%|██████████| 20/20 [00:19<00:00,  1.05it/s]


0
[tensor(0.6587), tensor(0.6551), tensor(0.6536), tensor(0.6526), tensor(0.6475), tensor(0.6463), tensor(0.6456), tensor(0.6453), tensor(0.6414), tensor(0.6393)]
[tensor(0.2217), tensor(0.2289), tensor(0.2330), tensor(0.2426), tensor(0.2462), tensor(0.2502), tensor(0.2536), tensor(0.2547), tensor(0.2547), tensor(0.2557)]
1
[tensor(0.7683), tensor(0.7642), tensor(0.7580), tensor(0.7572), tensor(0.7501), tensor(0.7497), tensor(0.7415), tensor(0.7409), tensor(0.7375), tensor(0.7374)]
[tensor(0.2138), tensor(0.2193), tensor(0.2370), tensor(0.2397), tensor(0.2445), tensor(0.2452), tensor(0.2453), tensor(0.2461), tensor(0.2474), tensor(0.2527)]
2
[tensor(0.7136), tensor(0.7092), tensor(0.7030), tensor(0.7016), tensor(0.7015), tensor(0.6933), tensor(0.6923), tensor(0.6917), tensor(0.6907), tensor(0.6906)]
[tensor(0.2017), tensor(0.2250), tensor(0.2374), tensor(0.2381), tensor(0.2418), tensor(0.2449), tensor(0.2509), tensor(0.2516), tensor(0.2548), tensor(0.2554)]
3
[tensor(0.7086), tensor(0.

Inference Embeddings: 100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


0
[tensor(0.7391), tensor(0.7378), tensor(0.7280), tensor(0.7275), tensor(0.7267), tensor(0.7129), tensor(0.7114), tensor(0.7091), tensor(0.7089), tensor(0.6973)]
[tensor(0.2392), tensor(0.2480), tensor(0.2550), tensor(0.2574), tensor(0.2584), tensor(0.2591), tensor(0.2612), tensor(0.2616), tensor(0.2636), tensor(0.2639)]
1
[tensor(0.6390), tensor(0.6384), tensor(0.6372), tensor(0.6335), tensor(0.6334), tensor(0.6314), tensor(0.6301), tensor(0.6290), tensor(0.6265), tensor(0.6256)]
[tensor(0.2605), tensor(0.2635), tensor(0.2638), tensor(0.2685), tensor(0.2712), tensor(0.2744), tensor(0.2753), tensor(0.2759), tensor(0.2765), tensor(0.2767)]
2
[tensor(0.6563), tensor(0.6556), tensor(0.6536), tensor(0.6522), tensor(0.6502), tensor(0.6462), tensor(0.6452), tensor(0.6448), tensor(0.6429), tensor(0.6418)]
[tensor(0.2276), tensor(0.2514), tensor(0.2532), tensor(0.2533), tensor(0.2538), tensor(0.2564), tensor(0.2566), tensor(0.2566), tensor(0.2573), tensor(0.2583)]
3
[tensor(0.7488), tensor(0.

Inference Embeddings: 100%|██████████| 20/20 [00:19<00:00,  1.02it/s]


0
[tensor(0.7574), tensor(0.7408), tensor(0.7399), tensor(0.7373), tensor(0.7325), tensor(0.7267), tensor(0.7267), tensor(0.7254), tensor(0.7254), tensor(0.7253)]
[tensor(0.2331), tensor(0.2460), tensor(0.2501), tensor(0.2527), tensor(0.2529), tensor(0.2562), tensor(0.2601), tensor(0.2601), tensor(0.2624), tensor(0.2625)]
1
[tensor(0.7054), tensor(0.7051), tensor(0.7049), tensor(0.7047), tensor(0.7043), tensor(0.7039), tensor(0.7018), tensor(0.6961), tensor(0.6954), tensor(0.6939)]
[tensor(0.2139), tensor(0.2366), tensor(0.2380), tensor(0.2420), tensor(0.2429), tensor(0.2459), tensor(0.2501), tensor(0.2525), tensor(0.2534), tensor(0.2540)]
2
[tensor(0.7406), tensor(0.7325), tensor(0.7155), tensor(0.7151), tensor(0.7142), tensor(0.7133), tensor(0.7074), tensor(0.7032), tensor(0.7019), tensor(0.7003)]
[tensor(0.2188), tensor(0.2308), tensor(0.2420), tensor(0.2474), tensor(0.2481), tensor(0.2498), tensor(0.2511), tensor(0.2517), tensor(0.2526), tensor(0.2537)]
3
[tensor(0.6954), tensor(0.

Inference Embeddings: 100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


0
[tensor(0.6443), tensor(0.6308), tensor(0.6070), tensor(0.6034), tensor(0.6000), tensor(0.5950), tensor(0.5949), tensor(0.5933), tensor(0.5926), tensor(0.5907)]
[tensor(0.2155), tensor(0.2164), tensor(0.2210), tensor(0.2279), tensor(0.2329), tensor(0.2426), tensor(0.2431), tensor(0.2435), tensor(0.2451), tensor(0.2452)]
1
[tensor(0.7014), tensor(0.6978), tensor(0.6957), tensor(0.6922), tensor(0.6902), tensor(0.6889), tensor(0.6878), tensor(0.6868), tensor(0.6867), tensor(0.6866)]
[tensor(0.2215), tensor(0.2259), tensor(0.2302), tensor(0.2380), tensor(0.2438), tensor(0.2546), tensor(0.2556), tensor(0.2558), tensor(0.2558), tensor(0.2566)]
2
[tensor(0.7546), tensor(0.7384), tensor(0.7281), tensor(0.7272), tensor(0.7260), tensor(0.7244), tensor(0.7238), tensor(0.7229), tensor(0.7203), tensor(0.7197)]
[tensor(0.1957), tensor(0.2041), tensor(0.2072), tensor(0.2130), tensor(0.2273), tensor(0.2331), tensor(0.2332), tensor(0.2370), tensor(0.2382), tensor(0.2407)]
3
[tensor(0.8027), tensor(0.

In [33]:
# client_pos_datasets, client_neg_datasets = client_pos_datasets[-10:], client_neg_datasets[-10:]
for i, pos_data in enumerate(client_pos_datasets):
    pos_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/pos_{i}.parquet")
# for i, (pos_data, neg_data) in enumerate(zip(client_pos_datasets, client_neg_datasets)):
#     pos_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/pos_{i}.parquet")
#     neg_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/neg_{i}.parquet")

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

# 构造pos+diversity 数据集，一半 pos，一半 diversity

In [None]:
from datasets import Dataset
import torch
import heapq
from tqdm import tqdm
from sklearn.cluster import KMeans

client_pos_datasets=[]
for i, sampled_data in enumerate(local_datasets):
    print(i)
    sampled_embeddings = model.encode(sampled_data["instruction"])
    # 假设 embeddings 是你的嵌入数据
    k = 10
    kmeans = KMeans(n_clusters=k, random_state=0).fit(sampled_embeddings)
    print(kmeans)
    concated_embeddings = model.encode(public_datasets[i]["instruction"])
    clusters = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
    concated_embeddings = torch.tensor(concated_embeddings, dtype=torch.float32)
    similarity_scores = clusters @ concated_embeddings.T
    top_idxs = []
    for i in range(similarity_scores.shape[0]):
        tmp = similarity_scores[i]
        top_idxs.append(heapq.nlargest(250, range(len(tmp)), key=tmp.__getitem__))
    pos_datasets: Dataset = []
    # top_idxs去重，其余作为 diversity
    top_idxs = set(np.concatenate(top_idxs,axis=0))
    try: top_idxs.remove(len(public_datasets[i]))
    except: pass
    pos_datasets = public_datasets[i].select(top_idxs)
    print(len(top_idxs))
    # 从public_datasets[i]中去掉 top_idxs
    all_idxs = set(range(len(public_datasets[i])))
    remain_idxs = list(all_idxs-top_idxs)
    random_idxs = random.sample(remain_idxs, 5000-len(top_idxs))
    diversity_datasets = public_datasets[i].select(random_idxs)
    pos_datasets = concatenate_datasets([pos_datasets, diversity_datasets, sampled_data])
    pos_datasets = pos_datasets.shuffle(seed=42)
    client_pos_datasets.append(pos_datasets)

In [None]:
for i, pos_data in enumerate(client_pos_datasets):
    pos_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/T_{i}.parquet")

# 构造去重 pos 数据集

In [31]:
from ordered_set import OrderedSet
tmp = OrderedSet([4,5,3,7,1])
tmp1 = OrderedSet([4,5])
for t in tmp1: tmp.discard(t)
print(tmp)

OrderedSet([3, 7, 1])


In [None]:
from datasets import Dataset
import torch
import heapq
from tqdm import tqdm
from sklearn.cluster import KMeans
from ordered_set import OrderedSet

client_pos_datasets=[]
for i, sampled_data in enumerate(local_datasets):
    print(i)
    sampled_embeddings = model.encode(sampled_data["instruction"])
    # 假设 embeddings 是你的嵌入数据
    k = 10
    kmeans = KMeans(n_clusters=k, random_state=0).fit(sampled_embeddings)
    print(kmeans)
    concated_embeddings = model.encode(public_datasets[i]["instruction"])
    clusters = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
    concated_embeddings = torch.tensor(concated_embeddings, dtype=torch.float32)
    top_idxs:OrderedSet = OrderedSet()
    remain_idxs = OrderedSet(range(len(public_datasets[i])))
    for i in range(k):
        similarity_scores = clusters[i] @ concated_embeddings.T
        top_idx = list(OrderedSet(heapq.nlargest(5000, range(len(similarity_scores)), key=similarity_scores.__getitem__))-top_idxs)[:500]
        top_idxs.update(top_idx)
        print("top_idxs", len(top_idxs))
        remain_idxs.difference_update(top_idx)
        print("remain_idxs", len(remain_idxs))

    try: top_idxs.remove(len(public_datasets[i]))
    except: pass
    pos_datasets = public_datasets[i].select(list(top_idxs))
    pos_datasets = concatenate_datasets([pos_datasets, sampled_data])
    pos_datasets = pos_datasets.shuffle(seed=42)
    client_pos_datasets.append(pos_datasets)

In [34]:
for i, pos_data in enumerate(client_pos_datasets):
    pos_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/pos_nodup_{i}.parquet")

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5099 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5099 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

# 查看数据集