In [12]:
from datasets import load_dataset, concatenate_datasets, load_from_disk
import pandas as pd
import datasets
from datasets import Dataset

In [3]:
code_data = load_dataset("sahil2801/CodeAlpaca-20k")["train"]
fin_data = load_dataset("FinGPT/fingpt-sentiment-train")["train"]
med_data = load_dataset("medalpaca/medical_meadow_medical_flashcards")["train"]
general_data = load_dataset("tatsu-lab/alpaca")["train"]
math_data = load_dataset("TIGER-Lab/MathInstruct")["train"]

In [4]:
def alpaca_format(example):
    if example['input'] == "":
        example["instruction"] = example["instruction"]
    else:
        example["instruction"] = example["instruction"] + " " + example['input']
    example["response"] = example['output']
    return example

In [5]:
def process_sft_dataset(dataset_name, dataset, dataset_sample)->datasets.Dataset:
    if dataset_name in ["lucasmccabe-lmi/CodeAlpaca-20k", "yahma/alpaca-cleaned", "FinGPT/fingpt-sentiment-train"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["WizardLM/WizardLM_evol_instruct_70k"]:
        dataset = dataset.rename_column("output", "response")
    elif dataset_name in ["tatsu-lab/alpaca", "vicgalle/alpaca-gpt4", "gbharti/finance-alpaca"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["TIGER-Lab/MathInstruct"]:
        df = pd.DataFrame(dataset)
        df = df.drop_duplicates(subset=['instruction'])
        dataset = datasets.Dataset.from_pandas(df)
        dataset = dataset.rename_column("output", "response")
        dataset = dataset.remove_columns(['source'])
    elif dataset_name in ["lighteval/MATH"]:
        dataset = dataset.rename_column("solution", "response")
        dataset = dataset.rename_column("problem", "instruction")
        dataset = dataset.remove_columns(['level', 'type'])
    elif dataset_name in ['gsm8k']:
        dataset = dataset.rename_column("question", "instruction")
        dataset = dataset.rename_column("answer", "response")
    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']:       # TODO: 'lavita/ChatDoctor-HealthCareMagic-100k'. not sure whether to discard the instruction.
        dataset = dataset.remove_columns(['instruction'])
        dataset = dataset.rename_column("input", "instruction")
        dataset = dataset.rename_column("output", "response")
    elif "math" in dataset_name:
        dataset = dataset.remove_columns(['source'])
        dataset = dataset.rename_column("output", "response")
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported.")
    dataset = dataset.shuffle(seed=2023)
    if dataset_sample:
        num_sample = min(len(dataset), dataset_sample)
        dataset = dataset.select(range(num_sample))
    print(f">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====")
    return dataset

In [6]:
processed_data = []
for name, dataset in zip(["lucasmccabe-lmi/CodeAlpaca-20k","FinGPT/fingpt-sentiment-train","medalpaca/medical_meadow_medical_flashcards","tatsu-lab/alpaca","TIGER-Lab/MathInstruct"],[code_data,fin_data,med_data,general_data,math_data]):
    tmp:datasets.Dataset = process_sft_dataset(name,dataset,None)
    print(tmp.column_names)
    processed_data.append(tmp)

>> ===== After processing, Dataset lucasmccabe-lmi/CodeAlpaca-20k has 20022 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset FinGPT/fingpt-sentiment-train has 76772 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset medalpaca/medical_meadow_medical_flashcards has 33955 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset tatsu-lab/alpaca has 52002 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset TIGER-Lab/MathInstruct has 224567 examples. =====
['response', 'instruction', '__index_level_0__']


In [7]:
data_concated = concatenate_datasets(processed_data)

# 构造base数据集

In [8]:
import numpy as np
import random
random.seed(10)
sampled_indices = random.sample(range(len(processed_data[0])), 1000)
sampled_data = processed_data[0].select(sampled_indices)
sampled_set = set(sampled_indices)
base_set = set(range(len(data_concated)))

# 计算差集，即在 idx_set 中但不在 sampled_set 中的元素
remaining_idx = list(base_set - sampled_set)
print(len(remaining_idx))
data_concated = data_concated.select(remaining_idx)

406318


# 将base数据集随机拆成十份

In [9]:
sampled_data = sampled_data.shuffle(seed=42)  
local_datasets = []
for i in range(10):
    local_datasets.append(sampled_data.shard(10, i))

In [10]:
print(len(local_datasets[0]))

100


## 将公共数据集也随机拆成10份

In [11]:
data_concated = data_concated.shuffle(seed=42)
public_datasets = []
for i in range(10):
    public_datasets.append(data_concated.shard(10,i))

# 构造随机采样数据集

In [13]:
client_random_datasets = []
dataset: Dataset
for dataset in public_datasets:
    idxs = random.sample(range(len(dataset)), 5000)
    client_random_datasets.append(dataset.select(idxs))
print(len(client_random_datasets[0]))

5000


In [16]:
for i, dataset in enumerate(client_random_datasets):
    dataset = concatenate_datasets([dataset,local_datasets[i]]).shuffle(seed=42)
    dataset.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/random_with_base_{i}.parquet")

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

# 对每一个客户端的数据集进行检索，构造 pos 和 neg 数据集

In [17]:
from FlagEmbedding import FlagModel
model = FlagModel('BAAI/bge-large-en-v1.5', 
                  query_instruction_for_retrieval="",
                  use_fp16=True)

----------using 8*GPUs----------


In [13]:
client_pos_datasets, client_neg_datasets = [], []

In [None]:
from datasets import Dataset
for i, sampled_data in enumerate(local_datasets):
    print(i)
    sampled_embeddings = model.encode(sampled_data["instruction"])
    from sklearn.cluster import KMeans
    # 假设 embeddings 是你的嵌入数据
    k = 10
    kmeans = KMeans(n_clusters=k, random_state=0).fit(sampled_embeddings)
    print(kmeans)
    from tqdm import tqdm
    concated_embeddings = model.encode(public_datasets[i]["instruction"])
    import torch
    clusters = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
    concated_embeddings = torch.tensor(concated_embeddings, dtype=torch.float32)
    import torch
    import heapq
    similarity_scores = clusters @ concated_embeddings.T
    top_idxs = []
    bot_idxs = []
    for i in range(similarity_scores.shape[0]):
        tmp = similarity_scores[i]
        top_idxs.append(heapq.nlargest(500, range(len(tmp)), key=tmp.__getitem__))
        bot_idxs.append(heapq.nsmallest(500, range(len(tmp)), key=tmp.__getitem__))
    pos_datasets: Dataset = []
    neg_datasets: Dataset = []
    for i in range(len(top_idxs)):
        instruction_length = len(public_datasets[i]["instruction"])
        filtered_top_idxs = list(filter(lambda idx: idx < instruction_length, top_idxs[i]))
        filtered_bot_idxs = list(filter(lambda idx: idx < instruction_length, bot_idxs[i]))
        pos_datasets.append(public_datasets[i].select(filtered_top_idxs))
        neg_datasets.append(public_datasets[i].select(filtered_bot_idxs))
    pos_datasets = concatenate_datasets([concatenate_datasets(pos_datasets), sampled_data])
    neg_datasets = concatenate_datasets([concatenate_datasets(neg_datasets), sampled_data])
    pos_datasets = pos_datasets.shuffle(seed=42)
    neg_datasets = neg_datasets.shuffle(seed=42)
    client_pos_datasets.append(pos_datasets)
    client_neg_datasets.append(neg_datasets)

In [None]:
for i, (pos_data, neg_data) in enumerate(zip(client_pos_datasets, client_neg_datasets)):
    pos_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/pos_{i}.parquet")
    neg_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/neg_{i}.parquet")

# 构造pos+diversity 数据集，一半 pos，一半 diversity

In [None]:
from datasets import Dataset
import torch
import heapq
from tqdm import tqdm
from sklearn.cluster import KMeans

client_pos_datasets=[]
for i, sampled_data in enumerate(local_datasets):
    print(i)
    sampled_embeddings = model.encode(sampled_data["instruction"])
    # 假设 embeddings 是你的嵌入数据
    k = 10
    kmeans = KMeans(n_clusters=k, random_state=0).fit(sampled_embeddings)
    print(kmeans)
    concated_embeddings = model.encode(public_datasets[i]["instruction"])
    clusters = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
    concated_embeddings = torch.tensor(concated_embeddings, dtype=torch.float32)
    similarity_scores = clusters @ concated_embeddings.T
    top_idxs = []
    for i in range(similarity_scores.shape[0]):
        tmp = similarity_scores[i]
        top_idxs.append(heapq.nlargest(250, range(len(tmp)), key=tmp.__getitem__))
    pos_datasets: Dataset = []
    # top_idxs去重，其余作为 diversity
    top_idxs = set(np.concatenate(top_idxs,axis=0))
    try: top_idxs.remove(len(public_datasets[i]))
    except: pass
    pos_datasets = public_datasets[i].select(top_idxs)
    print(len(top_idxs))
    # 从public_datasets[i]中去掉 top_idxs
    all_idxs = set(range(len(public_datasets[i])))
    remain_idxs = list(all_idxs-top_idxs)
    random_idxs = random.sample(remain_idxs, 5000-len(top_idxs))
    diversity_datasets = public_datasets[i].select(random_idxs)
    pos_datasets = concatenate_datasets([pos_datasets, diversity_datasets, sampled_data])
    pos_datasets = pos_datasets.shuffle(seed=42)
    client_pos_datasets.append(pos_datasets)

In [None]:
for i, pos_data in enumerate(client_pos_datasets):
    pos_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/T_{i}.parquet")

# 构造去重 pos 数据集

In [31]:
from ordered_set import OrderedSet
tmp = OrderedSet([4,5,3,7,1])
tmp1 = OrderedSet([4,5])
for t in tmp1: tmp.discard(t)
print(tmp)

OrderedSet([3, 7, 1])


In [33]:
from datasets import Dataset
import torch
import heapq
from tqdm import tqdm
from sklearn.cluster import KMeans
from ordered_set import OrderedSet

client_pos_datasets=[]
for i, sampled_data in enumerate(local_datasets):
    print(i)
    sampled_embeddings = model.encode(sampled_data["instruction"])
    # 假设 embeddings 是你的嵌入数据
    k = 10
    kmeans = KMeans(n_clusters=k, random_state=0).fit(sampled_embeddings)
    print(kmeans)
    concated_embeddings = model.encode(public_datasets[i]["instruction"])
    clusters = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
    concated_embeddings = torch.tensor(concated_embeddings, dtype=torch.float32)
    top_idxs:OrderedSet = OrderedSet()
    remain_idxs = OrderedSet(range(len(public_datasets[i])))
    for i in range(k):
        similarity_scores = clusters[i] @ concated_embeddings.T
        top_idx = list(OrderedSet(heapq.nlargest(5000, range(len(similarity_scores)), key=similarity_scores.__getitem__))-top_idxs)[:500]
        top_idxs.update(top_idx)
        print("top_idxs", len(top_idxs))
        remain_idxs.difference_update(top_idx)
        print("remain_idxs", len(remain_idxs))

    try: top_idxs.remove(len(public_datasets[i]))
    except: pass
    pos_datasets = public_datasets[i].select(list(top_idxs))
    pos_datasets = concatenate_datasets([pos_datasets, sampled_data])
    pos_datasets = pos_datasets.shuffle(seed=42)
    client_pos_datasets.append(pos_datasets)

0
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it]


top_idxs 500
remain_idxs 40132
top_idxs 1000
remain_idxs 39632
top_idxs 1500
remain_idxs 39132
top_idxs 2000
remain_idxs 38632
top_idxs 2500
remain_idxs 38132
top_idxs 3000
remain_idxs 37632
top_idxs 3500
remain_idxs 37132
top_idxs 4000
remain_idxs 36632
top_idxs 4500
remain_idxs 36132
top_idxs 5000
remain_idxs 35632
1
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it]


