In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
from pathlib import Path
from itertools import chain

import torch
import pandas as pd

from torch.utils.data import DataLoader
from sklearn import metrics as sk_metrics
from catalyst.utils import set_global_seed
from sklearn.model_selection import train_test_split
from catalyst.dl import CriterionCallback, CheckpointCallback, AUCCallback

sys.path.insert(0, "../")

from dupbert.model import DupBERT
from dupbert.runner import TripletRunner
from dupbert.dataset import TripletDataset

from dupbert.dataset import TripletDataset
from dupbert.config_reader import ConfigReader
from dupbert.transforms import TextTokenizer, Encoder, PadSequencer

# DupBERT example on quora data

## Set the directory

In [None]:
main_path = Path(os.getcwd()).parents[0]
data_path = main_path / 'data'
logs_path = main_path / 'logs'
config_path = main_path / 'config/config.yaml'


# Extract the parameters from config
config = ConfigReader.load(
    config_filepath=config_path
)

## Load the data

In [None]:
train_df = pd.read_csv(data_path / 'train.csv')
train_df = train_df.sample(n=100, random_state=123)

train_df = train_df[['question1', 'question2', 'is_duplicate']].dropna(how='any')
train_df['is_duplicate'] = train_df.is_duplicate.astype(float)

train_X, valid_X = train_test_split(train_df, **config.train_test_split)

train_triplets = train_X[['question1', 'question2', 'is_duplicate']].values
valid_triplets = valid_X[['question1', 'question2', 'is_duplicate']].values

## Set up parameters of experiment

In [None]:
# Make results reproducible
set_global_seed(config.seed)

# Model 
model = DupBERT(**config.model_params)

# TripletRunner
runner = TripletRunner()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(
        params=[{'params': model.parameters()}],
         **config.optimizer
)

# Test Preprocessing stages
encoder = Encoder(**config.encoder)
pad_sequencer = PadSequencer(**config.pad_sequencer)
txt_tokenizer = TextTokenizer(**config.txt_tokenizer)


# Callbacks for calculating metrics per batch/epoch
callbacks = [
        AUCCallback(
                input_key=model.keys.model_output,
                 target_key=model.keys.targets
        ),
        CriterionCallback(
                input_key=model.keys.model_output,
                target_key=model.keys.targets,
                metric_key="loss"
        ),
        CheckpointCallback(
                loader_key="valid",
                #  mode='runner',
                 **config.early_stopping
        ),
]

## Dataset preparation

In [None]:
train_dataset = TripletDataset(
    train_triplets, txt_tokenizer,
    encoder, pad_sequencer,
    train_mode=True
)
valid_dataset = TripletDataset(
    valid_triplets, txt_tokenizer,
    encoder, pad_sequencer,
    train_mode=False
)

train_loader = DataLoader(train_dataset, **config.loaders)
valid_loader = DataLoader(valid_dataset, **config.loaders)

loaders = {
    "train": train_loader,
    "valid": valid_loader,
}

## Training

In [None]:
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    callbacks=callbacks,
    verbose=True,
    valid_loader='valid',
    valid_metric='auc',
    minimize_valid_metric=False,
    **config.train
)

## Predictions

In [None]:
predictions = list()
for prediction in runner.predict_loader(loader=valid_loader):
    predictions.append(prediction["logits"].detach().cpu().numpy())

pred_probs = [_[0] for _ in chain(*predictions)]
pred_class = [int(_ > .5) for _ in pred_probs]

targets = valid_X['is_duplicate']

metrics_df = pd.DataFrame(
    sk_metrics.precision_recall_fscore_support(y_true=targets, y_pred=pred_class),
    index=['precision', 'recall', 'f1', 'support']
)

In [None]:
metrics_df