In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
import regex as re
from model2vec.distill.tokenizer import remove_tokens
import numpy as np
from tqdm import tqdm
import cvxpy as cvx
from model2vec import StaticModel
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ExponentialLR

## Toy Example: IMDB Classification

### Baselines
**Model2Vec**: 65%

**Potion**: 70%

In [None]:
data = load_dataset("mteb/imdb")

In [None]:
class TokenizerWrapper:
    def __init__(self, tokenizer_name: str, potion_model_name: str, device="cpu"):
        self.tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name, trust_remote_code=True, device=device)
        full_vocab = [pair[0] for pair in sorted(self.tokenizer.get_vocab().items(), key=lambda x: x[1])]
        vocab = [x for x in full_vocab if not re.match("\[unused\d+\]", x)]
        self.vocab_size = len(vocab)
        self.tokenizer = remove_tokens(self.tokenizer.backend_tokenizer, set(full_vocab) - set(vocab))
        self.tokenizer.no_padding()

        model = StaticModel.from_pretrained(potion_model_name)
        self.projection = model.embedding

    def get_token_counts(self, texts: list[str] | str, progress=False):
        if type(texts) == str:
            texts = [texts]
        encoded = self.tokenizer.encode_batch([x.lower() for x in texts], add_special_tokens=False)
        token_counts = np.zeros((len(texts), self.vocab_size))
        
        iterator = enumerate(encoded)
        if progress:
            iterator = enumerate(tqdm(encoded))
        for (row, enc) in iterator:
            for id in enc.ids:
                token_counts[row, id] += 1
        return token_counts
    
    def get_token_count_projection(self, texts: list[str] | str, progress=False):
        counts = self.get_token_counts(texts, progress=progress)
        return counts @ self.projection


In [None]:
tokenizer = TokenizerWrapper("BAAI/bge-base-en-v1.5", "minishlab/potion-base-2M")

In [None]:
X = tokenizer.get_token_counts(data["train"]["text"], progress=True)
y = np.array(data["train"]["label"])

In [None]:
Xtest = tokenizer.get_token_counts(data["test"]["text"], progress=True)
ytest = np.array(data["test"]["label"])

### PyTorch Neural Network

**The following code should get about 87% accuracy**