top_idxs 500
remain_idxs 40132
top_idxs 1000
remain_idxs 39632
top_idxs 1500
remain_idxs 39132
top_idxs 2000
remain_idxs 38632
top_idxs 2500
remain_idxs 38132
top_idxs 3000
remain_idxs 37632
top_idxs 3500
remain_idxs 37132
top_idxs 4000
remain_idxs 36632
top_idxs 4500
remain_idxs 36132
top_idxs 5000
remain_idxs 35632
2
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:26<00:00,  1.34s/it]


top_idxs 500
remain_idxs 40132
top_idxs 1000
remain_idxs 39632
top_idxs 1500
remain_idxs 39132
top_idxs 2000
remain_idxs 38632
top_idxs 2500
remain_idxs 38132
top_idxs 3000
remain_idxs 37632
top_idxs 3500
remain_idxs 37132
top_idxs 4000
remain_idxs 36632
top_idxs 4500
remain_idxs 36132
top_idxs 5000
remain_idxs 35632
3
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:29<00:00,  1.49s/it]


top_idxs 500
remain_idxs 40132
top_idxs 1000
remain_idxs 39632
top_idxs 1500
remain_idxs 39132
top_idxs 2000
remain_idxs 38632
top_idxs 2500
remain_idxs 38132
top_idxs 3000
remain_idxs 37632
top_idxs 3500
remain_idxs 37132
top_idxs 4000
remain_idxs 36632
top_idxs 4500
remain_idxs 36132
top_idxs 5000
remain_idxs 35632
4
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:29<00:00,  1.45s/it]


