In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, GPT2LMHeadModel
from typing import Dict
import numpy as np
import scipy
import json

def perform_nli(premise, hypothesis):
    with torch.inference_mode():
        out = nli_model(**nli_tokenizer(premise, hypothesis, return_tensors='pt'))
        proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
    return {v: proba[k] for k, v in nli_model.config.id2label.items()}

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

def get_embedding(sentences):
    encoded_input = sbert_tokenizer(sentences, padding=True, truncation=False, return_tensors='pt')

    with torch.inference_mode():
        model_output = sbert_model(**encoded_input)

    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    return sentence_embeddings

def estimate_perplexity(text):
    inputs = gpt_tokenizer(text, return_tensors="pt")
    with torch.inference_mode():
        loss = gpt_model(**inputs, labels=inputs["input_ids"]).loss
    return loss.item()

def get_mean_dist(text, embeds, k=10):
    embed = get_embedding(text).numpy()
    dists = scipy.spatial.distance.cdist(embed, compare_embeds, metric="cosine")[0]
    return np.partition(dists, k)[:k].mean()

def score(premise: str, hypothesis: str, test: Dict[str, np.array], top_k=10, perplexity_threshold=6., beta=1.):
    nli_scores = perform_nli(premise, hypothesis)
    perplexity_score = estimate_perplexity(premise)
    
    if hypothesis not in test:
        raise Exception("There are no ready embeddings for this hypothesis!")
    else:
        distance_score = get_mean_dist(premise, test[hypothesis])
    
    divisor = max(1., np.exp(perplexity_score - perplexity_threshold))
    nli_quality = 1 - nli_scores["contradiction"]
    
    f_score = (1 + beta ** 2) * nli_quality * distance_score / (beta ** 2 * nli_quality + distance_score)
    return {
        "nli_quality": nli_quality,
        "distance_score": distance_score,
        "perplexity_divisor": divisor,
        "final": f_score / divisor
    }

with open("data/embeddings.json") as f:
    hypothesis_files = json.load(f)
embeddings = {text: torch.load(file).numpy() for text, file in hypothesis_files.items()}

nli_tokenizer = AutoTokenizer.from_pretrained('cointegrated/rubert-base-cased-nli-threeway')
nli_model = AutoModelForSequenceClassification.from_pretrained('cointegrated/rubert-base-cased-nli-threeway')

sbert_tokenizer = AutoTokenizer.from_pretrained("sberbank-ai/sbert_large_nlu_ru")
sbert_model = AutoModel.from_pretrained("sberbank-ai/sbert_large_nlu_ru")

gpt_tokenizer = AutoTokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
gpt_model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')

Downloading:   0%|          | 0.00/545 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.65M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.62M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/712M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/323 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/655 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.78M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.71M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.27M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Downloading:   0%|          | 0.00/551M [00:00<?, ?B/s]

In [None]:
%%time

from pprint import pprint

premise1 = "Иван любит собирать грибы."
premise2 = "Ивана похитили инопланетяне и оставили одного в лесу в сапогах."
hypothesis = "На ногах у Ивана резиновые сапоги. Иван ходит по лесу с ножом и корзиной."

res1 = score(premise1, hypothesis, embeddings)
res2 = score(premise2, hypothesis, embeddings)

pprint({premise1: res1, premise2: res2})

[autoreload of torch.overrides failed: Traceback (most recent call last):
  File "/home/data_sapiens/.local/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/home/data_sapiens/.local/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/usr/lib/python3.8/imp.py", line 314, in reload
    return importlib.reload(module)
  File "/usr/lib/python3.8/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 604, in _exec
  File "<frozen importlib._bootstrap_external>", line 848, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/home/data_sapiens/.local/lib/python3.8/site-packages/torch/overrides.py", line 1365, in <module>
    has_torch_function = _add_docstr(
RuntimeError: function '_has_torch_function' already has a do