In [None]:
from datasets import load_dataset
import numpy as np

from relu_embed.classification import ReLUClassifier, MultiproblemReLUClassifier
from relu_embed.data import DatasetInfo, load_and_process_dataset
from relu_embed.utils import TokenizerWrapper
from relu_embed.embedding import NNEmbedding
import mteb
import torch

In [None]:
%load_ext autoreload
%autoreload 2

## Multi-Problem Classification

In [None]:
problems = [
    DatasetInfo("mteb/banking77"),
    DatasetInfo("mteb/amazon_counterfactual"),
    DatasetInfo("mteb/toxic_conversations_50k"),
    DatasetInfo(
        "google-research-datasets/poem_sentiment", text_column="verse_text"),
    DatasetInfo(
        "takala/financial_phrasebank", name="sentences_allagree",
        text_column="sentence", has_splits=False),
    DatasetInfo("fancyzhx/dbpedia_14", text_column="content", train_limit=100_000),
    DatasetInfo("mteb/imdb"),
    DatasetInfo("mteb/tweet_sentiment_extraction"),
    
]
problem_names = [x.dataset for x in problems]

n_classes = []
Xs, ys, Xtests, ytests, ids = [], [], [], [], []
for (i, prob) in enumerate(problems):
    X, y, Xtest, ytest = load_and_process_dataset(prob)
    n_classes.append(int(torch.max(y).item()) + 1)
    Xs.append(X)
    ys.append(y)
    Xtests.append(Xtest)
    ytests.append(ytest)
    ids.append(i)

In [None]:
model = MultiproblemReLUClassifier(
    n_classes=n_classes, problem_names=problem_names,
    input_size=Xs[0].shape[1],
    embedding_size=256,
    device="cuda:7"
)

In [None]:
train_dataloader = model.get_dataloader_batch(
    Xs, ys, ids, batch_size=256, normalize_rows=True
)

In [None]:
test_dataloaders = [
    model.get_dataloader_single(
        Xtest, ytest, id_test, batch_size=256, normalize_rows=True
    ) for (Xtest, ytest, id_test) in zip(Xtests, ytests, ids)
]

In [None]:
model.train(
    train_dataloader, test_dataloaders,
    epochs=5, lr=1e-3,
    lr_decay=0.95,
    eval_interval=500_000
)

In [None]:
tokenizer = TokenizerWrapper("BAAI/bge-base-en-v1.5")
embedding_model = model.get_embedding_model(
    tokenizer, normalize_embeds=True).to("cpu")

In [None]:
torch.save(embedding_model, "../object.bin")

In [None]:
embedding_model = torch.load("temp.bin", weights_only=False)

In [None]:
embedding_model.normalize_embeds = True
embedding_model.normalize_token_counts = True

In [None]:
tasks = mteb.get_tasks(tasks=["ToxicConversationsClassification"])
evaluation = mteb.MTEB(tasks=tasks)

results = evaluation.run(
    embedding_model, output_folder=f"results/test",
    show_progress_bar=True,
    overwrite_results=True
)

In [None]:
np.mean([results[0].scores[name][0]["main_score"] * 100 for name in results[0].scores])

### The following is mostly copied from Mert's Mosek code

In [None]:
# X = tokenizer.get_token_count_projection(data["train"]["text"], progress=True)
# y = np.array(data["train"]["label"])
# Xtest = tokenizer.get_token_count_projection(data["test"]["text"], progress=True)
# ytest = np.array(data["test"]["label"])

# def relu(x):
#     return np.maximum(0,x)

# def drelu(x):
#     return x>=0

# # NOTE: Mosek parameter dictionary from Mert
# # params = {
# #       "MSK_IPAR_NUM_THREADS": 8,
# #       #"MSK_IPAR_INTPNT_MAX_ITERATIONS": 10,
# #       #"MSK_IPAR_OPTIMIZER": 0 # auto 0, interior point 1, conic 2
# #       #"MSK_DPAR_INTPNT_CO_TOL_REL_GAP": 1e-2
# #       #"MSK_DPAR_INTPNT_TOL_PSAFE": 0.01
# #       #"MSK_IPAR_OPTIMIZER": "free"
# #       #"MSK_IPAR_INTPNT_SOLVE_FORM": 1
# #       }
# def solve_problem(
#     X: np.array, y: np.array, 
#     Xtest: np.array, ytest: np.array,
#     seed=0, hidden_dim=2000,
#     weight_decay_strength=None,
#     verbose=True,
#     mosek_params={"MSK_IPAR_NUM_THREADS": 64}
# ):
#     np.random.seed(seed)
#     n,d = X.shape
#     ntest = Xtest.shape[0]

#     if weight_decay_strength is None:
#         weight_decay_strength = np.array([0, 1, 10, 5e-1, 1e-1, 1e-2, 1e-3])

#     # Say the two-layer neural network is relu(X @ U1) @ U2.
#     # Then, the convex version of the neural network requires
#     # knowing all possible formations of indic{X @ U1 > 0}, where
#     # the indicator function is taken element-wise. This is
#     # computationally expensive, so we estimate it by randomly
#     # sampling the matrix U1.
#     U1 = np.random.randn(d,hidden_dim)
#     dmat = drelu(X @ U1)
    
#     dmat, ind=(np.unique(dmat,axis=1, return_index=True))
#     m1=dmat.shape[1]
#     U=U1[:,ind]

#     # CVXPY variables for finite-dimensional optimization problem
#     # from Section 3 of https://arxiv.org/pdf/2002.10553
#     W1=cvx.Variable((d,m1))
#     W2=cvx.Variable((d,m1))

#     # parameters
#     y_out1 = cvx.sum(cvx.multiply(dmat, X@W1),axis=1)
#     y_out2 = cvx.sum(cvx.multiply(dmat, X@W2),axis=1)

#     reg_term = cvx.mixed_norm(W1.T, 2, 1) + cvx.mixed_norm(W2.T, 2, 1)
#     # regularization strength as a cvxpy var
#     betaval = cvx.Parameter(nonneg=True)

#     objective_function = cvx.sum(
#         cvx.sum_squares(y-(y_out1 - y_out2))
#     ) / n + betaval * reg_term

#     constraints = [
#         cvx.multiply(2 * dmat - np.ones((n,m1)), X@W1) >= 0
#     ] + [
#         cvx.multiply( 2 * dmat - np.ones((n,m1)), X@W2) >= 0
#     ]

#     problem = cvx.Problem(cvx.Minimize(objective_function), constraints)

#     # Solve the problem for each possible regularization strength
#     for beta in weight_decay_strength:
#         print(f"Trying beta={beta}")
#         betaval.value = beta
#         problem.solve(
#             solver=cvx.MOSEK,
#             warm_start=True,
#             verbose=verbose,
#             mosek_params=mosek_params
#         )

#         print("Solution Status: ", problem.status)

#         W1v=W1.value
#         W2v=W2.value

#         ytest_est = np.sum(
#             drelu(Xtest@U) * (Xtest@W1v) - drelu(Xtest@U) * (Xtest@W2v),
#             axis=1
#         )
#         ytest_est = (ytest_est > 0.5).astype(ytest.dtype)
#         err = np.sum(ytest_est != ytest) / ntest
#         print("Classification Accuracy", 1 - err)

# idxs = np.random.choice(X.shape[0], size=1000)
# solve_problem(X[idxs], y[idxs], Xtest, ytest, 
#               hidden_dim=200, verbose=False)
    