In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer
import regex as re
from model2vec.distill.tokenizer import remove_tokens
import numpy as np
from tqdm import tqdm
import cvxpy as cvx
from model2vec import StaticModel

  from .autonotebook import tqdm as notebook_tqdm


## Toy Example: IMDB Classification

### Baselines
**Model2Vec**: 65%

**Potion**: 70%

In [2]:
data = load_dataset("mteb/imdb")

In [3]:
class TokenizerWrapper:
    def __init__(self, tokenizer_name: str, potion_model_name: str, device="cpu"):
        self.tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name, trust_remote_code=True, device=device)
        full_vocab = [pair[0] for pair in sorted(self.tokenizer.get_vocab().items(), key=lambda x: x[1])]
        vocab = [x for x in full_vocab if not re.match("\[unused\d+\]", x)]
        self.vocab_size = len(vocab)
        self.tokenizer = remove_tokens(self.tokenizer.backend_tokenizer, set(full_vocab) - set(vocab))
        self.tokenizer.no_padding()

        model = StaticModel.from_pretrained(potion_model_name)
        self.projection = model.embedding

    def get_token_counts(self, texts: list[str] | str, progress=False):
        if type(texts) == str:
            texts = [texts]
        encoded = self.tokenizer.encode_batch([x.lower() for x in texts], add_special_tokens=False)
        token_counts = np.zeros((len(texts), self.vocab_size))
        
        iterator = enumerate(encoded)
        if progress:
            iterator = enumerate(tqdm(encoded))
        for (row, enc) in iterator:
            for id in enc.ids:
                token_counts[row, id] += 1
        return token_counts
    
    def get_token_count_projection(self, texts: list[str] | str, progress=False):
        counts = self.get_token_counts(texts, progress=progress)
        return counts @ self.projection


In [4]:
tokenizer = TokenizerWrapper("BAAI/bge-base-en-v1.5", "minishlab/potion-base-2M")

In [5]:
X = tokenizer.get_token_count_projection(data["train"]["text"], progress=True)
y = np.array(data["train"]["label"])

100%|██████████| 25000/25000 [00:04<00:00, 6060.57it/s]


In [6]:
Xtest = tokenizer.get_token_count_projection(data["test"]["text"], progress=True)
ytest = np.array(data["test"]["label"])

100%|██████████| 25000/25000 [00:09<00:00, 2723.84it/s]


In [7]:
X.shape

(25000, 64)

### The following is mostly copied from Mert's Mosek code

In [8]:
def relu(x):
    return np.maximum(0,x)

def drelu(x):
    return x>=0

In [9]:
# NOTE: Mosek parameter dictionary from Mert
# params = {
#       "MSK_IPAR_NUM_THREADS": 8,
#       #"MSK_IPAR_INTPNT_MAX_ITERATIONS": 10,
#       #"MSK_IPAR_OPTIMIZER": 0 # auto 0, interior point 1, conic 2
#       #"MSK_DPAR_INTPNT_CO_TOL_REL_GAP": 1e-2
#       #"MSK_DPAR_INTPNT_TOL_PSAFE": 0.01
#       #"MSK_IPAR_OPTIMIZER": "free"
#       #"MSK_IPAR_INTPNT_SOLVE_FORM": 1
#       }

In [10]:
def solve_problem(
    X: np.array, y: np.array, 
    Xtest: np.array, ytest: np.array,
    seed=0, hidden_dim=2000,
    weight_decay_strength=None,
    verbose=True,
    mosek_params={"MSK_IPAR_NUM_THREADS": 64}
):
    np.random.seed(seed)
    n,d = X.shape
    ntest = Xtest.shape[0]

    if weight_decay_strength is None:
        weight_decay_strength = np.array([0, 1, 5e-1, 1e-1, 1e-2, 1e-3, 1e-4])

    # Say the two-layer neural network is relu(X @ U1) @ U2.
    # Then, the convex version of the neural network requires
    # knowing all possible formations of indic{X @ U1 > 0}, where
    # the indicator function is taken element-wise. This is
    # computationally expensive, so we estimate it by randomly
    # sampling the matrix U1.
    U1 = np.random.randn(d,hidden_dim)
    dmat = drelu(X @ U1)
    
    dmat, ind=(np.unique(dmat,axis=1, return_index=True))
    m1=dmat.shape[1]
    U=U1[:,ind]

    # CVXPY variables for finite-dimensional optimization problem
    # from Section 3 of https://arxiv.org/pdf/2002.10553
    W1=cvx.Variable((d,m1))
    W2=cvx.Variable((d,m1))

    # parameters
    y_out1 = cvx.sum(cvx.multiply(dmat, X@W1),axis=1)
    y_out2 = cvx.sum(cvx.multiply(dmat, X@W2),axis=1)

    reg_term = cvx.mixed_norm(W1.T, 2, 1) + cvx.mixed_norm(W2.T, 2, 1)
    # regularization strength as a cvxpy var
    betaval = cvx.Parameter(nonneg=True)

    objective_function = cvx.sum(
        cvx.sum_squares(y-(y_out1 - y_out2))
    ) / n + betaval * reg_term

    constraints = [
        cvx.multiply(2 * dmat - np.ones((n,m1)), X@W1) >= 0
    ] + [
        cvx.multiply( 2 * dmat - np.ones((n,m1)), X@W2) >= 0
    ]

    problem = cvx.Problem(cvx.Minimize(objective_function), constraints)

    # Solve the problem for each possible regularization strength
    for beta in weight_decay_strength:
        print(f"Trying beta={beta}")
        betaval.value = beta
        problem.solve(
            solver=cvx.MOSEK,
            warm_start=True,
            verbose=verbose,
            mosek_params=mosek_params
        )

        print("Solution Status: ", problem.status)

        W1v=W1.value
        W2v=W2.value

        ytest_est = np.sum(
            drelu(Xtest@U) * (Xtest@W1v) - drelu(Xtest@U) * (Xtest@W2v),
            axis=1
        )
        ytest_est = (ytest_est > 0.5).astype(ytest.dtype)
        err = np.sum(ytest_est != ytest) / ntest
        print("Classification Accuracy", 1 - err)
    

In [None]:
idxs = np.random.choice(X.shape[0], size=1000)
solve_problem(X[idxs], y[idxs], Xtest, ytest, 
              hidden_dim=200, verbose=False)

Trying beta=0.0
Solution Status:  optimal
Classification Accuracy 0.7170000000000001
Trying beta=1.0
Solution Status:  optimal
Classification Accuracy 0.736
Trying beta=0.5
Solution Status:  optimal
Classification Accuracy 0.729
Trying beta=0.1
Solution Status:  optimal
Classification Accuracy 0.7070000000000001
Trying beta=0.01
Solution Status:  optimal
Classification Accuracy 0.7090000000000001
Trying beta=0.001
Solution Status:  optimal
Classification Accuracy 0.706
Trying beta=0.0001
Solution Status:  optimal
Classification Accuracy 0.714
