In [None]:
import torch
from transformers import DistilBertForMaskedLM
from transformers import DistilBertTokenizer
from transformers import CamembertTokenizerFast, CamembertForMaskedLM
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from sentence_transformers import SentenceTransformer
import scipy
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

## NLP

### Independent Token Embedder

In [None]:
# en
model = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
embedding_layer = model.distilbert.embeddings.word_embeddings

# fr
# model = CamembertForMaskedLM.from_pretrained("camembert-base")
# tokenizer = CamembertTokenizerFast.from_pretrained("camembert-base")

### Sentence Embedder
use: sentence-transformers/all-MiniLM-L6-v2

In [None]:
# sentences = ["This is an example sentence", "Each sentence is converted"]

model_sequence = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# embeddings = model_sequence.encode(sentences)
# print(embeddings)


## Application

In [None]:
text_x = "Alice is my friend but she is often late."
text_y = "Alice is my late mother but she is often."
# text_y = "Alice is my friend but she is never on time."
inputs_x = tokenizer(text_x, return_tensors="pt")
inputs_y = tokenizer(text_y, return_tensors="pt")
x = jnp.array(embedding_layer.weight[inputs_x['input_ids']].detach().cpu().numpy())[0]
y = jnp.array(embedding_layer.weight[inputs_y['input_ids']].detach().cpu().numpy())[0]
a, b = jnp.ones((x.shape[0],)) / x.shape[0], jnp.ones((y.shape[0],)) / y.shape[0]
geom = pointcloud.PointCloud(x, y)
prob = linear_problem.LinearProblem(geom, a, b)
solver = sinkhorn.Sinkhorn()
out = solver(prob)
dist_sequence = scipy.spatial.distance.euclidean(*model_sequence.encode([text_x, text_y]))

In [None]:
y_ticks = tokenizer.convert_ids_to_tokens(inputs_y['input_ids'][0])
x_ticks = tokenizer.convert_ids_to_tokens(inputs_x['input_ids'][0])
plt.figure(figsize=(10, 6))
plt.imshow(out.matrix)
plt.colorbar()
plt.title("Optimal Coupling Matrix (Wasserstein Distance UB: {:.4f}, Embedding Distance: {:.4f})".format(out.primal_cost, dist_sequence))
plt.xticks(ticks=np.arange(len(y_ticks)), labels=y_ticks, rotation=90)
plt.yticks(ticks=np.arange(len(x_ticks)), labels=x_ticks)
plt.show()

# <>