# Data pre-processing and domain similarity tests
- **This notebook uses 'cola' sub-task from GLUE benchmark for demonstration. The process for other tasks are essentially the same.**

**Pre-processing** is first performed for both target task samples and candidate samples where we normalize all samples to a fixed length of 1000 characters to avoid different padding patterns affect the analysis on distributional distances.

- For samples with lengths much shorter than 1000 characters (e.g., training data for 'cola'), we concatenate multiple samples to reach 1000 characters; for samples much longer than 1000 characters (e.g., scientific papers), we split each of the original samples to multiple samples of 1000 characters.

- Then, we tokenize the processed samples using BERT tokenizers and embed the tokens using distilledBERT fine-tuned on the target task.

For **domain similarity tests**, we randomly sample 20k samples from each of the 7 domains in the candidate data ['amz', 'wiki', 'news', 'pubmed', 'arxiv', 'book1', 'owtc']. We then tokenize and embed these samples in the same way.

- To analyze the domain similarity, we compute the OT distance between the embeddings of target task samples and the embeddings of samples from each domain.

- Then, we select 2~4 domains with the smallest OT distances to the target task data and use samples from these domains as the candidate data for further selection. Domains with large OT distances will be discarded for this task.


In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np 
import pandas as pd 

import torch
print(torch.cuda.device_count())

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torch.utils.data import random_split, DataLoader

import matplotlib.pyplot as plt
%matplotlib inline


**sample pre-processing**

In [None]:
from datasets import load_dataset

# Load the CoLA dataset
dataset = load_dataset("glue", "cola")

# Print the train samples
print(dataset["train"])

In [None]:
cola_sp = np.array(dataset["train"]['sentence'])
cola_sp.shape

In [None]:
import pickle
from copy import deepcopy
import numpy as np

len1k = []
current_text = ''
for text in cola_sp:
    current_text = current_text + text
    if len(current_text) >= 1000:
        len1k.append(current_text[:1000])
        current_text = ''
        
len(len1k)

In [None]:
cola_1000 = len1k

tokenizing use BERT tokenizer

In [None]:
from transformers import BertTokenizer

# Load the BERT tokenizer.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokens_cola_1000 = []

from tqdm import tqdm
for text in tqdm(np.concatenate(cola_1000, axis=0).tolist(), desc="Processing sentences"):
    tokens = tokenizer.encode_plus(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=295, return_tensors='pt')
    # tokens_tensor = torch.from_numpy(tokens)
    # tokens_tensor = tokens_tensor.to('cuda')
    tokens = {key: value.to(device) for key, value in tokens.items()}
    tokens_cola_raw.append(tokens)

with open('tokens_cola_1000.pkl', 'wb') as file:
    pickle.dump(tokens_cola_1000, file, protocol=4)

embedding with distilledBERT fine-tuned on the target task

In [None]:
from torch.utils.data import DataLoader, TensorDataset

input_ids = torch.stack([item['input_ids'] for item in tokens_cola_1000]).squeeze()
attention_mask = torch.stack([item['attention_mask'] for item in tokens_cola_1000]).squeeze()

dataset = TensorDataset(input_ids, attention_mask)
dataloader = DataLoader(dataset, batch_size=512, shuffle=False)

from transformers import DistilBertConfig, DistilBertModel


# Load the configuration from a json file (if you have it)
config = DistilBertConfig.from_json_file('./output10e5/config.json')
# Load the model weights from the .bin file
model = DistilBertModel.from_pretrained('./output10e5/pytorch_model.bin', config=config)
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
model = model.to(device)
model.eval()

embeddings_list = []

from tqdm import tqdm

for batch in tqdm(dataloader, desc="Processing sentences"):
    batch_input_ids, batch_attention_mask = batch

    with torch.no_grad():
        outputs = model(input_ids=batch_input_ids.to(device), 
                        attention_mask=batch_attention_mask.to(device))

    # embeddings = outputs.last_hidden_state.cpu().detach().numpy()
    embeddings = outputs.last_hidden_state.cpu().detach().numpy()
    embeddings_list.append(np.mean(embeddings, axis=1))

# Concatenate all the embeddings if needed
# embeddings_tensor = torch.cat(embeddings_list, dim=0)
embeddings_tensor = np.concatenate(embeddings_list, axis=0)

with open('embeds_cola_1000.pkl', 'wb') as file:
    pickle.dump(embeddings_tensor, file, protocol=4)
    #     pickle.dump(embeddings_tensor, file, protocol=4)
    

