In [7]:
import os
from sentence_transformers import SentenceTransformer
import gzip
import json
import torch

In [8]:
# Load sentence embeddings model
model = SentenceTransformer('flax-sentence-embeddings/all_datasets_v4_MiniLM-L6')



In [9]:
# Load astronomy or physics titles generated by gpt4
with open("../datasets/astro-title.txt", 'r') as f:
    astro_title_list = f.read().splitlines()

In [31]:
# Load dolma_v1_7 .json.gz
root = "/mnt/geogpt-gpfs/llm-course/public/datasets/"
file = "dolma_v1_7/CC_head/documents/cc_en_head-0003.json.gz"
file = os.path.join(root, file)

title_list = []
id_list = []
text_list = []
with gzip.open(file, 'rt') as f:
    for line in f:
        doc = json.loads(line)
        title_list.append(doc['metadata']['title'])
        id_list.append(doc['id'])
        text_list.append(doc['text'])
print(len(title_list))

In [None]:
# Embedding astronomy titles
target_embeddings = model.encode(astro_title_list, show_progress_bar=True, batch_size=1024, convert_to_numpy=False)
target_embeddings = torch.vstack(target_embeddings)
target_embeddings.shape

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([286, 384])

In [None]:
# Embedding dolma titles
source_embeddings = model.encode(title_list, show_progress_bar=True, batch_size=1024, convert_to_numpy=False, device=torch.device("cuda:7"), normalize_embeddings=True)
source_embeddings = torch.vstack(source_embeddings)
source_embeddings.shape

Batches:   0%|          | 0/1695 [00:00<?, ?it/s]

torch.Size([1735398, 384])

In [None]:
# Calculate the similarity matrix of astronomy titles and dolma titles.
# Columns of the matrix represents dolma titles, and the rows represent astronomy titles.
device=torch.device("cuda:7")
title_similar_matrix = torch.matmul(target_embeddings.to(device), source_embeddings.T)
print(title_similar_matrix.shape)

source_max_similar = title_similar_matrix.max(dim=0).values

torch.Size([286, 1735398])


In [None]:
threshold = 0.8
idx_over_threshold = torch.where(source_max_similar >= threshold)[0].cpu().numpy()

# for idx in idx_over_threshold:
#     print(f"dolma text: {text_list[idx]}")
#     print(f"Dolma title: {title_list[idx]}")
#     print(f"Dolma id: {id_list[idx]}")
#     print("")

In [None]:
idx_over_threshold.shape

In [30]:
# Save to jsonl file
with open("../datasets/astro_text.jsonl", 'a') as f:
    for idx in idx_over_threshold:
        f.write(json.dumps({
            'id': id_list[idx],
            'source': file,
            'text': text_list[idx]
        }) + '\n')