# This notebook shows the training process of the projector for FedDCA*

In [1]:
import os
import time
import torch
import random
import heapq
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from contextlib import contextmanager
from pprint import pprint as pp
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from scipy.spatial.distance import pdist, squareform
from datasets import load_dataset, concatenate_datasets, load_from_disk, Dataset
import pandas as pd
from FlagEmbedding import FlagModel
from sentence_transformers import SentenceTransformer
from sklearn.cluster import MiniBatchKMeans, KMeans

In [10]:
model_s = FlagModel('BAAI/bge-large-en-v1.5', 
                  query_instruction_for_retrieval="",
                  use_fp16=True,
                )
model_c = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', model_kwargs={"torch_dtype":torch.float16})

----------using 8*GPUs----------


In [None]:
code_data = load_dataset("sahil2801/CodeAlpaca-20k")["train"]
fin_data = load_dataset("FinGPT/fingpt-sentiment-train")["train"]
med_data = load_dataset("medalpaca/medical_meadow_medical_flashcards")["train"]
general_data = load_dataset("tatsu-lab/alpaca")["train"]
math_data = load_dataset("TIGER-Lab/MathInstruct")["train"]

def alpaca_format(example):
    if example['input'] == "":
        example["instruction"] = example["instruction"]
    else:
        example["instruction"] = example["instruction"] + " " + example['input']
    example["response"] = example['output']
    return example

def process_sft_dataset(dataset_name, dataset, dataset_sample=None) -> Dataset:
    if dataset_name in ["lucasmccabe-lmi/CodeAlpaca-20k", "yahma/alpaca-cleaned", "FinGPT/fingpt-sentiment-train"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["WizardLM/WizardLM_evol_instruct_70k"]:
        dataset = dataset.rename_column("output", "response")
    elif dataset_name in ["tatsu-lab/alpaca", "vicgalle/alpaca-gpt4", "gbharti/finance-alpaca"]:
        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f"Preprocessing {dataset_name} for unified format.")
    elif dataset_name in ["TIGER-Lab/MathInstruct"]:
        df = pd.DataFrame(dataset)
        df = df.drop_duplicates(subset=['instruction'])
        dataset = Dataset.from_pandas(df)
        dataset = dataset.rename_column("output", "response")
        dataset = dataset.remove_columns(['source'])
    elif dataset_name in ["lighteval/MATH"]:
        dataset = dataset.rename_column("solution", "response")
        dataset = dataset.rename_column("problem", "instruction")
        dataset = dataset.remove_columns(['level', 'type'])
    elif dataset_name in ['gsm8k']:
        dataset = dataset.rename_column("question", "instruction")
        dataset = dataset.rename_column("answer", "response")
    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']: 
        dataset = dataset.remove_columns(['instruction'])
        dataset = dataset.rename_column("input", "instruction")
        dataset = dataset.rename_column("output", "response")
    elif "math" in dataset_name:
        dataset = dataset.remove_columns(['source'])
        dataset = dataset.rename_column("output", "response")
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported.")
    dataset = dataset.shuffle(seed=42)
    if dataset_sample:
        num_sample = min(len(dataset), dataset_sample)
        dataset = dataset.select(range(num_sample))
    print(f">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====")
    return dataset

processed_data = []
for name, dataset in zip(["lucasmccabe-lmi/CodeAlpaca-20k","FinGPT/fingpt-sentiment-train","medalpaca/medical_meadow_medical_flashcards","tatsu-lab/alpaca","TIGER-Lab/MathInstruct"],[code_data,fin_data,med_data,general_data,math_data]):
    tmp = process_sft_dataset(name,dataset)
    processed_data.append(tmp)

# Train the projector

In [4]:
public_data = concatenate_datasets(processed_data)["instruction"]

train_data_size = 10000 # Random select 10000 public data for training.
train_data = public_data.select(random.sample(range(len(public_data)), 10000))

## Construct the train set for contrastive learning

In [None]:
embeddings_s = model_s.encode(train_data)
embeddings_s = torch.Tensor(embeddings_s)

In [6]:
pool = model_c.start_multi_process_pool()
embeddings_c = torch.tensor(model_c.encode_multi_process(train_data,pool,precision='float32'))
model_c.stop_multi_process_pool(pool)

## The model structure of the projector

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Projector(nn.Module):
    def __init__(self):
        super(Projector, self).__init__()
        self.fc1 = nn.Linear(384, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 1024)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [23]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, anchor, positive, negatives) -> torch.Tensor:
        anchor_pos_similarity = (anchor * positive).sum(dim=1) / self.temperature
        anchor_neg_similarity = (anchor.unsqueeze(1) * negatives).sum(dim=2) / self.temperature

        logits = torch.cat([anchor_pos_similarity.unsqueeze(1), anchor_neg_similarity], dim=1)
        labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

        loss = nn.functional.cross_entropy(logits, labels)
        return loss

In [15]:
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, embeddings_c, embeddings_s):
        """
        Initialize the dataset
        :param embeddings_c
        :param embeddings_s
        """
        self.embeddings_c = embeddings_c
        self.embeddings_s = embeddings_s

    def __len__(self):
        return len(self.embeddings_c)

    def __getitem__(self, idx):
        embeddings_c_sample = self.embeddings_c[idx]
        embeddings_s_sample = self.embeddings_s[idx]

        return embeddings_c_sample, embeddings_s_sample

In [16]:
dataset = CustomDataset(embeddings_c, embeddings_s)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [None]:
projector = Projector().cuda()
criterion = ContrastiveLoss(temperature=0.5)
num_epochs = 3
optimizer = torch.optim.Adam(projector.parameters(), lr=1e-4)

# Suppose the dataloader produces data in the form of (embeddings_c_batch, embeddings_s_batch)
for epoch in range(num_epochs):
    tqdm_dataloader = tqdm(enumerate(dataloader), desc=f'Epoch {epoch+1}/{num_epochs}', total=len(dataloader))
    for batch_idx, (embeddings_c_batch, embeddings_s_batch) in tqdm_dataloader:
        # Projecting embeddings_c into higher dimensional space
        embeddings_c_batch, embeddings_s_batch = embeddings_c_batch.cuda(), embeddings_s_batch.cuda()
        projected_c_batch = projector(embeddings_c_batch)
        total_loss = 0
        # Calculate the loss for each sample and add it up
        for i in range(len(embeddings_c_batch)):
            # Take the positive sample embedding of the i - th sample
            positive = embeddings_s_batch[i]
            # Take the negative sample embedding of the i-th sample, here we take the other samples in the batch except itself
            negatives = torch.stack([embeddings_s_batch[j] for j in range(len(embeddings_c_batch)) if j != i])
            loss = criterion(projected_c_batch[i].unsqueeze(0), positive.unsqueeze(0), negatives)
            total_loss += loss
        # Calculate the average loss of the batch
        batch_loss = total_loss / len(embeddings_c_batch)
        tqdm_dataloader.set_description(f'Batch loss: {batch_loss:.4f}')
        
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

In [26]:
output_path = ""
torch.save(projector.state_dict(), output_path) # Save the trained model