'rand7_20k' contains 20k samples from each of the 7 domains: ['amz', 'wiki', 'news', 'pubmed', 'arxiv', 'book1', 'owtc'].

Samples are processed in the same way as cola to have a fixed length of 1000 characters.

In [None]:
from transformers import BertTokenizer

# Load the BERT tokenizer.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokens_rand7_20k = []

from tqdm import tqdm
for text in tqdm(rand7_20k.tolist(), desc="Processing sentences"):
    tokens = tokenizer.encode_plus(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=295, return_tensors='pt')
    # tokens_tensor = torch.from_numpy(tokens)
    # tokens_tensor = tokens_tensor.to('cuda')
    tokens = {key: value.to(device) for key, value in tokens.items()}
    tokens_rand7_20k.append(tokens)
    
with open('tokens_rand7_20k.pkl', 'wb') as file:
    pickle.dump(tokens_rand7_20k, file, protocol=4)

In [None]:
from torch.utils.data import DataLoader, TensorDataset

input_ids = torch.stack([item['input_ids'] for item in tokens_rand7_20k]).squeeze()
attention_mask = torch.stack([item['attention_mask'] for item in tokens_rand7_20k]).squeeze()

dataset = TensorDataset(input_ids, attention_mask)
dataloader = DataLoader(dataset, batch_size=512, shuffle=False)

from transformers import DistilBertConfig, DistilBertModel


# Load the configuration from a json file (if you have it)
config = DistilBertConfig.from_json_file('./output10e5/config.json')
# Load the model weights from the .bin file
model = DistilBertModel.from_pretrained('./output10e5/pytorch_model.bin', config=config)
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
model = model.to(device)
model.eval()

embeddings_list = []

from tqdm import tqdm

for batch in tqdm(dataloader, desc="Processing sentences"):
    batch_input_ids, batch_attention_mask = batch

    with torch.no_grad():
        outputs = model(input_ids=batch_input_ids.to(device), 
                        attention_mask=batch_attention_mask.to(device))

    # embeddings = outputs.last_hidden_state.cpu().detach().numpy()
    embeddings = outputs.last_hidden_state.cpu().detach().numpy()
    embeddings_list.append(np.mean(embeddings, axis=1))

# Concatenate all the embeddings if needed
# embeddings_tensor = torch.cat(embeddings_list, dim=0)
embeddings_tensor = np.concatenate(embeddings_list, axis=0)

with open('embeds_rand7_20k.pkl', 'wb') as file:
    pickle.dump(embeddings_tensor, file, protocol=4)
    #     pickle.dump(embeddings_tensor, file, protocol=4)
    

**domain similarity tests**

names of the 7 domains of the candidate data

In [None]:
names = ['amz', 'wiki', 'news', 'pubmed', 'arxiv', 'book1', 'owtc']

Optimal Transport is computed on GPU with 'jax' using package 'ott-jax'.

In [None]:
import jax
import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

print(jax.numpy.ones(3).device()) # TFRT_CPU_0

import matplotlib.pyplot as plt
import jax.numpy as jnp

import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.tools import plot

In [None]:
# from ott.solvers.linear import sinkhorn_lr
import tqdm
from ott import utils
from ott.solvers.linear import acceleration

dists = []

for i in range(7):
    geom = pointcloud.PointCloud(np.array(embeds_task_1k), np.array(embeds_rand7_20k[i*20000:20000+i*20000]), epsilon=5e-1)
    ot_prob = linear_problem.LinearProblem(geom)

    with tqdm.tqdm() as pbar:
        progress_fn = utils.tqdm_progress_fn(pbar)
        solver = sinkhorn.Sinkhorn(progress_fn=progress_fn, momentum=acceleration.Momentum(value=1.2), threshold = 5e-2, inner_iterations=1, max_iterations = 2000)
        # solver = sinkhorn.Sinkhorn(progress_fn=progress_fn, inner_iterations=1, max_iterations = 200000)
        # momentum=None
        ot_u = jax.jit(solver)(ot_prob)
    
    print(i, names[i], f"Converged: {ot_u.converged}, cost: {ot_u.reg_ot_cost}")
    
        # transp_cost = ot_lr.compute_reg_ot_cost(ot_prob, use_danskin=True)
    transp_cost = ot_u.reg_ot_cost
    dists.append(transp_cost)
    # dis_cola[i] = transp_cost


Based on the OT distances, select domains with the smallest OT distances to the target task data and use samples from these domains as the candidate data for further selection. Domains with large OT distances will be discarded for this task.