In [1]:
from FlagEmbedding import FlagModel
import numpy as np
from sklearn.manifold import TSNE
from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt
from pprint import pprint as pp
import time
import umap
import os
import random
import time
from contextlib import contextmanager
import torch
from sentence_transformers import SentenceTransformer

@contextmanager
def timer():
    start_time = time.time()
    try:
        yield
    finally:
        end_time = time.time()
        print(f"Elapsed time: {end_time - start_time:.4f} seconds")

In [10]:
model_s = FlagModel('BAAI/bge-large-en-v1.5', 
                  query_instruction_for_retrieval="",
                  use_fp16=True,
                )
model_c = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', model_kwargs={"torch_dtype":torch.float16})

----------using 8*GPUs----------


In [18]:
from datasets import load_dataset, load_from_disk
from datasets import load_dataset, concatenate_datasets, load_from_disk
import pandas as pd
import datasets
from datasets import Dataset
from pprint import pprint as pp
from datasets import Dataset
from sklearn.cluster import KMeans
from tqdm import tqdm
import torch
import heapq
code_data = load_dataset("sahil2801/CodeAlpaca-20k")["train"]
fin_data = load_dataset("FinGPT/fingpt-sentiment-train")["train"]
med_data = load_dataset("medalpaca/medical_meadow_medical_flashcards")["train"]
general_data = load_dataset("tatsu-lab/alpaca")["train"]
math_data = load_dataset("TIGER-Lab/MathInstruct")["train"]

def alpaca_format(example):
    if example['input'] == "":
        example["instruction"] = example["instruction"]
    else:
        example["instruction"] = example["instruction"] + " " + example['input']
    example["response"] = example['output']
    return example

def process_sft_dataset(dataset_name, dataset, dataset_sample=None)->datasets.Dataset:
    if dataset_name in ["lucasmccabe-lmi/CodeAlpaca-20k", "yahma/alpaca-cleaned", "FinGPT/fingpt-sentiment-train"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["WizardLM/WizardLM_evol_instruct_70k"]:
        dataset = dataset.rename_column("output", "response")
    elif dataset_name in ["tatsu-lab/alpaca", "vicgalle/alpaca-gpt4", "gbharti/finance-alpaca"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["TIGER-Lab/MathInstruct"]:
        df = pd.DataFrame(dataset)
        df = df.drop_duplicates(subset=['instruction'])
        dataset = datasets.Dataset.from_pandas(df)
        # dataset = dataset.shuffle(seed=42).select(range(51000))
        dataset = dataset.rename_column("output", "response")
        dataset = dataset.remove_columns(['source'])
    elif dataset_name in ["lighteval/MATH"]:
        dataset = dataset.rename_column("solution", "response")
        dataset = dataset.rename_column("problem", "instruction")
        dataset = dataset.remove_columns(['level', 'type'])
    elif dataset_name in ['gsm8k']:
        dataset = dataset.rename_column("question", "instruction")
        dataset = dataset.rename_column("answer", "response")
    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']:       # TODO: 'lavita/ChatDoctor-HealthCareMagic-100k'. not sure whether to discard the instruction.
        dataset = dataset.remove_columns(['instruction'])
        dataset = dataset.rename_column("input", "instruction")
        dataset = dataset.rename_column("output", "response")
    elif "math" in dataset_name:
        dataset = dataset.remove_columns(['source'])
        dataset = dataset.rename_column("output", "response")
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported.")
    dataset = dataset.shuffle(seed=42)
    if dataset_sample:
        num_sample = min(len(dataset), dataset_sample)
        dataset = dataset.select(range(num_sample))
    print(f">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====")
    return dataset

processed_data = []
# 这块一定要注意!!! name 和datasest的顺序都要改
for name, dataset in zip(["TIGER-Lab/MathInstruct","FinGPT/fingpt-sentiment-train","medalpaca/medical_meadow_medical_flashcards","lucasmccabe-lmi/CodeAlpaca-20k","tatsu-lab/alpaca",],[math_data,fin_data,med_data,code_data,general_data]):
# for name, dataset in zip(["lucasmccabe-lmi/CodeAlpaca-20k","FinGPT/fingpt-sentiment-train","medalpaca/medical_meadow_medical_flashcards", "TIGER-Lab/MathInstruct"],[code_data,fin_data,med_data,math_data]):
    tmp:datasets.Dataset = process_sft_dataset(name,dataset)
    # if "fin" in name: 
    #     tmp = tmp.shuffle(seed=42).select(range(51000))
    print(tmp.column_names)
    processed_data.append(tmp)

>> ===== After processing, Dataset TIGER-Lab/MathInstruct has 224567 examples. =====
['response', 'instruction', '__index_level_0__']
>> ===== After processing, Dataset FinGPT/fingpt-sentiment-train has 76772 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset medalpaca/medical_meadow_medical_flashcards has 33955 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset lucasmccabe-lmi/CodeAlpaca-20k has 20022 examples. =====
['instruction', 'response']
>> ===== After processing, Dataset tatsu-lab/alpaca has 52002 examples. =====
['instruction', 'response']


# 训练projector

In [4]:
data_concated = concatenate_datasets(processed_data)["instruction"]

In [5]:
embeddings_s = model_s.encode(data_concated)

Inference Embeddings: 100%|██████████| 796/796 [11:08<00:00,  1.19it/s]


In [6]:
pool = model_c.start_multi_process_pool()
embeddings_c = torch.tensor(model_c.encode_multi_process(data_concated,pool,precision='float32'))
model_c.stop_multi_process_pool(pool)

In [14]:
print(embeddings_c[0].shape)
embeddings_s = torch.Tensor(embeddings_s)
print(embeddings_s[0].shape)

torch.Size([384])
torch.Size([1024])


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Projector(nn.Module):
    def __init__(self):
        super(Projector, self).__init__()
        self.fc1 = nn.Linear(384, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 1024)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [23]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, anchor, positive, negatives) -> torch.Tensor:
        anchor_pos_similarity = (anchor * positive).sum(dim=1) / self.temperature
        anchor_neg_similarity = (anchor.unsqueeze(1) * negatives).sum(dim=2) / self.temperature

        logits = torch.cat([anchor_pos_similarity.unsqueeze(1), anchor_neg_similarity], dim=1)
        labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

        loss = nn.functional.cross_entropy(logits, labels)
        return loss

In [15]:
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, embeddings_c, embeddings_s):
        """
        初始化数据集。
        :param embeddings_c: 384维的embeddings。
        :param embeddings_s: 1024维的embeddings，与embeddings_c一一对应。
        """
        self.embeddings_c = embeddings_c
        self.embeddings_s = embeddings_s

    def __len__(self):
        """
        返回数据集中样本的数量。
        """
        return len(self.embeddings_c)

    def __getitem__(self, idx):
        """
        根据索引idx获取一个样本。
        """
        # 获取对应的embeddings_c和embeddings_s
        embeddings_c_sample = self.embeddings_c[idx]
        embeddings_s_sample = self.embeddings_s[idx]

        return embeddings_c_sample, embeddings_s_sample

