# Embeddings pipeline
Here, we use MolFormer, a BERT-trained encoder model to encode our SMILES (without using AIS) into a latent space for similarity pairing. This outputs into a pairs.json file, with the following structure:
[[high_scoring_smile, low_scoring_smile, L2_similarity], ...]

We also save various checkpoint embeddings (notably, top and bottom json files) for the top 3% and bottom 15%, since these tend to take a while to embed all.

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer

# Load in MolFormer for mapping to latent space
model = AutoModel.from_pretrained("ibm/MoLFormer-XL-both-10pct", deterministic_eval=True, trust_remote_code=True).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True)

After loading the model and tokenizer, we read any json files in the working directory (which are fragmented due to earlier saving behavior from docking), and read in SMILES with their associated docking score. The top 3% and bottom 15% get saved, along with their embedding after they are tokenized with MoLFormer's tokenizer and passed through the model. 

In [None]:
import json
import torch
import os

working_directory = '/kaggle/input/total-dataset'
combined_scores = {}

# Go over each json fragment made from docking
for filename in os.listdir(working_directory):
    # Check if the file is a JSON file
    if filename.endswith('.json'):
        full_path = os.path.join(working_directory, filename)
        
        # Load the JSON data from the file
        with open(full_path) as f:
            data = json.load(f)
        combined_scores.update(data)

# Convert the dictionary into a list of tuples [(key, score), ...] and sort by score
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

# Calculate the index ranges for top 3% and bottom 15%
num_keys = len(sorted_items)

top_3_percent_index = int(num_keys * 0.97)  # Starting index for top 3%
bottom_15_percent_index = int(num_keys * 0.15)  # Ending index (exclusive) for bottom 15%

# Extract the keys for top 3% and bottom 15%
top_3_percent_keys = [item[0] for item in sorted_items[top_3_percent_index:]]
bottom_15_percent_keys = [item[0] for item in sorted_items[:bottom_15_percent_index]]
        
combined_output = {}
# Assuming top_3_percent_keys is a list of keys you want to process
batch_size = 500
top_embeddings = []  # This will store all the embeddings

# Process in batches of 500
for i in range(0, len(top_3_percent_keys), batch_size):
    batch_keys = top_3_percent_keys[i:i + batch_size]

    # Tokenize the current batch
    encoded_batch = tokenizer(batch_keys, padding='max_length', return_tensors='pt', truncation=True).to("cuda")
    
    with torch.no_grad():  # Don't compute gradients to save memory and computation
        model_output = model(**encoded_batch)

        # Assuming the model has a pooler_output attribute (like BERT)
        # Convert to float16 to save memory (if necessary)
        batch_embeddings = model_output.pooler_output.to(torch.float16).squeeze().tolist()
    print("Finished top", i)
    # Extend the embeddings list with the embeddings from the current batch
    top_embeddings.extend(batch_embeddings)
    
# Assuming embeddings is a list of embeddings corresponding to each text
for i, key in enumerate(top_3_percent_keys):
    combined_output[key] = {'embedding': top_embeddings[i], 'score': combined_scores[key]}
# Once all files have been processed, write the combined results to a single JSON file
with open('combined_output_top.json', 'w') as f_out:
    json.dump(combined_output, f_out)

bottom_embeddings = []  # This will store all the embeddings

# Process in batches of 500
for i in range(0, len(bottom_15_percent_keys), batch_size):
    batch_keys = bottom_15_percent_keys[i:i + batch_size]

    # Tokenize the current batch
    encoded_batch = tokenizer(batch_keys, padding='max_length', return_tensors='pt', truncation=True).to("cuda")
    
    with torch.no_grad():  # Don't compute gradients to save memory and computation
        model_output = model(**encoded_batch)

        # Assuming the model has a pooler_output attribute (like BERT)
        # Convert to float16 to save memory (if necessary)
        batch_embeddings = model_output.pooler_output.to(torch.float16).squeeze().tolist()
    
    print("Finished bottom", i)
    # Extend the embeddings list with the embeddings from the current batch
    bottom_embeddings.extend(batch_embeddings)

combined_output = {}
# Assuming embeddings is a list of embeddings corresponding to each text
for i, key in enumerate(bottom_15_percent_keys):
    combined_output[key] = {'embedding': bottom_embeddings[i], 'score': combined_scores[key]}
    
# Once all files have been processed, write the combined results to a single JSON file
with open('combined_output_bottom.json', 'w') as f_out:
    json.dump(combined_output, f_out)

Lastly, pair up the top 3% and bottom 15%, where we exhaust the 3% based on the most similar embedding (in terms of L2) in the bottom 15%. Since we remove both from the respective sets, these pairings are unique, and none of the top or bottom vectors are repeated.

In [None]:
import numpy as np
import json
import faiss
import random

# Load embeddings from file 
def load_embeddings(file_path, shuffle=False):
    with open(file_path) as f:
        data = json.load(f)
        keys = list(data.keys())
        if shuffle:
            np.random.shuffle(keys)
        embeddings = np.array([data[key]['embedding'] for key in keys]).astype('float32')
    return keys, embeddings

# Load top and bottom embeddings
top_keys, top_embeddings = load_embeddings('combined_output_top.json', shuffle=True)
bottom_keys, bottom_embeddings = load_embeddings('combined_output_bottom.json')

# Normalize both for optimizations
faiss.normalize_L2(top_embeddings)
faiss.normalize_L2(bottom_embeddings)

# Create an index for the bottom embeddings
index = faiss.IndexFlatIP(bottom_embeddings.shape[1])
index.add(bottom_embeddings)

# Batch by 100 each so that we don't run out of memory
batch_size = 100
# get the amount of embeddings
n_top = top_embeddings.shape[0]
n_bottom = bottom_embeddings.shape[0]
# keep track of what we've already used to not repeat
used_bottom_indices = set()
pairs = []

print(f"Starting processing {n_top} top embeddings in batches of {batch_size}...")

for start_idx in range(0, n_top, batch_size):
    end_idx = min(start_idx + batch_size, n_top)
    # Look for the nearest neighbors of the top embeddings in the bottom embeddings
    D, I = index.search(top_embeddings[start_idx:end_idx], n_bottom)
    
    for i, indices in enumerate(I):
        top_key = top_keys[start_idx + i]
        found_match = False
        
        for idx in indices:
            # If we've already used this bottom index, skip it
            if idx not in used_bottom_indices:
                # Otherwise, we've found a unique pair
                bottom_key = bottom_keys[idx]
                # Get the similarity score
                similarity_score = D[i][np.where(indices == idx)[0][0]]
                # Add the pair to the list with it's score
                pairs.append((top_key, bottom_key, float(similarity_score)))
                used_bottom_indices.add(idx)
                found_match = True
                break
        
        if not found_match:
            print(f"No available unique bottom key for top key {top_key}. Consider increasing batch size or revising data.")
    
    # Print progress
    print(f"Processed {end_idx} / {n_top} top embeddings. Unique pairs found so far: {len(pairs)}")

# Optionally, sort the pairs by similarity score (highest first)
pairs.sort(key=lambda x: x[2], reverse=True)

# Save the results
with open('pairs.json', 'w') as f_out:
    json.dump(pairs, f_out, ensure_ascii=False, indent=4)

print("Processing complete. Results saved.")