In [None]:
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from datetime import datetime
from collections import defaultdict
from torch.utils.data import IterableDataset
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import re

import json
import time
import torch
import os
import logging

In [None]:
train_batch_size = 16
model_name = 'cross-encoder/ms-marco-TinyBERT-L-6'
model_save_path = 'models/crenc-readme-exp2'
data_folder = 'generated4'

In [None]:
df = pd.read_excel('./data/20231004_data.xlsx', index_col=0)


df['answer'] = df['readme_short'].astype(str)


df = df.dropna()
df.head(1)

In [None]:
df.info()

In [None]:
def get_triplets(Passage_dict):
    triplets = []
    for k, v in Passage_dict.items():
        for x in v[0]:
            for y in v[1]:
                triplets.append([k, x, y])

    return triplets

def get_dataset(triplets, corpus):
    dataset = []        
    for triplet in triplets:
        qid, pos_id, neg_id = triplet
        
        qid = str(qid)
        pos_id = pos_id
        neg_id = neg_id

        query_text = corpus[qid]
        pos_text = df.loc[pos_id, 'answer'] 
        neg_text = df.loc[neg_id, 'answer']

        pos_instance = InputExample(texts=[query_text, pos_text],label=1)
        neg_instance = InputExample(texts=[query_text, neg_text],label=0)

        dataset.append(pos_instance)
        dataset.append(neg_instance)

    return dataset


with open(f'./data/{data_folder}/train_passage.json', 'r') as f:
    train_passage = json.load(f)

with open(f'./data/{data_folder}/train_corpus.json', 'r') as f:
    train_corpus = json.load(f)

with open(f'./data/{data_folder}/val_passage.json', 'r') as f:
    val_passage = json.load(f)

with open(f'./data/{data_folder}/val_corpus.json', 'r') as f:
    val_corpus = json.load(f)

train_triplets = get_triplets(train_passage)
train_dataset = get_dataset(train_triplets, train_corpus)

val_triplets = get_triplets(val_passage)
val_dataset = get_dataset(val_triplets, val_corpus)

In [None]:
logging.basicConfig(
    format='- %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    handlers=[LoggingHandler()]
)

model = CrossEncoder(model_name, max_length=None)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
evaluator = CEBinaryClassificationEvaluator.from_input_examples(val_dataset, name='cross_encoder_val')

In [None]:
warmup_steps = int(len(train_dataloader) * 2 * 0.1)

model.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=2,
    evaluation_steps=int(len(train_dataloader) / 2),
    warmup_steps=warmup_steps,
    save_best_model=True,
    output_path=model_save_path
)

In [None]:
del model
torch.cuda.empty_cache()