top_idxs 500
remain_idxs 40132
top_idxs 1000
remain_idxs 39632
top_idxs 1500
remain_idxs 39132
top_idxs 2000
remain_idxs 38632
top_idxs 2500
remain_idxs 38132
top_idxs 3000
remain_idxs 37632
top_idxs 3500
remain_idxs 37132
top_idxs 4000
remain_idxs 36632
top_idxs 4500
remain_idxs 36132
top_idxs 5000
remain_idxs 35632
5
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:29<00:00,  1.47s/it]


top_idxs 500
remain_idxs 40132
top_idxs 1000
remain_idxs 39632
top_idxs 1500
remain_idxs 39132
top_idxs 2000
remain_idxs 38632
top_idxs 2500
remain_idxs 38132
top_idxs 3000
remain_idxs 37632
top_idxs 3500
remain_idxs 37132
top_idxs 4000
remain_idxs 36632
top_idxs 4500
remain_idxs 36132
top_idxs 5000
remain_idxs 35632
6
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it]


top_idxs 500
remain_idxs 40132
top_idxs 1000
remain_idxs 39632
top_idxs 1500
remain_idxs 39132
top_idxs 2000
remain_idxs 38632
top_idxs 2500
remain_idxs 38132
top_idxs 3000
remain_idxs 37632
top_idxs 3500
remain_idxs 37132
top_idxs 4000
remain_idxs 36632
top_idxs 4500
remain_idxs 36132
top_idxs 5000
remain_idxs 35632
7
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:29<00:00,  1.49s/it]


