# This notebook shows the computation of domain coverage

In [None]:
import os
import time
import torch
import random
import heapq
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from contextlib import contextmanager
from pprint import pprint as pp
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from scipy.spatial.distance import pdist, squareform
from datasets import load_dataset, concatenate_datasets, load_from_disk, Dataset
import pandas as pd
from FlagEmbedding import FlagModel
from sentence_transformers import SentenceTransformer
from sklearn.cluster import MiniBatchKMeans, KMeans

In [None]:
code_data = load_dataset("sahil2801/CodeAlpaca-20k")["train"]
fin_data = load_dataset("FinGPT/fingpt-sentiment-train")["train"]
med_data = load_dataset("medalpaca/medical_meadow_medical_flashcards")["train"]
general_data = load_dataset("tatsu-lab/alpaca")["train"]
math_data = load_dataset("TIGER-Lab/MathInstruct")["train"]

def alpaca_format(example):
    if example['input'] == "":
        example["instruction"] = example["instruction"]
    else:
        example["instruction"] = example["instruction"] + " " + example['input']
    example["response"] = example['output']
    return example

def process_sft_dataset(dataset_name, dataset, dataset_sample=None) -> Dataset:
    if dataset_name in ["lucasmccabe-lmi/CodeAlpaca-20k", "yahma/alpaca-cleaned", "FinGPT/fingpt-sentiment-train"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["WizardLM/WizardLM_evol_instruct_70k"]:
        dataset = dataset.rename_column("output", "response")
    elif dataset_name in ["tatsu-lab/alpaca", "vicgalle/alpaca-gpt4", "gbharti/finance-alpaca"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["TIGER-Lab/MathInstruct"]:
        df = pd.DataFrame(dataset)
        df = df.drop_duplicates(subset=['instruction'])
        dataset = Dataset.from_pandas(df)
        dataset = dataset.rename_column("output", "response")
        dataset = dataset.remove_columns(['source'])
    elif dataset_name in ["lighteval/MATH"]:
        dataset = dataset.rename_column("solution", "response")
        dataset = dataset.rename_column("problem", "instruction")
        dataset = dataset.remove_columns(['level', 'type'])
    elif dataset_name in ['gsm8k']:
        dataset = dataset.rename_column("question", "instruction")
        dataset = dataset.rename_column("answer", "response")
    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']: 
        dataset = dataset.remove_columns(['instruction'])
        dataset = dataset.rename_column("input", "instruction")
        dataset = dataset.rename_column("output", "response")
    elif "math" in dataset_name:
        dataset = dataset.remove_columns(['source'])
        dataset = dataset.rename_column("output", "response")
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported.")
    dataset = dataset.shuffle(seed=42)
    if dataset_sample:
        num_sample = min(len(dataset), dataset_sample)
        dataset = dataset.select(range(num_sample))
    print(f">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====")
    return dataset

processed_data = []
for name, dataset in zip(["lucasmccabe-lmi/CodeAlpaca-20k","FinGPT/fingpt-sentiment-train","medalpaca/medical_meadow_medical_flashcards","tatsu-lab/alpaca","TIGER-Lab/MathInstruct"],[code_data,fin_data,med_data,general_data,math_data]):
    tmp = process_sft_dataset(name,dataset)
    processed_data.append(tmp)
    
public_data = concatenate_datasets(processed_data)

In [None]:
model = FlagModel('BAAI/bge-large-en-v1.5', 
                  query_instruction_for_retrieval="",
                  use_fp16=True)

In [None]:
import numpy as np
import torch
from torch.nn.functional import cosine_similarity as cosine_similarity

def coverage(A, V):
    A_tensor = torch.tensor(A, dtype=torch.float32)
    V_tensor = torch.tensor(V, dtype=torch.float32)
    # Calculate the domain coverage of set A
    similarities = torch.matmul(V_tensor, A_tensor.T)
    # Calculate the maximum similarity for each v in V
    max_similarities = torch.max(similarities, dim=1).values
    # Sum the similarity
    total_similarity = torch.sum(max_similarities).item()/len(max_similarities)
    return total_similarity

In [None]:
public_data_embeddings = model.encode(public_data["instruction"])

In [None]:
code_embeddings = public_data_embeddings[:20022]
med_embeddings = public_data_embeddings[96794:130749]
fin_embeddings = public_data_embeddings[20022:96794]
math_embeddings = public_data_embeddings[182751:]

In [None]:
settings = ["code_5000"] 
root = ""
domain = "code"

for setting in tqdm(settings):
    cross_client_datas = []
    for i in range(10):
        cross_client_datas.append(load_from_disk(f"{root}/{setting}_{i}.parquet"))
    cross_client_datas = concatenate_datasets(cross_client_datas)
    cross_client_datas = cross_client_datas.filter(lambda example: example['label'] == domain) # Filter out-of-domain data.
    datas_embeddings = model.encode(cross_client_datas["instruction"])

    if "code" in setting: domain_embeddings = code_embeddings
    elif "med" in setting: domain_embeddings = med_embeddings
    elif "fin" in setting: domain_embeddings = fin_embeddings
    else: domain_embeddings = math_embeddings

    domain_coverage = coverage(datas_embeddings, domain_embeddings)