## This notebook is a simulation of implementing FedDCA in the financial domain, where client num is 10, each client's base data size is 100, and each client's retrieval num is 5000. I hope this notebook will help you gain a deeper understanding of FedDCA!

<p align="center">
  <img src="./overview.png" alt="" width="1000">
</p>

In [17]:
import os
import time
import torch
import random
import heapq
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from contextlib import contextmanager
from pprint import pprint as pp
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from scipy.spatial.distance import pdist, squareform
import datasets
from datasets import load_dataset, concatenate_datasets, load_from_disk, Dataset
import pandas as pd
from FlagEmbedding import FlagModel
from sentence_transformers import SentenceTransformer
from sklearn.cluster import MiniBatchKMeans, KMeans

# Load pretrained encoder

In [5]:
model = FlagModel('BAAI/bge-large-en-v1.5', 
                  query_instruction_for_retrieval="",
                  use_fp16=True,
                )
# model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', model_kwargs={"torch_dtype":torch.bfloat16})

# Composition of public datasets(code, medical, financial, mathematical, general)

In [None]:
code_data = load_dataset("sahil2801/CodeAlpaca-20k")["train"]
fin_data = load_dataset("FinGPT/fingpt-sentiment-train")["train"]
med_data = load_dataset("medalpaca/medical_meadow_medical_flashcards")["train"]
general_data = load_dataset("tatsu-lab/alpaca")["train"]
math_data = load_dataset("TIGER-Lab/MathInstruct")["train"]

def alpaca_format(example):
    if example['input'] == "":
        example["instruction"] = example["instruction"]
    else:
        example["instruction"] = example["instruction"] + " " + example['input']
    example["response"] = example['output']
    return example

