In [None]:
import json 
import pandas as pd
from datetime import datetime
import torch
from torch import nn
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, CrossEncoder, evaluation, InputExample, datasets
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import util as sentenceutils
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import random 
import math

In [None]:
# change path depending on dataset
data_folder = '/contextretrieval/cross-encoder/wow/splits/' 

In [None]:
train_pairs = pd.read_csv(data_folder + 'train_pairs.csv')
test_samples = pd.read_csv(data_folder + 'test_samples.csv', converters={'positive': pd.eval, 'negative': pd.eval})

### Train Samples

In [None]:
train_samples = []

for i in range(0, len(train_pairs)):
    train_samples.append(InputExample(texts=[train_pairs['input'][i], train_pairs['passage'][i]], label = train_pairs['label'][i]))

In [None]:
train_batch_size = 16
num_epochs = 3

In [None]:
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

### Test Samples

In [None]:
test_samples['negative'] = test_samples['negative'].apply(set)
test_samples['positive'] = test_samples['positive'].apply(set)

In [None]:
test_samples = test_samples.to_dict('index') 

## Train Model 

In [1]:
# load base model 
cross_encoder = CrossEncoder('/contextretrieval/cross-encoder/ms-marco-MiniLM-L-6-v2',num_labels=1, max_length=512)

In [None]:
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)

In [None]:
evaluator = CERerankingEvaluator(test_samples)

In [None]:
output_folder = '/contextretrieval/cross-encoder/wow/tuned_models'

In [None]:
cross_encoder.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=500,
          warmup_steps=warmup_steps,
          output_path=output_folder+'ms-marco-MiniLM-L-6-v2'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
          use_amp=True)