In [None]:
import numpy as np
from datasets import load_dataset
import random
import json
from sentence_transformers import SentenceTransformer
import torch
import time
from tqdm import tqdm
import pandas as pd
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


In [None]:
# Load dataset, coding stack exchange q and a dataset 
ds = load_dataset("bigscience-data/roots_code_stackexchange")
print(f"Dataset loaded with {len(ds['train'])} examples")
train_data = ds['train']

# sentence transformer model for encoding
model_name = "all-MiniLM-L6-v2"  # model for embeddings, lightweight
encoder = SentenceTransformer(model_name)
encoder.to(device)  # Move model to GPU if available
print(f"Loaded encoder model: {model_name} on {encoder.device}")

Dataset loaded with 9825059 examples
Loaded encoder model: all-MiniLM-L6-v2 on cpu


SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

In [None]:
def get_text(item):
    text = item["text"]
    if "Q:\n\n" in text and "\n\nA:\n\n" in text:
        return text.split("Q:\n\n")[1].split("\n\nA:\n\n")[0]
    return None

def encode_text(text):
    # Encode the text using the encoder model
    with torch.no_grad():
        embeddings = encoder.encode(
                        text,
                        convert_to_tensor=True,
                        show_progress_bar=False,
                        device=device  # Explicitly use the detected device
                    )
    return embeddings

# Create output directory if it doesn't exist
output_dir = "embeddings/code_stack_exchange"
os.makedirs(output_dir, exist_ok=True)

# Final output file
final_parquet = "embeddings.parquet"

# Configure batch sizes (adjust these as needed)
encoding_batch_size = 1000   # Number of texts to encode in a single batch
save_batch_size = 250000      # Number of embeddings to save in each parquet file
total_examples = len(train_data)  # Change this to 10 for initial testing

In [99]:
# Initialize tracking variables
processed_count = 0
batch_count = 0
batch_files = []

# Main processing loop with overall progress bar
with tqdm(total=total_examples, desc="Processing dataset") as pbar:
    # Process data in chunks for saving to parquet
    save_batch_data = []
    
    # Process in encoding batches
    for start_idx in range(0, total_examples, encoding_batch_size):
        # Determine batch range
        end_idx = min(start_idx + encoding_batch_size, total_examples)
        batch_size = end_idx - start_idx
        
        # Collect valid texts and their indices for this encoding batch
        texts_to_encode = []
        indices_to_encode = []
        
        for i in range(start_idx, end_idx):
            text = get_text(train_data[i])
            if text:
                texts_to_encode.append(text)
                indices_to_encode.append(i)
        
        # Skip if no valid texts in this batch
        if not texts_to_encode:
            pbar.update(batch_size)
            continue
            
        # Batch encode all texts at once 
        with torch.no_grad():  # Disable gradient calculation for inference
                embeddings = encoder.encode(
                    texts_to_encode,
                    convert_to_tensor=True,
                    show_progress_bar=False,
                    device=device  # Explicitly use the detected device
                )
        # Process each embedding in the batch
        for idx, emb in zip(indices_to_encode, embeddings):
            # Store the embedding with its original index
            if idx == 20:
                 print(emb.cpu().numpy().tolist()[0])
            save_batch_data.append({
                'index': idx,
                'embedding': emb.cpu().numpy().tolist()  # Convert tensor to list for storage
            })
            
            processed_count += 1
            
            # If we've reached the save batch size, write to parquet
            if len(save_batch_data) >= save_batch_size:
                df_batch = pd.DataFrame(save_batch_data)
                
                # Get the start and end indices for this batch
                min_index = min(item['index'] for item in save_batch_data)
                max_index = max(item['index'] for item in save_batch_data)
                
                # Name the file with the index range
                batch_file = f"{output_dir}/batch_{min_index}_{max_index}.parquet"
                df_batch.to_parquet(batch_file, engine='pyarrow', index=False)
                
                batch_files.append(batch_file)
                batch_count += 1
                print(f"Saved batch {batch_count} with {len(save_batch_data)} embeddings. Range: {min_index}-{max_index}. Total processed: {processed_count}")
                
                # Clear the batch data to free memory
                save_batch_data = []

                if torch.cuda.is_available():
                        torch.cuda.empty_cache()
        
        # Update the progress bar
        pbar.update(batch_size)

# Save any remaining embeddings
if save_batch_data:
    df_batch = pd.DataFrame(save_batch_data)
    
    # Get the start and end indices for the final batch
    min_index = min(item['index'] for item in save_batch_data)
    max_index = max(item['index'] for item in save_batch_data)
    
    # Name the file with the index range
    batch_file = f"{output_dir}/batch_{min_index}_{max_index}.parquet"
    df_batch.to_parquet(batch_file, engine='pyarrow', index=False)
    
    batch_files.append(batch_file)
    batch_count += 1
    print(f"Saved final batch {batch_count} with {len(save_batch_data)} embeddings. Range: {min_index}-{max_index}. Total processed: {processed_count}")

print(f"Processing complete. Total embeddings saved: {processed_count} in {batch_count} batches")

Processing dataset:   2%|▎         | 50/2000 [00:02<01:46, 18.30it/s]

-0.021913018077611923


Processing dataset:   2%|▎         | 50/2000 [00:03<02:34, 12.65it/s]


KeyboardInterrupt: 

In [59]:
''''
'BE CAREFUL THIS MIGHT NOT FIT INTO RAM!!!!!!'
'BE CAREFUL THIS MIGHT NOT FIT INTO RAM!!!!!!'
'BE CAREFUL THIS MIGHT NOT FIT INTO RAM!!!!!!'
'BE CAREFUL THIS MIGHT NOT FIT INTO RAM!!!!!!'
'BE CAREFUL THIS MIGHT NOT FIT INTO RAM!!!!!!'
'BE CAREFUL THIS MIGHT NOT FIT INTO RAM!!!!!!'
'BE CAREFUL THIS MIGHT NOT FIT INTO RAM!!!!!!'
'BE CAREFUL THIS MIGHT NOT FIT INTO RAM!!!!!!'
'''


# Now combine all batches into a single parquet file
print("\nCreating final consolidated parquet file...")
all_data = []

# Load and combine all batch files
for batch_file in tqdm(batch_files, desc="Consolidating batches"):
    df = pd.read_parquet(batch_file)
    all_data.append(df)

# Combine into a single DataFrame
if all_data:
    final_df = pd.concat(all_data, ignore_index=True)
    # final_df.drop(columns=['index'], inplace=True)
    
    # Save to a single parquet file
    final_df.to_parquet(final_parquet, engine='pyarrow', index=False)
    print(f"Created final file with {len(final_df)} embeddings: {final_parquet}")
else:
    print("No data processed. Check your dataset and text extraction function.")


Creating final consolidated parquet file...


Consolidating batches: 100%|██████████| 4/4 [00:00<00:00, 76.41it/s]

Created final file with 2000 embeddings: embeddings.parquet





In [93]:
# Example of how to access embeddings by index from the final parquet file
print("\nExample of how to access embeddings by index from the final file:")
df = pd.read_parquet(final_parquet)

# Access by index example
index_to_access = 20  # Example index to access
embedding_row = df.iloc[index_to_access]['embedding']
print(embedding_row[0])


Example of how to access embeddings by index from the final file:
-0.021913018077611923


In [98]:
text = get_text(train_data[index_to_access])
embed = encode_text(text)
print(embed.cpu().numpy().tolist()[0])

-0.02191298082470894
