In [2]:
import fasttext
import numpy as np
import senteval
import senteval.engine

from preprocess import preprocess

In [3]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA device found")

True
NVIDIA GeForce RTX 2060


In [4]:
def batcher(params: dict, batch: list[list[str]]):
    model = params["model"]
    sentences = [" ".join(tokens) for tokens in batch]
    sentences = [preprocess(sentence) for sentence in sentences]
    embeddings = [model.get_sentence_vector(sentence) for sentence in sentences]
    return np.array(embeddings)

def evaluate(model_name: str):
    model = fasttext.load_model(f"models/{model_name}/model.bin")
    se = senteval.engine.SE(
        params={
            "task_path": "./data",
            "usepytorch": False,
            "model": model,
        },
        batcher=batcher,
    )
    return se.eval(["STSBenchmark", "SICKRelatedness"])

In [5]:
evaluate(model_name="cbow-300")

  text = bs4.BeautifulSoup(text, "html.parser").get_text()
  text = bs4.BeautifulSoup(text, "html.parser").get_text()


{'STSBenchmark': {'devpearson': 0.7582449438022215,
  'pearson': 0.6566804830515196,
  'spearman': 0.6570145765611611,
  'mse': 1.533234035180191,
  'yhat': array([2.15068829, 1.95850046, 1.99806645, ..., 3.69713277, 3.73834193,
         3.4640295 ]),
  'ndev': 1500,
  'ntest': 1379},
 'SICKRelatedness': {'devpearson': 0.7493884812763058,
  'pearson': 0.750127279355567,
  'spearman': 0.6188631478522657,
  'mse': 0.44532020448988174,
  'yhat': array([2.26351474, 3.87570367, 1.05863823, ..., 3.73737081, 3.99276663,
         4.13148674]),
  'ndev': 500,
  'ntest': 4927}}

In [7]:
evaluate(model_name="cbow-300-default")

  text = bs4.BeautifulSoup(text, "html.parser").get_text()
  text = bs4.BeautifulSoup(text, "html.parser").get_text()


{'STSBenchmark': {'devpearson': 0.7572775685986125,
  'pearson': 0.6577306342711853,
  'spearman': 0.6591548860719196,
  'mse': 1.5385954969345816,
  'yhat': array([2.2437307 , 1.90932599, 2.20185072, ..., 3.72475329, 3.76128937,
         3.54415373]),
  'ndev': 1500,
  'ntest': 1379},
 'SICKRelatedness': {'devpearson': 0.757374412812026,
  'pearson': 0.7492449334032484,
  'spearman': 0.6205331572018444,
  'mse': 0.44721855107553615,
  'yhat': array([3.16511013, 3.91810059, 1.01254228, ..., 3.43609562, 4.03274303,
         4.1775888 ]),
  'ndev': 500,
  'ntest': 4927}}

In [8]:
evaluate(model_name="skipgram-300-default")

  text = bs4.BeautifulSoup(text, "html.parser").get_text()
  text = bs4.BeautifulSoup(text, "html.parser").get_text()


{'STSBenchmark': {'devpearson': 0.7574455092666968,
  'pearson': 0.6503540111722347,
  'spearman': 0.6546787360727105,
  'mse': 1.5520568811017421,
  'yhat': array([2.23092577, 1.93793401, 1.97176552, ..., 3.71433933, 3.74879795,
         3.51791616]),
  'ndev': 1500,
  'ntest': 1379},
 'SICKRelatedness': {'devpearson': 0.7244545900002768,
  'pearson': 0.7445994716998314,
  'spearman': 0.6126717069030698,
  'mse': 0.45370780680852785,
  'yhat': array([3.06755047, 3.86533583, 1.00140402, ..., 3.55201603, 4.00198343,
         4.14197291]),
  'ndev': 500,
  'ntest': 4927}}

In [9]:
evaluate(model_name="pretrained")

  text = bs4.BeautifulSoup(text, "html.parser").get_text()
  text = bs4.BeautifulSoup(text, "html.parser").get_text()


{'STSBenchmark': {'devpearson': 0.7722644003313113,
  'pearson': 0.699760538824385,
  'spearman': 0.6864169654946871,
  'mse': 1.3794074405581056,
  'yhat': array([1.98113038, 1.38803504, 1.73000027, ..., 3.84465029, 3.91753221,
         3.51790507]),
  'ndev': 1500,
  'ntest': 1379},
 'SICKRelatedness': {'devpearson': 0.7399497063112649,
  'pearson': 0.7704026296896912,
  'spearman': 0.6457025965930945,
  'mse': 0.4145651948976424,
  'yhat': array([2.2199348 , 3.61858781, 1.00043939, ..., 3.75805776, 3.8675019 ,
         4.16258069]),
  'ndev': 500,
  'ntest': 4927}}

In [5]:
evaluate(model_name="skipgram-300")

  text = bs4.BeautifulSoup(text, "html.parser").get_text()
  text = bs4.BeautifulSoup(text, "html.parser").get_text()


{'STSBenchmark': {'devpearson': 0.7562020866135934,
  'pearson': 0.6641948679986213,
  'spearman': 0.6572811012906445,
  'mse': 1.4311542552175271,
  'yhat': array([1.9737691 , 1.62806711, 2.08961052, ..., 3.95394052, 3.9009156 ,
         3.53364154]),
  'ndev': 1500,
  'ntest': 1379},
 'SICKRelatedness': {'devpearson': 0.7599352749950791,
  'pearson': 0.7512608783547287,
  'spearman': 0.6234612616025729,
  'mse': 0.4442296544011046,
  'yhat': array([2.53432852, 3.8379614 , 1.0014275 , ..., 3.83733679, 3.96035554,
         4.13801533]),
  'ndev': 500,
  'ntest': 4927}}