In [None]:
# Install python dependencies
%pip install torch transformers huggingface_hub omegaconf datasets==2.16.1 
# Optinal python packages for better user experience
%pip install ipywidgets nbconvert

In [None]:
# Import necessary libraries
import torch
import omegaconf
import collections
import os
import re
from pathlib import Path
from typing import Any
from collections import OrderedDict
from transformers import DPRContextEncoder, AutoTokenizer, DPRConfig, GPT2TokenizerFast
from huggingface_hub import hf_hub_download
from datasets import load_dataset, Dataset, concatenate_datasets
from concurrent.futures import ProcessPoolExecutor, wait
import time

# Setup external services authentication
HF_TOKEN = os.getenv('HF_TOKEN')

# Configure cache settings
CACHE_DIR = Path("./cache")
CACHE_DIR.mkdir(exist_ok=True)
CORPUS_CACHE_DIR = CACHE_DIR / "corpus_embeddings"

In [None]:
def rename_keys_substring(ordered_dict: OrderedDict[str, Any], find_pattern, replace_pattern):
    """
    Rename keys in an OrderedDict by replacing substring occurrences using regular expressions.
    
    Args:
        ordered_dict: The OrderedDict to modify
        find_pattern: The regex pattern to find in keys
        replace_pattern: The replacement pattern (can include backreferences like \\1, \\2)
    
    Returns:
        New Mapping with renamed keys
    """
    new_dict = OrderedDict[str, Any]()
    compiled_pattern = re.compile(find_pattern)
    
    for key, value in ordered_dict.items():
        if not compiled_pattern.search(key):
            continue
            
        new_key = compiled_pattern.sub(replace_pattern, key)
        new_dict[new_key] = value
    return new_dict


In [None]:
def setup_model_on_device(device: str) -> tuple[DPRContextEncoder, GPT2TokenizerFast]:
    """
    Setup model on the specified device.

    Args:
        device: Device to load the model on, either 'cuda' or 'cpu'.

    Returns:
        Tuple containing the context encoder and tokenizer.
    """
    torch.serialization.add_safe_globals(
        [
            omegaconf.dictconfig.ContainerMetadata,
            omegaconf.dictconfig.DictConfig,
            omegaconf.base.Metadata,
            omegaconf.nodes.AnyNode,
            omegaconf.listconfig.ListConfig,
            collections.defaultdict,
            Any,
            dict,
            list,
            int,
        ]
    )

    # Load model state dict (shared across all GPUs)
    checkpoint_path = hf_hub_download(
        repo_id="NTU-NLP-sg/xCodeEval-nl-code-starencoder-ckpt-37",
        filename="dpr_biencoder.37.pt",
        repo_type="model",
        token=HF_TOKEN,
    )
    state_dict = torch.load(checkpoint_path, map_location=device)

    # Retrieve fine-tuned weights
    ctx_state_dict = rename_keys_substring(
        state_dict["model_dict"],
        r"ctx_model\.(embeddings|encoder)\.([Ll]ayer|token|word|position_embeddings)",
        r"ctx_encoder.bert_model.\1.\2",
    )

    # Initialize encoder
    pretrained_model_name = state_dict["encoder_params"]["encoder"][
        "pretrained_model_cfg"
    ]
    encoder_config = DPRConfig.from_pretrained(
        pretrained_model_name,
        token=HF_TOKEN,
    )

    ctx_encoder = DPRContextEncoder.from_pretrained(
        None, state_dict=ctx_state_dict, config=encoder_config, token=HF_TOKEN
    )
    ctx_encoder = ctx_encoder.to(device).eval()

    # Initialize tokenizer
    tokenizer: GPT2TokenizerFast = AutoTokenizer.from_pretrained(
        pretrained_model_name, config=encoder_config
    )
    tokenizer.pad_token = tokenizer.eos_token

    return ctx_encoder, tokenizer

In [None]:
def process_shard_on_gpu(process_id: int, shard: Dataset) -> str:
    """
    Process a single shard of the dataset on the specified GPU.
    
    Args:
        process_id: The ID of the GPU to use for processing
        shard: The dataset shard to process
    
    Returns:
        Path to Dataset with embeddings added
    """
    print(f"Process {process_id}: Starting processing of {len(shard)} documents")
    # Set device for this process
    deviceType = "cuda" if torch.cuda.is_available() else "cpu"
    device = f"{deviceType}:{process_id % torch.cuda.device_count()}"  # Use modulo to handle multiple GPUs
    
    # Load model on this specific GPU
    ctx_encoder_gpu, tokenizer_gpu = setup_model_on_device(device)
    
    # Create embedding function for this GPU
    def embed_codes_gpu(batch):
        inputs = tokenizer_gpu(
            batch["source_code"],
            padding="max_length",
            truncation=True,
            max_length=1024,
            return_tensors="pt"
        )
        inputs = {k: v.to(device, non_blocking=True) for k, v in inputs.items()}

        # bfloat16 is more memory efficient on GPUs like RTX 3090
        # but has a lower precision than float32
        # bfloat16: 16 bits, 1 sign bit, 8 exponent bits, 7 mantissa bits
        # float16: 16 bits, 1 sign bit, 5 exponent bits, 10 mantissa bits
        # float32: 32 bits, 1 sign bit, 8 exponent bits, 23 mantissa bits
        with torch.no_grad(), torch.amp.autocast(device_type=deviceType, dtype=torch.bfloat16):
            embeddings = ctx_encoder_gpu(**inputs).pooler_output
            embeddings_cpu = embeddings.detach().cpu().to(torch.float32).tolist()
            # print(f"Process {process_id}: GPU memory usage: {torch.cuda.memory_allocated(device) / (1024 ** 3):.2f} GB")
            # print(f"Process {process_id}: GPU allocated memory: {torch.cuda.memory_reserved(device) / (1024 ** 3):.2f} GB")
            return {"embedding": embeddings_cpu}
    
    # Process the shard
    try:
        shard_directory = CORPUS_CACHE_DIR / f"shard_{process_id}"
        shard_directory.mkdir(parents=True, exist_ok=True)
        shard.map(
            embed_codes_gpu,
            batched=True,
            batch_size=48,
            desc=f"Process {process_id}",
            cache_file_name=str(shard_directory / f"shard_{process_id}"),
        )

        print(f"Process {process_id}: Successfully processed {len(shard)} documents")
        
        # Save the processed dataset to a specific location for later loading
        # processed_shard.save_to_disk(str(shard_cache_dir))
        
        # Ensure shard directory has content
        if not any(shard_directory.iterdir()):
            raise FileNotFoundError(f"Processed shard file not found: {shard_directory}")
        
        return str(shard_directory)
        
    except Exception as e:
        print(f"Process {process_id}: Error during processing: {e}")
        raise e

