# Hyperparameter Sweep

This notebook demonstrates how to use simple grid search to find the best hyperparameters for a deduplication method.

### Dependencies

In [1]:
import pickle
import time
from collections import defaultdict

import datasets
import pandas as pd

In [2]:
%%capture --no-display
dataset_path = "temp_inp"
ds = datasets.load_dataset("pinecone/core-2020-05-10-deduplication", split="train")
ds = ds.map(lambda x: {"text": (x["processed_title"] + " " + x["processed_abstract"]).lower()})
ds.save_to_disk(dataset_path)

Saving the dataset (0/1 shards):   0%|          | 0/100000 [00:00<?, ? examples/s]

In [3]:
truth = ds.map(lambda x, id: {"core_id": x["core_id"], "id": id, "duplicates": x["labelled_duplicates"]}, remove_columns=ds.column_names, with_indices=True)
id2core_id = {x["id"]: int(x["core_id"]) for x in truth}
labels = {int(x["core_id"]): set(map(int, x["duplicates"])) if x["duplicates"] else set() for x in truth}

Loading cached processed dataset at /Users/chenghao/.cache/huggingface/datasets/pinecone___json/pinecone--core-2020-05-10-deduplication-dbaaf752a12c0b16/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-7d521c7365b892fc.arrow


In [4]:
def evaluate(path):

    with open(path, "rb") as f:
        uf = pickle.load(f)

    id2cluster = defaultdict(set)
    for id, cluster in uf.parent.items():
        id2cluster[cluster].add(id)

    predictions = {
        id2core_id[x["id"]]: set(
            [id2core_id[neighbor] for neighbor in id2cluster[uf.find(x["id"])] if neighbor != x["id"]]
        )
        for x in truth
    }
    df = (
        pd.Series(labels)
        .to_frame("duplicates")
        .reset_index()
        .merge(pd.Series(predictions).to_frame("predictions").reset_index(), on="index")
    )

    df["Correct"] = df.apply(lambda row: set(row["duplicates"]) == set(row["predictions"]), axis=1).astype(int)
    prediction_summary = {"Correct": df["Correct"].sum(), "Incorrect": df.shape[0] - df["Correct"].sum()}
    prediction_summary["Accuracy"] = round(prediction_summary["Correct"] / df.shape[0], 4)

    def _recall(row):
        labelled_dups = set(row["duplicates"])
        if len(labelled_dups) == 0:
            return 1
        dups = set(row["predictions"])
        return len(dups & labelled_dups) / len(labelled_dups)

    recalls = df.apply(lambda row: _recall(row), axis=1)
    prediction_summary["Recall"] = round(recalls.mean(), 4)

    def _precision(row):
        labelled_dups = set(row["duplicates"])
        dups = set(row["predictions"])
        if len(dups) == 0:
            return 0

        return len(dups & labelled_dups) / len(dups)

    precisions = df.apply(lambda row: _precision(row), axis=1)
    prediction_summary["Precision"] = round(precisions.mean(), 4)

    return prediction_summary


In [8]:
bit_diff = [1, 2, 3, 4, 5, 6]
ngram = [2, 3, 4, 5, 6, 7, 8, 9, 10]
results = []
temp_dir = "temp_simhash"
for bd in bit_diff:
    for ng in ngram:
        num_bucket = max(3, bd + 1)
        start_time = time.time()
        !python -m text_dedup.simhash --path ./$dataset_path --local --column text --output $temp_dir --split train --debug \
        --bit_diff $bd \
        --num_bucket $num_bucket \
        --ngram $ng >> /dev/null 2>&1
        print(f"Running with bit_diff={bd} and ngram={ng}")
        metrics = evaluate(f"{temp_dir}/uf.pkl")
        metrics["time"] = time.time() - start_time
        metrics["bit_diff"] = bd
        metrics["ngram"] = ng
        results.append(metrics)

Running with bit_diff=1 and ngram=2
Running with bit_diff=1 and ngram=3
Running with bit_diff=1 and ngram=4
Running with bit_diff=1 and ngram=5
Running with bit_diff=1 and ngram=6
Running with bit_diff=1 and ngram=7
Running with bit_diff=1 and ngram=8
Running with bit_diff=1 and ngram=9
Running with bit_diff=1 and ngram=10
Running with bit_diff=2 and ngram=2
Running with bit_diff=2 and ngram=3
Running with bit_diff=2 and ngram=4
Running with bit_diff=2 and ngram=5
Running with bit_diff=2 and ngram=6
Running with bit_diff=2 and ngram=7
Running with bit_diff=2 and ngram=8
Running with bit_diff=2 and ngram=9
Running with bit_diff=2 and ngram=10
Running with bit_diff=3 and ngram=2
Running with bit_diff=3 and ngram=3
Running with bit_diff=3 and ngram=4
Running with bit_diff=3 and ngram=5
Running with bit_diff=3 and ngram=6
Running with bit_diff=3 and ngram=7
Running with bit_diff=3 and ngram=8
Running with bit_diff=3 and ngram=9
Running with bit_diff=3 and ngram=10
Running with bit_diff=4 a

In [10]:
pd.DataFrame(results).sort_values("Accuracy", ascending=False)

Unnamed: 0,Correct,Incorrect,Accuracy,Recall,Precision,time,bit_diff,ngram
46,82075,17925,0.8208,0.8413,0.3544,209.974036,6,3
47,80936,19064,0.8094,0.8277,0.3422,212.105268,6,4
48,80345,19655,0.8034,0.821,0.3352,181.985123,6,5
37,80276,19724,0.8028,0.821,0.3351,80.227249,5,3
49,79774,20226,0.7977,0.8155,0.33,174.631982,6,6
38,79331,20669,0.7933,0.81,0.3243,60.95705,5,4
50,79261,20739,0.7926,0.8103,0.3249,172.973809,6,7
51,79146,20854,0.7915,0.8088,0.3235,173.322175,6,8
39,78764,21236,0.7876,0.8039,0.3181,58.976702,5,5
52,78654,21346,0.7865,0.8042,0.3189,184.024269,6,9