In [16]:
# 创建数据集实例
dataset = CustomDataset(embeddings_c, embeddings_s)
# 创建DataLoader实例
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [25]:
# 实例化投影器和损失函数
projector = Projector().cuda()
criterion = ContrastiveLoss(temperature=0.5)
num_epochs = 3
# 选择优化器
optimizer = torch.optim.Adam(projector.parameters(), lr=1e-4)

# 假设 dataloader 产生 (embeddings_c_batch, embeddings_s_batch) 形式的数据
for epoch in range(num_epochs):
    tqdm_dataloader = tqdm(enumerate(dataloader), desc=f'Epoch {epoch+1}/{num_epochs}', total=len(dataloader))
    for batch_idx, (embeddings_c_batch, embeddings_s_batch) in tqdm_dataloader:
        # 将 embeddings_c 投影到更高维度的空间
        embeddings_c_batch, embeddings_s_batch = embeddings_c_batch.cuda(), embeddings_s_batch.cuda()
        projected_c_batch = projector(embeddings_c_batch)
        # 初始化总损失
        total_loss = 0
        # 计算每个样本的损失并累加
        for i in range(len(embeddings_c_batch)):
            # 取出第 i 个样本的正样本嵌入
            positive = embeddings_s_batch[i]
            # 取出第 i 个样本的负样本嵌入，这里我们取批次中除了自身之外的其他样本
            negatives = torch.stack([embeddings_s_batch[j] for j in range(len(embeddings_c_batch)) if j != i])
            # 计算损失
            loss = criterion(projected_c_batch[i].unsqueeze(0), positive.unsqueeze(0), negatives)
            # 累加损失
            total_loss += loss
        # 计算批次的平均损失
        batch_loss = total_loss / len(embeddings_c_batch)
        tqdm_dataloader.set_description(f'Batch loss: {batch_loss:.4f}')
        # 反向传播和优化
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

Batch loss: 0.0251: 100%|██████████| 12729/12729 [03:51<00:00, 54.96it/s]
Batch loss: 0.0185: 100%|██████████| 12729/12729 [03:49<00:00, 55.56it/s]
Batch loss: 0.0015: 100%|██████████| 12729/12729 [03:26<00:00, 61.53it/s]


In [26]:
torch.save(projector.state_dict(), '/mnt/bn/data-tns-live-llm/leon/datasets/projector_model.pth')

# 开始检索