def process_sft_dataset(dataset_name, dataset, dataset_sample=None)->datasets.Dataset:
    if dataset_name in ["lucasmccabe-lmi/CodeAlpaca-20k", "yahma/alpaca-cleaned", "FinGPT/fingpt-sentiment-train"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["sujet-ai/Sujet-Finance-Instruct-177k"]:
        dataset = dataset.filter(lambda example: example['task_type'] == "qa")
        dataset = dataset.rename_column("inputs", "instruction")
        dataset = dataset.rename_column("answer", "response")
    elif dataset_name in ["WizardLM/WizardLM_evol_instruct_70k"]:
        dataset = dataset.rename_column("output", "response")
    elif dataset_name in ["tatsu-lab/alpaca", "vicgalle/alpaca-gpt4", "gbharti/finance-alpaca"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["TIGER-Lab/MathInstruct"]:
        df = pd.DataFrame(dataset)
        df = df.drop_duplicates(subset=['instruction'])
        dataset = datasets.Dataset.from_pandas(df)
        dataset = dataset.rename_column("output", "response")
        dataset = dataset.remove_columns(['source'])
    elif dataset_name in ["lighteval/MATH"]:
        dataset = dataset.rename_column("solution", "response")
        dataset = dataset.rename_column("problem", "instruction")
        dataset = dataset.remove_columns(['level', 'type'])
    elif dataset_name in ['gsm8k']:
        dataset = dataset.rename_column("question", "instruction")
        dataset = dataset.rename_column("answer", "response")
    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']:     
        dataset = dataset.remove_columns(['instruction'])
        dataset = dataset.rename_column("input", "instruction")
        dataset = dataset.rename_column("output", "response")
    elif "math" in dataset_name:
        dataset = dataset.remove_columns(['source'])
        dataset = dataset.rename_column("output", "response")
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported.")
    dataset = dataset.shuffle(seed=42)
    if dataset_sample:
        num_sample = min(len(dataset), dataset_sample)
        dataset = dataset.select(range(num_sample))
    print(f">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====")
    return dataset

In [None]:
processed_data = []
for name, dataset in zip(["FinGPT/fingpt-sentiment-train","lucasmccabe-lmi/CodeAlpaca-20k","TIGER-Lab/MathInstruct","medalpaca/medical_meadow_medical_flashcards","tatsu-lab/alpaca",],[fin_data,code_data,math_data,med_data,general_data]):
    tmp = process_sft_dataset(name, dataset)
    print(tmp.column_names)
    processed_data.append(tmp)

In [None]:
def labeling(example, label):
    example["label"] = label
    return example

label = ["fin","code","math","med","gen",] # Label each data's domain for the later domain coverage computation.
for i, data in enumerate(processed_data):
    data = data.map(lambda example: labeling(example, label[i]), batched=False)
    processed_data[i] = data

In [8]:
in_domain_data = processed_data[0] # The first processed_data is financial domain.

In [None]:
random.seed(42)
selected_idxs = random.sample(range(len(in_domain_data)), 1000)
base_data = in_domain_data.select(selected_idxs)

# Construct each client's base data.
clients_data = []
for i in range(10):
    clients_data.append(base_data.shard(10,i))

# Use the remaining in-domain data as part of the public data.
in_domain_data = in_domain_data.select(list(set(range(len(in_domain_data)))-set(selected_idxs)))
print(len(in_domain_data))

# Start greedy client center selection in FedDCA!

## Perform k-means clustering for client 0

In [None]:
k=10 

base_0_embeddings = model.encode(clients_data[0]["instruction"])

kmeans = KMeans(n_clusters=k, random_state=0).fit(base_0_embeddings)
labels = kmeans.labels_

# Computing each cluster's size.
counts = np.bincount(labels)

# Find the center of the biggest cluster.
largest_cluster_label = np.argmax(counts)
cluster_center_0 = kmeans.cluster_centers_[largest_cluster_label]
print(cluster_center_0.shape)
client_centers = cluster_center_0.reshape((1,-1))

## Perform greedy client center selection

<p align="center">
  <img src="./greedy_client_center_selection.png" alt="" width="1000">
</p>

In [None]:
for i in range(10-1):
    i=i+1
    base_i_embeddings = model.encode(clients_data[i]["instruction"])
    kmeans = KMeans(n_clusters=k, random_state=0).fit(base_i_embeddings)
    labels = kmeans.labels_
    similarity_scores = np.sum(kmeans.cluster_centers_ @ client_centers.T, axis=-1)
    selected_idxs = np.argsort(similarity_scores)[:10-i]      
    counts = np.bincount(labels)
    
    """
    We consider maximizing the domain coverage from two aspects: 
    1) Select a client center that can represent the distribution of the local data. 
    2) To optimize the cross-client domain coverage, we filter client centers that are close to the previously selected client centers.
    """
    # Sorting based on the size of the cluster.
    largest_cluster_labels = np.argsort(-counts) # Descending order.
    largest_cluster_label = -1
    for j in largest_cluster_labels:
        if j in selected_idxs:
            largest_cluster_label = j
            break
    
    selected_cluster_center = kmeans.cluster_centers_[largest_cluster_label]
    client_centers = np.concatenate([client_centers, selected_cluster_center.reshape((1,-1))])
client_centers = torch.tensor(client_centers, dtype=torch.float32)

# Construct the public data

In [14]:
public_data = concatenate_datasets(processed_data)
public_data = public_data.select(list(set(range(len(public_data)))-set(selected_idxs))) # Filter the selected data for local datasets.
print(len(public_data))

369347


In [None]:
public_embeddings = model.encode(public_data["instruction"])
public_embeddings = torch.tensor(public_embeddings, dtype=torch.float32)

# Perform dense retrieval

In [None]:
import numpy as np
import random
retrival_nums = [5000] # Retrieval num
domain = "fin"
root = "" # Config your own root path for save dataset.

for retrival_num in retrival_nums:
    client_datasets = []
    for i, sampled_data in enumerate(clients_data):
        similarity_scores = torch.matmul(client_centers[i,:].cuda(), (public_embeddings.T).cuda()).cpu()
        # Filter public data with client center similarity greater than or equal to 0.7
        filtered_scores = [(score.item(), idx) for idx, score in enumerate(similarity_scores) if score < 0.7]
        top_idxs = heapq.nlargest(retrival_num, range(len(filtered_scores)-1), key=lambda x:filtered_scores[x])
        clinet_data = public_data.select(top_idxs)
        clinet_data = concatenate_datasets([sampled_data, clinet_data])
        client_datasets.append(clinet_data)
        
    for i, clinet_data in enumerate(client_datasets): 
        clinet_data.save_to_disk(f"{root}/{domain}_{retrival_num}_{i}.parquet")

# Visualize similarity score distribution of dense retrieval

In [None]:
import numpy as np
import random
random.seed(10)

fig, axes = plt.subplots(5, 2, figsize=(15, 25)) 

axes = axes.flatten()  

for i, sampled_data in enumerate(clients_data):
    print(i)
    
    similarity_scores = torch.matmul(torch.tensor(client_centers[i,:]).cuda(), torch.tensor(public_embeddings.T).cuda()).cpu()
    import numpy as np
    import matplotlib.pyplot as plt

    # Assuming similarity_scores is a 1D tensor, convert it to a NumPy array if needed
    if hasattr(similarity_scores, 'numpy'):
        similarity_scores = similarity_scores.numpy()

    # Create a histogram with bins of size 0.1
    bins = np.arange(0, 1.1, 0.05)  # Bins from 0 to 1 with step of 0.1
    hist, bin_edges = np.histogram(similarity_scores, bins=bins)

    for j in range(len(hist)):
        print(f'Range [{bin_edges[j]}, {bin_edges[j + 1]}): {hist[j]}')

    
    ax = axes[i]
    ax.hist(similarity_scores, bins=bins, edgecolor='black')
    ax.set_xlabel('Similarity Score')
    ax.set_ylabel('Frequency')
    ax.set_title(f'Histogram of Similarity Scores {i + 1}')

plt.tight_layout()
plt.show()