In [1]:
import pandas as pd
import numpy as np
import torch

from sklearn.metrics import classification_report, accuracy_score

import sys
sys.path.append('../rzd')

from information_retrieval import Embedder

import warnings
warnings.filterwarnings('ignore')

In [2]:
model_name_or_path = 'd0rj/ruRoberta-distilled'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedder = Embedder(model_name_or_path=model_name_or_path, device=device)

No sentence-transformers model found with name /Users/alexnikko/.cache/torch/sentence_transformers/d0rj_ruRoberta-distilled. Creating a new one with MEAN pooling.
Some weights of RobertaModel were not initialized from the model checkpoint at /Users/alexnikko/.cache/torch/sentence_transformers/d0rj_ruRoberta-distilled and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
index = pd.read_pickle('database_keys.p')
index_embs = embedder.get_embeddings(index)

In [4]:
validation_dataset = pd.read_csv('validation_dataset.csv')
unique_labels = validation_dataset['gt'].unique()
queries = validation_dataset['query'].tolist()
y_true = validation_dataset['gt'].values
query_embs = embedder.get_embeddings(queries)

In [5]:
def calculate_pairwise_cos_sim(query_embs, index_embs):
    return torch.cosine_similarity(query_embs[:, None], index_embs[None], dim=-1)

def calculate_top_1_prediction(pairwise_dists):
    topk = torch.topk(pairwise_dists, 1, dim=1)
    return topk.indices.ravel().numpy()

def get_y_pred(index, top_1_indices):
    return [index[k] for k in top_1_indices]


def calculate_y_pred(query_embs, index_embs, index):
    pairwise_dists = calculate_pairwise_cos_sim(query_embs, index_embs)
    top_1_indices = calculate_top_1_prediction(pairwise_dists)
    y_pred = get_y_pred(index, top_1_indices)
    return y_pred

In [6]:
y_pred = calculate_y_pred(query_embs, index_embs, index)

In [7]:
print(f'Accuracy = {accuracy_score(y_true, y_pred)}')
print(classification_report(y_true, y_pred, labels=unique_labels))

Accuracy = 0.81
                                                                                                      precision    recall  f1-score   support

                                                                                   Не включилось РУ6       1.00      1.00      1.00        10
                                        Реле РУ6 срабатывает, но не включается реле времени РВ1, РВ2       1.00      1.00      1.00        10
                  При нажатии кнопки "Пуск дизеля" (все нужные автоматы включены) КМН не включается.       1.00      1.00      1.00        10
  При нажатии кнопки "Пуск дизеля" контактор КМН включается, но маслопрокачивающий насос не работает       1.00      1.00      1.00        10
При пуске прокачка масла есть (60-90 сек), но после отключения КМН пусковые контакторы не включаются       1.00      1.00      1.00        10
                                                      При работающем дизеле нет тока зарядки батареи       1.00      0.30      0.46