In [1]:
from datasets import load_dataset, concatenate_datasets, load_from_disk
import pandas as pd
import datasets

In [2]:
code_data = load_dataset("sahil2801/CodeAlpaca-20k")["train"]
fin_data = load_dataset("FinGPT/fingpt-sentiment-train")["train"]
med_data = load_dataset("medalpaca/medical_meadow_medical_flashcards")["train"]
general_data = load_dataset("tatsu-lab/alpaca")["train"]
math_data = load_dataset("TIGER-Lab/MathInstruct")["train"]

In [3]:
def alpaca_format(example):
    if example['input'] == "":
        example["instruction"] = example["instruction"]
    else:
        example["instruction"] = example["instruction"] + " " + example['input']
    example["response"] = example['output']
    return example

In [4]:
def process_sft_dataset(dataset_name, dataset, dataset_sample)->datasets.Dataset:
    if dataset_name in ["lucasmccabe-lmi/CodeAlpaca-20k", "yahma/alpaca-cleaned", "FinGPT/fingpt-sentiment-train"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["WizardLM/WizardLM_evol_instruct_70k"]:
        dataset = dataset.rename_column("output", "response")
    elif dataset_name in ["tatsu-lab/alpaca", "vicgalle/alpaca-gpt4", "gbharti/finance-alpaca"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["TIGER-Lab/MathInstruct"]:
        df = pd.DataFrame(dataset)
        df = df.drop_duplicates(subset=['instruction'])
        dataset = datasets.Dataset.from_pandas(df)
        dataset = dataset.rename_column("output", "response")
        dataset = dataset.remove_columns(['source'])
    elif dataset_name in ["lighteval/MATH"]:
        dataset = dataset.rename_column("solution", "response")
        dataset = dataset.rename_column("problem", "instruction")
        dataset = dataset.remove_columns(['level', 'type'])
    elif dataset_name in ['gsm8k']:
        dataset = dataset.rename_column("question", "instruction")
        dataset = dataset.rename_column("answer", "response")
    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']:       # TODO: 'lavita/ChatDoctor-HealthCareMagic-100k'. not sure whether to discard the instruction.
        dataset = dataset.remove_columns(['instruction'])
        dataset = dataset.rename_column("input", "instruction")
        dataset = dataset.rename_column("output", "response")
    elif "math" in dataset_name:
        dataset = dataset.remove_columns(['source'])
        dataset = dataset.rename_column("output", "response")
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported.")
    dataset = dataset.shuffle(seed=2023)
    if dataset_sample:
        num_sample = min(len(dataset), dataset_sample)
        dataset = dataset.select(range(num_sample))
    print(f">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====")
    return dataset

In [5]:
processed_data = []
for name, dataset in zip(["lucasmccabe-lmi/CodeAlpaca-20k","FinGPT/fingpt-sentiment-train","medalpaca/medical_meadow_medical_flashcards","tatsu-lab/alpaca","TIGER-Lab/MathInstruct"],[code_data,fin_data,med_data,general_data,math_data]):
    tmp:datasets.Dataset = process_sft_dataset(name,dataset,None)
    print(tmp.column_names)
    processed_data.append(tmp)

>> ===== After processing, Dataset lucasmccabe-lmi/CodeAlpaca-20k has 20022 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset FinGPT/fingpt-sentiment-train has 76772 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset medalpaca/medical_meadow_medical_flashcards has 33955 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset tatsu-lab/alpaca has 52002 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset TIGER-Lab/MathInstruct has 224567 examples. =====
['response', 'instruction', '__index_level_0__']


In [6]:
data_concated = concatenate_datasets(processed_data)

# 构造base数据集

In [7]:
import numpy as np
import random
random.seed(10)
sampled_indices = random.sample(range(len(processed_data[0])), 1000)
sampled_data = processed_data[0].select(sampled_indices)
sampled_set = set(sampled_indices)
base_set = set(range(len(data_concated)))

# 计算差集，即在 idx_set 中但不在 sampled_set 中的元素
remaining_idx = list(base_set - sampled_set)
print(len(remaining_idx))
data_concated = data_concated.select(remaining_idx)

406318


# 将base数据集随机拆成十份

In [8]:
sampled_data = sampled_data.shuffle(seed=42)  
local_datasets = []
for i in range(10):
    local_datasets.append(sampled_data.shard(10, i))

In [10]:
print(len(local_datasets[0]))

100


## 将公共数据集也随机拆成10份

In [13]:
data_concated = data_concated.shuffle(seed=42)
public_datasets = []
for i in range(10):
    public_datasets.append(data_concated.shard(10,i))

In [None]:
print(len(public_datasets[8]))
print(len(public_datasets[9]))

40631


# 对每一个客户端的数据集进行检索

In [11]:
from FlagEmbedding import FlagModel
model = FlagModel('BAAI/bge-large-en-v1.5', 
                  query_instruction_for_retrieval="",
                  use_fp16=True)

----------using 8*GPUs----------


In [12]:
client_pos_datasets, client_neg_datasets = [], []

In [21]:
sampled_embeddings = model.encode(sampled_data["instruction"])
from sklearn.cluster import KMeans
# 假设 embeddings 是你的嵌入数据
k = 10
kmeans = KMeans(n_clusters=k, random_state=0).fit(sampled_embeddings)
print(kmeans)
from tqdm import tqdm
concated_embeddings = model.encode(public_datasets[9]["instruction"])
import torch
clusters = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
concated_embeddings = torch.tensor(concated_embeddings, dtype=torch.float32)
print(concated_embeddings.shape)
concated_embeddings = concated_embeddings[:len(public_datasets[9]["instruction"])]
print(len(public_datasets[9]["instruction"]))
import torch
import heapq
similarity_scores = clusters @ concated_embeddings.T
top_idxs = []
bot_idxs = []
for i in range(similarity_scores.shape[0]):
    tmp = similarity_scores[i]
    top_idxs.append(heapq.nlargest(500, range(len(tmp)), key=tmp.__getitem__))
    bot_idxs.append(heapq.nsmallest(500, range(len(tmp)), key=tmp.__getitem__))
pos_datasets = []
neg_datasets = []
for i in range(len(top_idxs)):
    pos_datasets.append(public_datasets[9].select(top_idxs[i]))
    neg_datasets.append(public_datasets[9].select(bot_idxs[i]))
pos_datasets = concatenate_datasets([concatenate_datasets(pos_datasets), sampled_data])
neg_datasets = concatenate_datasets([concatenate_datasets(neg_datasets), sampled_data])

KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:21<00:00,  1.05s/it]


torch.Size([40631, 1024])
40631


In [27]:
for i, sampled_data in enumerate(local_datasets):
    print(i)
    sampled_embeddings = model.encode(sampled_data["instruction"])
    from sklearn.cluster import KMeans
    # 假设 embeddings 是你的嵌入数据
    k = 10
    kmeans = KMeans(n_clusters=k, random_state=0).fit(sampled_embeddings)
    print(kmeans)
    from tqdm import tqdm
    concated_embeddings = model.encode(public_datasets[i]["instruction"])
    import torch
    clusters = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
    concated_embeddings = torch.tensor(concated_embeddings, dtype=torch.float32)
    import torch
    import heapq
    similarity_scores = clusters @ concated_embeddings.T
    top_idxs = []
    bot_idxs = []
    for i in range(similarity_scores.shape[0]):
        tmp = similarity_scores[i]
        top_idxs.append(heapq.nlargest(500, range(len(tmp)), key=tmp.__getitem__))
        bot_idxs.append(heapq.nsmallest(500, range(len(tmp)), key=tmp.__getitem__))
    pos_datasets = []
    neg_datasets = []
    for i in range(len(top_idxs)):
        instruction_length = len(public_datasets[i]["instruction"])
        filtered_top_idxs = list(filter(lambda idx: idx < instruction_length, top_idxs[i]))
        filtered_bot_idxs = list(filter(lambda idx: idx < instruction_length, bot_idxs[i]))
        pos_datasets.append(public_datasets[i].select(filtered_top_idxs))
        neg_datasets.append(public_datasets[i].select(filtered_bot_idxs))
    pos_datasets = concatenate_datasets([concatenate_datasets(pos_datasets), sampled_data])
    neg_datasets = concatenate_datasets([concatenate_datasets(neg_datasets), sampled_data])
    client_pos_datasets.append(pos_datasets)
    client_neg_datasets.append(neg_datasets)

0
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:20<00:00,  1.02s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
1
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:21<00:00,  1.07s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
2
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:20<00:00,  1.05s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
3
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:22<00:00,  1.11s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
4
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:20<00:00,  1.04s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
5
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:21<00:00,  1.06s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
6
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:20<00:00,  1.01s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
7
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:21<00:00,  1.06s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
8
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:21<00:00,  1.07s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
9
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:20<00:00,  1.04s/it]


helo
helo
helo
helo
helo
helo
helo
helo
helo
helo
