In [2]:
import json

import sklearn
import torch
import pandas as pd
from tqdm import tqdm

from sentence_transformers import SentenceTransformer, InputExample, util, losses
from sentence_transformers import evaluation
from torch.utils.data import DataLoader

<IPython.core.display.Javascript object>

### Загрузка данных для дообучения трансформера

In [3]:
TRAIN_DATASET_PATH = r"/home/train.json"
TEST_DATASET_PATH = r"/home/test.json"

with open(TRAIN_DATASET_PATH, "r") as dataset_file:
    train_dataset = json.load(dataset_file)

with open(TEST_DATASET_PATH, "r") as dataset_file:
    test_dataset = json.load(dataset_file)

### Инициализация предобученной модели трансформера

In [4]:
MODEL = "sentence-transformers/all-mpnet-base-v2"
model = SentenceTransformer(MODEL)

### Подготовка датасета для дообучения и эвалюатора для теста

In [6]:
BATCH_SIZE = 32
DATASET_LEN = len(train_dataset)

train_examples = []
for i in tqdm(range(len(train_dataset))):
    train_examples.append(InputExample(texts=[train_dataset[i]["text"], train_dataset[i]["comment"]], label=train_dataset[i]["score"]))

evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(test_dataset, batch_size=BATCH_SIZE)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)
train_loss = losses.CosineSimilarityLoss(model)

100%|██████████| 440535/440535 [00:01<00:00, 252567.34it/s]

BATCH_SIZE = 32
EPOCHS_NUM = 50
DATASET_LEN = 440535
STEPS_PER_ERPOCH = 1000





### Запуск дообучения на 50 эпохах по 1000 шагов

In [None]:
EPOCHS_NUM = 50
STEPS_PER_EPOCH = 1000
WARMUP_STEPS = int(len(train_dataloader) * EPOCHS_NUM * 0.1)

model.fit(
    train_objectives=[(train_dataloader, train_loss)], 
    epochs=EPOCHS_NUM, 
    warmup_steps=WARMUP_STEPS,
    steps_per_epoch=STEPS_PER_EPOCH,
    evaluator=evaluator,
    evaluation_steps=2500,
    
    output_path="./output",
    checkpoint_path="./checkpoint",
    checkpoint_save_total_limit=25,

    use_amp=True
)

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1000 [00:00<?, ?it/s]

### Загрузка дообученной модели

In [None]:
MODEL = "./output"
model = SentenceTransformer(MODEL)

### Проверка дообученной модели на текстовых данных

In [11]:
test = []
for item in tqdm(test_dataset[:5000]):
    text = item.texts[0]
    comment = item.texts[1]
    orig_score = item.label
    
    text_emb = model.encode(text)
    comment_emb = model.encode(comment)
    score = util.dot_score(text_emb, comment_emb)[0].cpu().tolist()[0]
    
    test.append({
        "text": text,
        "comment": comment,
        "orig_score": orig_score,
        "calculated_score": score
    })
    
test_df = pd.DataFrame(test)
test_df

100%|██████████| 5000/5000 [02:03<00:00, 40.60it/s]


Unnamed: 0,text,comment,orig_score,score
0,A portable high-resolution timestamp in C++,If you want to see what a rabbit hole looks li...,1.0,0.788209
1,Underwater Kites Can Harness Ocean Currents to...,"More, uhm, sophisticated designs are already i...",0.4,0.486585
2,Software Effort Estimation Considered Harmful,He's sometimes right. There are lots of good ...,0.8,0.632986
3,Obama Must Stop Neglecting India,aargh! Politics! why is this on HN? And that t...,0.6,0.660816
4,Ask YC: How do you defend the downmodded?,I do a little of both; I'll spend a little if ...,0.6,0.521535
...,...,...,...,...
4995,"BashBooster – simple provisioning, bash only, ...",This reminds me of fucking_shell_scripts[1] bu...,0.6,0.793892
4996,A Story on Drive,what a short but inspiring paragraph. I just ...,0.4,0.630290
4997,"India blocks 73 URLs criticizing IIPM, an MBA ...",I think a lot of people are confused with the ...,0.4,0.554567
4998,"Ask HN: is ""The Social Network"" motivating ent...",I think it will inspire a bunch of people to t...,0.8,0.564551
