In [6]:
import os
from sentence_transformers import SentenceTransformer
import gzip
import json
import torch
import pyarrow.parquet as pq

In [2]:
# Load sentence embeddings model
model = SentenceTransformer('flax-sentence-embeddings/all_datasets_v4_MiniLM-L6', cache_folder='../pretrained')



In [3]:
# Load sun titles generated by gpt4
with open("../datasets/sun-related-wikititle.txt", 'r') as f:
    astro_title_list = f.read().splitlines()

In [47]:
# Assuming you have downloaded the parquet files and stored them in a directory
data_directory = "/mnt/geogpt-gpfs/llm-course/public/datasets/wikipedia/20231101.en/"
file = "train-00000-of-00041.parquet"

file = os.path.join(data_directory, file)

In [48]:
table = pq.read_table(file)

In [49]:
table

pyarrow.Table
id: string
url: string
title: string
text: string
----
id: [["12","39","290","303","305",...,"2448","2452","2457","2459","2460"],["2466","2467","2470","2471","2472",...,"4563","4565","4566","4567","4568"],...,["5264071","5264079","5264082","5264093","5264105",...,"5273111","5273119","5273122","5273132","5273137"],["5273146","5273161","5273163","5273168","5273182",...,"5275580","5275584","5275590","5275592","5275597"]]
url: [["https://en.wikipedia.org/wiki/Anarchism","https://en.wikipedia.org/wiki/Albedo","https://en.wikipedia.org/wiki/A","https://en.wikipedia.org/wiki/Alabama","https://en.wikipedia.org/wiki/Achilles",...,"https://en.wikipedia.org/wiki/Action%20Against%20Hunger","https://en.wikipedia.org/wiki/AW","https://en.wikipedia.org/wiki/Apoptosis","https://en.wikipedia.org/wiki/Appomattox","https://en.wikipedia.org/wiki/Anal%20sex"],["https://en.wikipedia.org/wiki/Aarau","https://en.wikipedia.org/wiki/Aargau","https://en.wikipedia.org/wiki/Aba","https://en.wikipedia

In [11]:
df = table.to_pandas()
title_list = df['title'].tolist()
title_list[:5]

['Anarchism', 'Albedo', 'A', 'Alabama', 'Achilles']

In [19]:
df.columns

Index(['id', 'url', 'title', 'text'], dtype='object')

In [10]:
# Embedding astronomy titles
target_embeddings = model.encode(astro_title_list, show_progress_bar=True, batch_size=1024, convert_to_numpy=False)
target_embeddings = torch.vstack(target_embeddings)
target_embeddings.shape

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([95, 384])

In [12]:
# Embedding dolma titles
source_embeddings = model.encode(title_list, show_progress_bar=True, batch_size=1024, convert_to_numpy=False, normalize_embeddings=True)
source_embeddings = torch.vstack(source_embeddings)
source_embeddings.shape

Batches:   0%|          | 0/153 [00:00<?, ?it/s]

torch.Size([156289, 384])

In [14]:
source_embeddings = source_embeddings.to('cuda:0')

In [15]:
title_similar_matrix = torch.matmul(target_embeddings, source_embeddings.T)
print(title_similar_matrix.shape)

source_max_similar = title_similar_matrix.max(dim=0).values

torch.Size([95, 156289])


In [34]:
threshold = 0.8
idx_over_threshold = torch.where(source_max_similar >= threshold)[0].cpu().numpy()

In [35]:
idx_over_threshold

array([    23,     96,   2373,   3946,   5223,  15997,  27909,  30773,
        33504,  47898,  49801,  53597,  69645,  71235,  74928,  77251,
        82108,  84133,  97299, 111662, 112226, 114093, 116313, 127814,
       131963, 133843, 136157, 138924, 142784])

In [38]:
# Save to jsonl file
with open("../datasets/astro_text-wiki.jsonl", 'a') as f:
    for idx in idx_over_threshold:
        f.write(json.dumps({
            'id': df['id'][idx],
            'title': df['title'][idx],
            'source': file,
            'text': df['text'][idx]
        }) + '\n')

In [37]:
list(df['title'][idx_over_threshold])

['Astronomer',
 'Amateur astronomy',
 'Physical cosmology',
 'Sunshine (disambiguation)',
 'Winter Equinox',
 'Sunglasses at Night',
 'Solar tracker',
 'Sunspot (disambiguation)',
 'Sunblock (band)',
 'Solar eclipses in fiction',
 'The Sound of Sunshine',
 'The Last Sunset',
 'Black Sunshine',
 'Chasing Daylight',
 'Solar eclipse of June 21, 2001',
 'Sunset Studies',
 'MV Zenith',
 'South of Sunset',
 'Solar eclipse of August 21, 2017',
 'Helium (disambiguation)',
 'Sunburn (disambiguation)',
 'Suncatcher',
 'California Sunshine',
 'Solar eclipse',
 'Liquid Sunshine',
 'Astronomy (disambiguation)',
 'Sundial (disambiguation)',
 'Aviator sunglasses',
 'Jesús Corona']

In [42]:
import time

start = time.time()
model.encode("astronomy")
time.time() - start

0.010596752166748047

In [44]:
with open("../datasets/astro_text-wiki.jsonl", 'r') as f:
    for line in f:
        data = json.loads(line)
        text = data['text']

start = time.time()
model.encode(text*10000)
time.time() - start

1.21871018409729

In [46]:
with open('../datasets/test.txt', 'w') as f:
    f.write(text*10000)

In [1]:
import json
text_list = []
with open("../datasets/astro_text-wiki.jsonl", 'r') as f:
    for line in f:
        data = json.loads(line)
        text = data['text']
        text_list.append(text)

with open("../datasets/astro-wiki.txt", 'w') as f:
    f.write("\n".join(text_list))