In [56]:
import faiss
import torch
from datasets import Dataset, DatasetDict, load_dataset
from sentence_transformers import SentenceTransformer
from tqdm.autonotebook import tqdm

In [2]:
ds = load_dataset(
    "/home/pranav-pc/projects/OpenTransformer/multiformer/data/downloads/TinyStories",
    split="validation",
)

In [3]:
ds

Dataset({
    features: ['text'],
    num_rows: 21990
})

In [4]:
import faiss
from sentence_transformers import SentenceTransformer

In [5]:
# https://huggingface.co/spaces/mteb/leaderboard
model = "mixedbread-ai/mxbai-embed-large-v1"
sentence_model = SentenceTransformer(model)

In [7]:
ds = ds.map(
    lambda example: {"embedding": sentence_model.encode(example["text"])},
    batched=True,
)
ds.set_format("pt")

Map:   0%|          | 0/21990 [00:00<?, ? examples/s]

Map (num_proc=25):   0%|          | 0/21990 [00:00<?, ? examples/s]

TimeoutError: 

In [51]:
def normalize_embedding(example):
    embedding = example["embedding"]
    norm = torch.norm(embedding, dim=1, keepdim=True)
    normalized_embedding = embedding / norm
    return {"embedding": normalized_embedding}


ds = ds.map(normalize_embedding, batched=True, batch_size=int(1e4))

Map:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [52]:
dim = ds[0]["embedding"].shape[0]
index = faiss.IndexFlatIP(dim)
index.add(ds["embedding"])

In [53]:
D, I = index.search(
    ds[:10]["embedding"],
    k=2,
)

In [66]:
print("Filtering out near-duplicates...")
D, I = index.search(ds["embedding"], k=2)
to_keep = []
threshold = 0.95

Filtering out near-duplicates...


Filtering:   0%|          | 0/21990 [00:00<?, ?it/s]

21643
18677
4269
8467
784
4201
6637
21486
18563
2257
18655
9850
12639
5632
5713
6299
19916
149
1696
1696
10182
5593
9646
14805
17027
105
17301
6610
16647
21486
12638
18712
12458
356
3566
6939
4202
17027
14038
8086
12013
5731
2594
12390
10099
20428
17057
154
2633
6171
6167
12518
110
10933
21442
18058
8450
21441
17027
17523
7941
1566
105
432
3566
13446
255
3228
18000
6387
16884
4236
4205
20385
6995
21460
13522
12537
5713
6526
5833
18622
18256
6385
16380
1657
255
8471
16302
10253
11926
6233
8550
18106
6307
255
15975
18734
18091
5590
2257
1696
3139
11926
18610
4712
111
6939
12524
7033
16751
21764
1573
12621
758
743
13502
3566
17200
407
21460
3941
1539
3566
6951
16783
16725
11499
13694
8471
4274
16409
9787
18397
10752
21678
6957
21460
17608
12518
18655
3165
13358
17337
255
423
13694
19908
19905
1222
3806
2164
18418
4243
8483
6307
20405
9666
1717
423
17178
12047
11926
5713
18007
9783
6306
16814
5833
18510
13788
19125
12086
18563
16567
17514
13646
17343
18511
9850
16888
18726
1626
15577
18562

In [106]:
import pandas as pd

df = pd.DataFrame(D)

In [129]:
I[df[(df > 0.98).sum(axis=1) == 2].index]

array([], shape=(0, 2), dtype=int64)

In [130]:
ds[13492]["text"]

'Once upon a time, there was a little boy named Timmy. Timmy liked to climb trees. One day, he saw a big green tree and wanted to climb it. He said to his mom, "Mommy, can I climb that big green tree?" His mom said, "No Timmy, that tree is too high and it\'s bad for you to climb it." Timmy was sad but he listened to his mom.\n\nThe next day, Timmy saw a smaller tree that was also green. He asked his mom, "Mommy, can I climb that small green tree?" His mom said, "Yes Timmy, that tree is not too high and it\'s safe for you to climb it." Timmy was happy and climbed the tree. He felt like a big adventurer.\n\nWhen Timmy got to the top of the tree, he shouted down to his mom, "Mommy, I climbed the tree!" His mom smiled and said, "Good job Timmy, you are a great climber!" Timmy felt proud of himself and couldn\'t wait to climb more trees.'

In [128]:
ds[18007]["text"]

'Once upon a time, there was a little boy named Timmy. Timmy loved to climb trees. One day, Timmy saw a really high tree and he wanted to climb it. \n\nTimmy\'s mom said, "Be careful Timmy, that tree is really high." \n\nTimmy said, "I can do it, Mommy!" \n\nSo, Timmy climbed and climbed until he reached the top of the tree. He looked down and saw his mom tapping her foot. \n\n"Come down, Timmy," she said. \n\nTimmy climbed back down and said, "That was so much fun! Can we climb another tree tomorrow?" \n\nHis mom smiled and said, "Sure, Timmy. But let\'s find a shorter one next time."'

In [65]:
for i in tqdm(range(len(ds["embedding"])), desc="Filtering"):
    # If the second closest vector (D[i, 1]) has cosine similarity above the threshold
    if D[i, 1] >= threshold:
        # Check if either the current item or its nearest neighbor is already in the to_keep list
        nearest_neighbor = I[i, 1]
        if i not in to_keep and nearest_neighbor not in to_keep:
            # If not, add the current item to the list
            to_keep.append(i)
        print(nearest_neighbor)
    else:
        # If the similarity is below the threshold, always keep the current item
        to_keep.append(i)

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,


In [69]:
len(to_keep)

21864

In [64]:
len(I)

21990

In [132]:
ds.select(to_keep)

Dataset({
    features: ['text', 'embedding'],
    num_rows: 21864
})