In [None]:
def process_with_processpool(corpus: Dataset):
    """
    Process the dataset using a process pool for true parallel execution.
    Each process gets its own CUDA context.
    """
    
    # Get number of available GPUs
    num_gpus = torch.cuda.device_count()
    num_processes = num_gpus * 2  # Use more GPUs per process for better load balancing
    print(f"Found {num_gpus} GPUs available")
    
    if num_gpus < 1:
        raise RuntimeError("At least one GPU is required for this operation.")

    # Calculate shard sizes
    total_docs = len(corpus)
    docs_per_process = total_docs // num_processes
    remainder = total_docs % num_processes

    print(f"Total documents: {total_docs}")
    print(f"Documents per process: {docs_per_process}")
    print(f"Remainder documents: {remainder}")
    
    # Create shards and distribute workload across Processes
    shards = []
    start_idx = 0
    for process_id in range(num_processes):
        # Give remainder documents to first few processes
        shard_size = docs_per_process + (1 if process_id < remainder else 0)
        end_idx = start_idx + shard_size
        
        shard = corpus.select(range(start_idx, end_idx))
        shards.append((process_id, shard))

        print(f"Process {process_id}: Processing documents {start_idx} to {end_idx-1} ({shard_size} docs)")
        start_idx = end_idx
    
    # Process shards in parallel using processes
    with ProcessPoolExecutor(max_workers=num_processes) as executor:
        futures = []
        for process_id, shard in shards:
            future = executor.submit(process_shard_on_gpu, process_id, shard)
            futures.append((process_id, future))

        print(f"[{time.strftime('%H:%M:%S')}] Starting parallel processing on {num_processes} Processes...")        
        wait(futures)  # Wait for all futures to complete
    
        print(f"[{time.strftime('%H:%M:%S')}] All Processes completed processing!")
        print(f"[{time.strftime('%H:%M:%S')}] Please wait for the main process to combine results...")    
        shard_results: list[Dataset] = []
        for process_id, future in futures:
            shard_dataset_path: str = future.result()
            shard_results.append(Dataset.load_from_disk(shard_dataset_path))
    
    # Combine all shard results
    print("Combining results from all Processes...")

    # Create final dataset
    corpus_with_embeddings = concatenate_datasets(shard_results)
    
    print(f"Combined dataset created with {len(corpus_with_embeddings)} documents")
    return corpus_with_embeddings

In [None]:
# Check if cache exists and load, otherwise process corpus
if CORPUS_CACHE_DIR.exists():
    try:
        print(f"Loading corpus cache from {CORPUS_CACHE_DIR}")
        corpus_with_embeddings = Dataset.load_from_disk(str(CORPUS_CACHE_DIR))
        print(f"Cache loaded successfully. Documents: {len(corpus_with_embeddings)}")
    except Exception as e:
        print(f"Failed to load cache: {e}")
        print("Cache directory exists but contains invalid data. Recreating cache...")
        corpus_with_embeddings = None
else:
    corpus_with_embeddings = None

if corpus_with_embeddings is None:
    print("No cache found. Processing corpus...")
    
    # Load corpus dataset
    corpus = load_dataset(
        "NTU-NLP-sg/xCodeEval",
        "retrieval_corpus",
        trust_remote_code=True,
        split="test",
        revision="467d25a839086383794b58055981221b82c0d107",
        token=HF_TOKEN,
    )
    
    # Generate embeddings
    corpus_with_embeddings = process_with_processpool(corpus)
    
    print("Embeddings generated successfully!")
    print(f"Saving corpus cache to {CORPUS_CACHE_DIR}")
    corpus_with_embeddings.save_to_disk(str(CORPUS_CACHE_DIR))
    print("Cache saved successfully!")

# Display information about the processed corpus
print(f"\nCorpus information:")
print(f"Number of documents: {len(corpus_with_embeddings)}")
if len(corpus_with_embeddings) > 0:
    print(f"Embedding dimension: {len(corpus_with_embeddings[0]['embedding'])}")
    print(f"Sample document keys: {list(corpus_with_embeddings[0].keys())}")
    print(f"Sample source code (first 200 chars): {corpus_with_embeddings[0]['source_code'][:200]}...")