In [97]:

import numpy as np
import itertools
import datasets
import pandas as pd
import os
from tqdm import tqdm

In [55]:
BASIC_MODELS = os.listdir('data')
BASIC_MODELS.remove("sentences")

In [78]:
TASK_LIST_STS = [
    "SICK-R",
    "STS12",
    "STS13",
    "STS14",
    "STS15",
    "STS16",
    "STS17",
    "STS22",
    "STSBenchmark",
    "BIOSSES",
    # "SummEval",
]

In [88]:
def PCA2(data, dims_rescaled_data=2):
    """
    returns: data transformed in 2 dims/columns + regenerated original data
    pass in: data as 2D NumPy array
    """
    import numpy as NP
    from scipy import linalg as LA
    m, n = data.shape
    # mean center the data
    data -= data.mean(axis=0)
    # calculate the covariance matrix
    R = NP.cov(data, rowvar=False)
    # calculate eigenvectors & eigenvalues of the covariance matrix
    # use 'eigh' rather than 'eig' since R is symmetric,
    # the performance gain is substantial
    evals, evecs = LA.eigh(R)
    # sort eigenvalue in decreasing order
    idx = NP.argsort(evals)[::-1]
    evecs = evecs[:,idx]
    # sort eigenvectors according to same index
    evals = evals[idx]
    # select the first n eigenvectors (n is desired dimension
    # of rescaled data array, or dims_rescaled_data)
    evecs = evecs[:, :dims_rescaled_data]
    # carry out the transformation on the data using eigenvectors
    # and return the re-scaled data, eigenvalues, and eigenvectors
    return NP.dot(evecs.T, data.T).T, evals, evecs

In [98]:
max_stack_size = len(BASIC_MODELS)

# Generate stacked model of all sizes
ALL_COMBINATIONS = []
for r in tqdm(range(2, min(max_stack_size + 1, len(BASIC_MODELS) + 1))):
    combinations_object = itertools.combinations(BASIC_MODELS, r)
    combinations_list = [sorted(list(combination)) for combination in combinations_object] # Sort to ensure the same combination is always the same
    
    for combination in combinations_list:
        # [ANGLE, COHERE]
        for task in TASK_LIST_STS:
            task_embeddings = []
            for model in combination:
                model_dataset = datasets.load_from_disk(f"data/{model}/{task}")
                # rename column
                model_dataset = model_dataset.rename_column("embeddings", f"embeddings_{model}")
                task_embeddings.append(model_dataset)
            
            ds = datasets.concatenate_datasets(task_embeddings, axis = 1)
            df = ds.to_pandas()
            concat_model = "$".join(combination)
            df[concat_model] = df.apply(lambda row: np.concatenate([row[f"embeddings_{model}"] for model in combination]), axis = 1)
            df = df[concat_model]
            
            if len(df) < 1024:
                continue
            
            # PCA
            data = np.array(df.tolist())
            new_data, _, _ = PCA2(data, dims_rescaled_data=1024)
            
            path = f"data_pca/{concat_model}/{task}"
            dataset = datasets.Dataset.from_dict({"embeddings": new_data})
            dataset.save_to_disk(path, max_shard_size="75MB")

            
MODELS = BASIC_MODELS + ALL_COMBINATIONS

  0%|          | 0/4 [00:00<?, ?it/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

 25%|██▌       | 1/4 [02:09<06:29, 129.93s/it]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

 50%|█████     | 2/4 [06:16<06:37, 198.58s/it]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

 75%|███████▌  | 3/4 [09:57<03:28, 208.88s/it]

Saving the dataset (0/3 shards):   0%|          | 0/19854 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/10684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2372 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/17256 [00:00<?, ? examples/s]

100%|██████████| 4/4 [11:06<00:00, 166.51s/it]
