# Embedding Model with Instructions

## Datasets

In [71]:
def display_dataset_info(dataset):
    info = dataset.info
    dataset_name = info.dataset_name
    splits_info = info.splits
    features = info.features
    print(f"Dataset Name: {dataset_name}")
    print("Splits Info:")
    for split_name, split_info in splits_info.items():
        num_examples = split_info.num_examples
        print(f" - Split: {split_name}, Num Examples: {num_examples}")
    print("Features:")
    for feature_name, feature_info in features.items():
        print(f" - {feature_name}: {feature_info}")

In [72]:
from datasets import load_dataset

def get_dataset(dataset_name, sample_size=0):  
    # Load the dataset
    dataset = load_dataset(dataset_name)
    #display_dataset_info(dataset['test'])
    
    # Access the train, test splits
    train_dataset = dataset['train']
    test_dataset = dataset['test']

    # Random sample the dataset, only use random_sample_size
    if(sample_size != 0):
        train_dataset = train_dataset.shuffle(seed=42).select(range(sample_size))
        test_dataset = test_dataset.shuffle(seed=42).select(range(sample_size))

    return train_dataset, test_dataset

## Text Embedding Models

In [73]:
from tqdm.auto import tqdm
from transformers.pipelines.pt_utils import KeyDataset
import numpy as np

def process_dataset(model, dataset, key="text", truncation=True, padding=True, max_length=512, use_mean_pooling=True):
    data = KeyDataset(dataset, key)
    pipe = model(data, return_tensors=True, truncation=truncation, padding=padding, max_length=max_length)
    embeddings=[]
    for tensor in tqdm(pipe, desc="Encoding:"): 
        # Use mean pooling to 1 dim
        if(use_mean_pooling):
            tensor = tensor.mean(dim=1)
            tensor = tensor.flatten()
        embeddings.append(tensor)
    return np.array(embeddings), np.array(dataset["label"])

Augment dataset with instructions

In [74]:
def mapper(dataset, prefix, suffix):
    dataset['text'] = prefix + dataset['text'] + suffix
    return dataset

def augment_dataset(dataset, prefix, suffix):
    return dataset.map(lambda x: mapper(x, prefix, suffix))

### Train classifiers to evaluate Embedding performance

In [75]:
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report

def evaluate(method, train_embeddings, test_embeddings, train_labels, test_labels):
    if method == "SVM":
        model = SVC(kernel='linear')
        
    elif method == "MLP":
        model = MLPClassifier(hidden_layer_sizes=(100,), max_iter=300, alpha=1e-4,
                          solver='sgd', verbose=1, random_state=1,
                          learning_rate_init=.1)

    model.fit(train_embeddings, train_labels)
    predicted_labels = model.predict(test_embeddings)
    print("Report on " + method + ": ")
    print(classification_report(y_true = test_labels, y_pred = predicted_labels))

## The Text Embedding Pipeline

In [76]:
from transformers import pipeline
import warnings
def EmbedFlow(model_name, dataset_name, sample_size, evaluator, prefix, suffix):
    warnings.filterwarnings("ignore")
    # Load Dataset
    train_dataset, test_dataset = get_dataset(dataset_name, sample_size)

    # Load Text Embedding Model
    model = pipeline("feature-extraction", model=model_name, device=0)
    
    # Add Instruction
    train_dataset = augment_dataset(train_dataset, prefix, suffix)
    #test_dataset = augment_dataset(train_dataset, prefix, suffix)

    # Embed Dataset
    train_embeddings, train_labels = process_dataset(model, train_dataset)
    test_embeddings, test_labels = process_dataset(model, test_dataset)

    # Evaluate
    evaluate(evaluator, train_embeddings, test_embeddings, train_labels, test_labels)

In [77]:
models  = ['google-bert/bert-base-uncased',
          '',
          '']
datasets = ['stanfordnlp/imdb', 'yelp_review_full']
evaluator = ['SVM', 'MLP']
prefix = ['','Movie Review: ', 'Cat and Dog: ']
suffix = ['','']
sample_size = 1000

EmbedFlow(models[0], datasets[0], sample_size, evaluator[0], prefix[0], suffix[0])

Encoding::   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding::   0%|          | 0/1000 [00:00<?, ?it/s]

Report on SVM: 
              precision    recall  f1-score   support

           0       0.85      0.82      0.83       512
           1       0.82      0.84      0.83       488

    accuracy                           0.83      1000
   macro avg       0.83      0.83      0.83      1000
weighted avg       0.83      0.83      0.83      1000