In [15]:
projector = Projector().cuda()
projector.load_state_dict(torch.load('/mnt/bn/data-tns-live-llm/leon/datasets/projector_model.pth'))

  projector.load_state_dict(torch.load('/mnt/bn/data-tns-live-llm/leon/datasets/projector_model.pth'))


<All keys matched successfully>

In [19]:
data_concated: Dataset = processed_data[0]
random.seed(42)
iid_idxs = random.sample(range(len(data_concated)), 1000)
base_data = data_concated.select(iid_idxs)
clients_data = []
for i in range(10):
    clients_data.append(base_data.shard(10,i))

data_concated = data_concated.select(list(set(range(len(data_concated)))-set(iid_idxs)))
print(len(data_concated))

223567


In [20]:
k=10
from sklearn.cluster import MiniBatchKMeans, KMeans
base_0_embeddings = model_c.encode(clients_data[0]["instruction"])
# 假设 embeddings 是你的嵌入数据
kmeans = KMeans(n_clusters=k, random_state=0).fit(base_0_embeddings)
labels = kmeans.labels_
# 计算每个簇的样本数量
counts = np.bincount(labels)
# 找到最大的簇的标签
largest_cluster_label = np.argmax(counts)
# 从 cluster_centers_ 中获取最大的簇的中心
cluster_center_0:np.array = kmeans.cluster_centers_[largest_cluster_label]
print(cluster_center_0.shape)
client_clusters = cluster_center_0.reshape((1,-1))

(384,)


In [21]:
for i in range(10-1):
    i=i+1
    base_i_embeddings = model_c.encode(clients_data[i]["instruction"])
    # 假设 embeddings 是你的嵌入数据
    kmeans = KMeans(n_clusters=10, random_state=0).fit(base_i_embeddings)
    labels = kmeans.labels_
    similarity_scores = np.sum(kmeans.cluster_centers_ @ client_clusters.T, axis=-1)
    print(similarity_scores.shape)
    selected_idxs = np.argsort(similarity_scores)[i:]      
    # 计算每个簇的样本数量
    counts = np.bincount(labels)
    # 找到最大的簇的标签
    largest_cluster_labels = np.argsort(-counts) #降序
    largest_cluster_label = -1
    for j in largest_cluster_labels:
        if j in selected_idxs:
            largest_cluster_label = j
    # 从 cluster_centers_ 中获取最大的簇的中心
    largest_cluster_center = kmeans.cluster_centers_[largest_cluster_label]
    client_clusters = np.concatenate([client_clusters,largest_cluster_center.reshape((1,-1))])

(10,)
(10,)
(10,)
(10,)
(10,)
(10,)
(10,)
(10,)
(10,)


In [22]:
data_concated: Dataset = concatenate_datasets(processed_data)
data_concated = data_concated.select(list(set(range(len(data_concated)))-set(iid_idxs)))
print(len(data_concated))
concated_embeddings = model_s.encode(data_concated["instruction"])
concated_embeddings = torch.tensor(concated_embeddings, dtype=torch.float32)
client_clusters = torch.tensor(client_clusters, dtype=torch.float32)

406318


  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
Inference Embeddings: 100%|██████████| 199/199 [06:40<00:00,  2.01s/it]


In [23]:
client_clusters = projector(client_clusters.cuda())
client_clusters.shape

  client_clusters = torch.tensor(client_clusters, dtype=torch.float32)


torch.Size([10, 1024])

In [24]:
import numpy as np
import random
retrival_nums = [5000]
domain = "math"
for retrival_num in retrival_nums:
    client_pos_datasets = []
    for i, sampled_data in enumerate(clients_data):
        print(i)
        similarity_scores = torch.matmul(client_clusters[i,:].cuda(), (concated_embeddings.T).cuda()).cpu()
        # filter
        filtered_scores = [(score.item(), idx) for idx, score in enumerate(similarity_scores) if score < 0.7]
        top_idxs = heapq.nlargest(retrival_num, range(len(filtered_scores)-1), key=lambda x:filtered_scores[x])
        # no filter
        # top_idxs = heapq.nlargest(5000, range(len(similarity_scores)-1), key=lambda x: similarity_scores[x])
        pos_datasets: Dataset = []
        pos_datasets = data_concated.select(top_idxs)
        pos_datasets = concatenate_datasets([pos_datasets, sampled_data])
        pos_datasets = pos_datasets.shuffle(seed=42)
        client_pos_datasets.append(pos_datasets)
        
    for i, pos_data in enumerate(client_pos_datasets):
        pos_data.save_to_disk(f"/mnt/bn/data-tns-live-llm/leon/datasets/fed_data/iid2niid_{domain}_{retrival_num}_projector_{i}.parquet")

0
1
2
3
4
5
6
7
8
9


Saving the dataset (0/1 shards):   0%|          | 0/119 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/210 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/149 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1817 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/219 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/626 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/108 [00:00<?, ? examples/s]