In [None]:
def train_loop(
    model, train, test, loss_fn=None, device="cpu",
    epochs=10, lr=0.01, bs=16, weight_decay=0,
    lr_decay_coef=0.9
):
        
    if loss_fn is None:
        loss_fn = nn.MSELoss()
    model.to(device)

    train_dataloader = DataLoader(train, batch_size=bs, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = ExponentialLR(optimizer, gamma=lr_decay_coef)

    test_dataloader = DataLoader(test, batch_size=bs, shuffle=True)
    with torch.no_grad():
        losses = []
        for (X, y) in test_dataloader:
            losses.append(loss_fn(model(X.to(device)),y.to(device)).item())
        print("Loss", np.mean(losses))

    for _ in tqdm(range(epochs)):
        for (X, y) in train_dataloader:
            output = model(X.to(device))
            #Compute loss using (Mean Squared Error)
            loss = loss_fn(output, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        
        with torch.no_grad():
            losses = []
            for (X, y) in test_dataloader:
                losses.append(loss_fn(model(X.to(device)),y.to(device)).item())
            print("Loss", np.mean(losses))

In [None]:
X_scale = torch.Tensor(X / np.linalg.norm(X, axis=1, keepdims=True))
Xtest_scale = torch.Tensor(Xtest / np.linalg.norm(Xtest, axis=1, keepdims=True))

train = TensorDataset(X_scale, torch.Tensor(y).to(torch.long))
test = TensorDataset(Xtest_scale, torch.Tensor(ytest).to(torch.long))

In [None]:
hidden_dim = 16
model = nn.Sequential(
    nn.Linear(tokenizer.vocab_size, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, 2),
    nn.Softmax(dim=1),
).to("cuda:7")

In [None]:
losses = train_loop(
    model, train, test,
    loss_fn=nn.CrossEntropyLoss(),
    device="cuda:7",
    epochs=20,
    bs=256,
    lr=1e-3
)

In [None]:
losses = []
device = "cuda:7"
test_dataloader = DataLoader(test, batch_size=256, shuffle=True)
for (Xt, yt) in test_dataloader:
    losses.append((model(Xt.to(device)).argmax(dim=1) != yt.to(device)).to(torch.float).mean().item())
print("Classification accuracy:", 1 - np.mean(losses))

### Try with MSE Loss

**This should also do about the same, 88%**

In [None]:
X_scale = torch.Tensor(X / np.linalg.norm(X, axis=1, keepdims=True))
Xtest_scale = torch.Tensor(Xtest / np.linalg.norm(Xtest, axis=1, keepdims=True))

train = TensorDataset(X_scale, torch.Tensor(y).unsqueeze(1))
test = TensorDataset(Xtest_scale, torch.Tensor(ytest).unsqueeze(1))

In [None]:
hidden_dim = 16
model = nn.Sequential(
    nn.Linear(tokenizer.vocab_size, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, 1),
).to("cuda:7")

In [None]:
losses = train_loop(
    model, train, test,
    loss_fn=nn.MSELoss(),
    device="cuda:7",
    epochs=20,
    bs=256,
    lr=1e-3
)

In [None]:
losses = []
device = "cuda:7"
test_dataloader = DataLoader(test, batch_size=256, shuffle=True)
for (Xt, yt) in test_dataloader:
    losses.append((torch.round(model(Xt.to(device))) != yt.to(device)).to(torch.float).mean().item())
print("Classification accuracy:", 1 - np.mean(losses))

## Tweet Sentiment Classification (potion: 55.4%)

Here, cross-entropy ends up doing better than MSE (maybe because it's multiclass).

**The following code should get about 70% accuracy**

In [None]:
data = load_dataset("mteb/tweet_sentiment_extraction")

In [None]:
X = tokenizer.get_token_counts(data["train"]["text"], progress=True)
y = np.array(data["train"]["label"])
Xtest = tokenizer.get_token_counts(data["test"]["text"], progress=True)
ytest = np.array(data["test"]["label"])

In [None]:
X_scale = torch.Tensor(X / np.maximum(1, np.linalg.norm(X, axis=1, keepdims=True)))
Xtest_scale = torch.Tensor(Xtest / np.maximum(1, np.linalg.norm(Xtest, axis=1, keepdims=True)))

train = TensorDataset(X_scale, torch.Tensor(y).to(torch.long))
test = TensorDataset(Xtest_scale, torch.Tensor(ytest).to(torch.long))

In [None]:
hidden_dim = 32
model = nn.Sequential(
    nn.Linear(tokenizer.vocab_size, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, 3),
    nn.Softmax(dim=1),
).to("cuda:7")

In [None]:
losses = train_loop(
    model, train, test,
    loss_fn=nn.CrossEntropyLoss(),
    device="cuda:7",
    epochs=20,
    bs=256,
    lr=1e-3
)

In [None]:
losses = []
device = "cuda:7"
test_dataloader = DataLoader(test, batch_size=256, shuffle=True)
for (Xt, yt) in test_dataloader:
    losses.append((model(Xt.to(device)).argmax(dim=1) != yt.to(device)).to(torch.float).mean().item())
print("Classification accuracy:", 1 - np.mean(losses))

### The following is mostly copied from Mert's Mosek code

In [None]:
X = tokenizer.get_token_count_projection(data["train"]["text"], progress=True)
y = np.array(data["train"]["label"])
Xtest = tokenizer.get_token_count_projection(data["test"]["text"], progress=True)
ytest = np.array(data["test"]["label"])

In [None]:
def relu(x):
    return np.maximum(0,x)

def drelu(x):
    return x>=0

In [None]:
# NOTE: Mosek parameter dictionary from Mert
# params = {
#       "MSK_IPAR_NUM_THREADS": 8,
#       #"MSK_IPAR_INTPNT_MAX_ITERATIONS": 10,
#       #"MSK_IPAR_OPTIMIZER": 0 # auto 0, interior point 1, conic 2
#       #"MSK_DPAR_INTPNT_CO_TOL_REL_GAP": 1e-2
#       #"MSK_DPAR_INTPNT_TOL_PSAFE": 0.01
#       #"MSK_IPAR_OPTIMIZER": "free"
#       #"MSK_IPAR_INTPNT_SOLVE_FORM": 1
#       }

In [None]:
def solve_problem(
    X: np.array, y: np.array, 
    Xtest: np.array, ytest: np.array,
    seed=0, hidden_dim=2000,
    weight_decay_strength=None,
    verbose=True,
    mosek_params={"MSK_IPAR_NUM_THREADS": 64}
):
    np.random.seed(seed)
    n,d = X.shape
    ntest = Xtest.shape[0]

    if weight_decay_strength is None:
        weight_decay_strength = np.array([0, 1, 10, 5e-1, 1e-1, 1e-2, 1e-3])

    # Say the two-layer neural network is relu(X @ U1) @ U2.
    # Then, the convex version of the neural network requires
    # knowing all possible formations of indic{X @ U1 > 0}, where
    # the indicator function is taken element-wise. This is
    # computationally expensive, so we estimate it by randomly
    # sampling the matrix U1.
    U1 = np.random.randn(d,hidden_dim)
    dmat = drelu(X @ U1)
    
    dmat, ind=(np.unique(dmat,axis=1, return_index=True))
    m1=dmat.shape[1]
    U=U1[:,ind]

    # CVXPY variables for finite-dimensional optimization problem
    # from Section 3 of https://arxiv.org/pdf/2002.10553
    W1=cvx.Variable((d,m1))
    W2=cvx.Variable((d,m1))

    # parameters
    y_out1 = cvx.sum(cvx.multiply(dmat, X@W1),axis=1)
    y_out2 = cvx.sum(cvx.multiply(dmat, X@W2),axis=1)

    reg_term = cvx.mixed_norm(W1.T, 2, 1) + cvx.mixed_norm(W2.T, 2, 1)
    # regularization strength as a cvxpy var
    betaval = cvx.Parameter(nonneg=True)

    objective_function = cvx.sum(
        cvx.sum_squares(y-(y_out1 - y_out2))
    ) / n + betaval * reg_term

    constraints = [
        cvx.multiply(2 * dmat - np.ones((n,m1)), X@W1) >= 0
    ] + [
        cvx.multiply( 2 * dmat - np.ones((n,m1)), X@W2) >= 0
    ]

    problem = cvx.Problem(cvx.Minimize(objective_function), constraints)

    # Solve the problem for each possible regularization strength
    for beta in weight_decay_strength:
        print(f"Trying beta={beta}")
        betaval.value = beta
        problem.solve(
            solver=cvx.MOSEK,
            warm_start=True,
            verbose=verbose,
            mosek_params=mosek_params
        )

        print("Solution Status: ", problem.status)

        W1v=W1.value
        W2v=W2.value

        ytest_est = np.sum(
            drelu(Xtest@U) * (Xtest@W1v) - drelu(Xtest@U) * (Xtest@W2v),
            axis=1
        )
        ytest_est = (ytest_est > 0.5).astype(ytest.dtype)
        err = np.sum(ytest_est != ytest) / ntest
        print("Classification Accuracy", 1 - err)
    

In [None]:
idxs = np.random.choice(X.shape[0], size=1000)
solve_problem(X[idxs], y[idxs], Xtest, ytest, 
              hidden_dim=200, verbose=False)