# Data selection based on OT gradients
- **This notebook uses 'cola' sub-task from GLUE benchmark for demonstration. The process for other tasks are essentially the same.**

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np 
import pandas as pd 

import torch
print(torch.cuda.device_count())

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torch.utils.data import random_split, DataLoader

import matplotlib.pyplot as plt
%matplotlib inline


Load samples from selected domains

In [None]:
import pickle

with open('./llm_datasets/amz_0to80.pkl', 'rb') as file:
    # Load the object from the file
    amz_full = pickle.load(file)

with open('./llm_datasets/news_full_0to80.pkl', 'rb') as file:
    # Load the object from the file
    news_full = pickle.load(file)

with open('./llm_datasets/real_books_2M.pkl', 'rb') as file:
    # Load the object from the file
    realb_full = pickle.load(file)

with open('./llm_datasets/ready1k_owtc_3M.pkl', 'rb') as file:
    # Load the object from the file
    owtc_full = pickle.load(file)

Construct a candidate dataset of 5M samples.

In [None]:
import numpy as np

sp_4d = np.concatenate([amz_full[:1500000], realb_full[:1000000], news_full[:1500000], owtc_full[:1000000]], axis = 0)
sp_4d.shape

Samples are processed in the same way as detailed in the pre-processing notebook. Here we directly load the embeddings.

In [None]:
with open('./embeds_news_f.pkl', 'rb') as file:
    # Load the object from the file
    embeds_news = pickle.load(file)
    
with open('./embeds_owtc_f.pkl', 'rb') as file:
    # Load the object from the file
    embeds_owtc = pickle.load(file)
    
with open('./embeds_amz_f.pkl', 'rb') as file:
    # Load the object from the file
    embeds_amz = pickle.load(file)

with open('./embeds_book1.pkl', 'rb') as file:
    # Load the object from the file
    embeds_book1 = pickle.load(file)

embeds_4d = np.concatenate([embeds_amz[:1500000], embeds_book1[:1500000], embeds_news[:1000000], embeds_owtc[:1000000]], axis = 0)
embeds_4d.shape

Load 'jax' for GPU computation of OT using 'ott-jax' package.

In [None]:
import jax
import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

print(jax.numpy.ones(3).device()) # TFRT_CPU_0

import matplotlib.pyplot as plt
import jax.numpy as jnp

import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.tools import plot

In [None]:
task = 'cola'

**Solve the OT problem between the embeddings of target task data and candidate data.** We use the 'batch-wise' method to deal with the high memory demand from the problem size. We use 'momentum acceleration' and 'entropy regularization' to speed up the solution process while maintaining its numerical stability. The parameters for these two techniques need to be tuned together to achieve optimal performance.

In [None]:
# from ott.solvers.linear import sinkhorn_lr
import tqdm
from ott import utils
from ott.solvers.linear import acceleration

batch_size = 2000

geom = pointcloud.PointCloud(np.array(cola_1000), np.array(embeds_4d), epsilon=1e-1, batch_size = batch_size)
ot_prob = linear_problem.LinearProblem(geom)

with tqdm.tqdm() as pbar:
    progress_fn = utils.tqdm_progress_fn(pbar)
    solver = sinkhorn.Sinkhorn(progress_fn=progress_fn, momentum=acceleration.Momentum(value=1.3), threshold = 1e-1, inner_iterations=1, max_iterations = 2000)
    # solver = sinkhorn.Sinkhorn(progress_fn=progress_fn, inner_iterations=1, max_iterations = 200000)
    # momentum=None
    ot_u = jax.jit(solver)(ot_prob)

print(f"Converged: {ot_u.converged}, cost: {ot_u.reg_ot_cost}")

    # transp_cost = ot_lr.compute_reg_ot_cost(ot_prob, use_danskin=True)
transp_cost = ot_u.reg_ot_cost
transp_cost

with open('otg_' + task + '.pkl', 'wb') as file:
    pickle.dump(ot_u.g, file, protocol=4)
    #     pickle.dump(embeddings_tensor, file, protocol=4)

Compute the calibrated gradient based on [LAVA: Data Valuation without Pre-Specified Learning Algorithms, ICLR 2023].

In [None]:
gsP = np.array(ot_u.g)
mean_all = np.sum(gsP) / len(gsP)
gsP = gsP - (mean_all - gsP / len(gsP))
print(gsP)

Rank all the candidate samples based on the gradients.

In [None]:
import pandas as pd

g_series = pd.Series(gsP)
ranked_indices = g_series.rank(ascending=True).argsort()
ranked_indices

Show a few samples and save the selected data. Then the process is complete.

In [None]:
sp_4d[ranked_indices[0:5]]

In [None]:
with open('cola_ot_5m_150k.pkl', 'wb') as file:
    pickle.dump(sp_4d[ranked_indices[:150000]], file, protocol=4)
    #     pickle.dump(embeddings_tensor, file, protocol=4)

with open('cola_ot_5m_300k.pkl', 'wb') as file:
    pickle.dump(sp_4d[ranked_indices[:300000]], file, protocol=4)
    #     pickle.dump(embeddings_tensor, file, protocol=4)

with open('cola_ot_5m_500k.pkl', 'wb') as file:
    pickle.dump(sp_4d[ranked_indices[:500000]], file, protocol=4)
    #     pickle.dump(embeddings_tensor, file, protocol=4)