top_idxs 500
remain_idxs 40132
top_idxs 1000
remain_idxs 39632
top_idxs 1500
remain_idxs 39132
top_idxs 2000
remain_idxs 38632
top_idxs 2500
remain_idxs 38132
top_idxs 3000
remain_idxs 37632
top_idxs 3500
remain_idxs 37132
top_idxs 4000
remain_idxs 36632
top_idxs 4500
remain_idxs 36132
top_idxs 5000
remain_idxs 35632
8
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:29<00:00,  1.48s/it]


top_idxs 500
remain_idxs 40131
top_idxs 1000
remain_idxs 39631
top_idxs 1500
remain_idxs 39131
top_idxs 2000
remain_idxs 38631
top_idxs 2500
remain_idxs 38131
top_idxs 3000
remain_idxs 37631
top_idxs 3500
remain_idxs 37131
top_idxs 4000
remain_idxs 36631
top_idxs 4500
remain_idxs 36131
top_idxs 5000
remain_idxs 35631
9
KMeans(n_clusters=10, random_state=0)


Inference Embeddings: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it]


top_idxs 500
remain_idxs 40131
top_idxs 1000
remain_idxs 39631
top_idxs 1500
remain_idxs 39131
top_idxs 2000
remain_idxs 38631
top_idxs 2500
remain_idxs 38131
top_idxs 3000
remain_idxs 37631
top_idxs 3500
remain_idxs 37131
top_idxs 4000
remain_idxs 36631
top_idxs 4500
remain_idxs 36131
top_idxs 5000
remain_idxs 35631


In [34]:
for i, pos_data in enumerate(client_pos_datasets):
    pos_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/pos_nodup_{i}.parquet")

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5099 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5099 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]