# Setup

In [None]:
# Install dependencies
%pip install torch transformers huggingface_hub omegaconf
# Optinal packages for better user experience
%pip install tqdm ipywidgets nbconvert

In [None]:
# Import necessary libraries
import torch
import omegaconf
import collections
import os
import re
from typing import Any
from collections import OrderedDict
from collections.abc import Mapping
from transformers import DPRQuestionEncoder, DPRContextEncoder, AutoTokenizer, AutoModel, DPRPreTrainedModel
from huggingface_hub import hf_hub_download

# Setup external services authentication
HF_TOKEN = os.getenv('HF_TOKEN')

# Model Loading

In [None]:
def rename_keys_substring(ordered_dict: OrderedDict[str, Any], find_pattern, replace_pattern):
    """
    Rename keys in an OrderedDict by replacing substring occurrences using regular expressions.
    
    Args:
        ordered_dict: The OrderedDict to modify
        find_pattern: The regex pattern to find in keys
        replace_pattern: The replacement pattern (can include backreferences like \\1, \\2)
    
    Returns:
        New Mapping with renamed keys
    """
    new_dict = OrderedDict[str, Any]()
    compiled_pattern = re.compile(find_pattern)
    
    for key, value in ordered_dict.items():
        if not compiled_pattern.search(key):
            continue
            
        new_key = compiled_pattern.sub(replace_pattern, key)
        new_dict[new_key] = value
    return new_dict


In [32]:
# Ensure that the necessary types are registered for safe deserialization
torch.serialization.add_safe_globals(
    [
        omegaconf.dictconfig.ContainerMetadata,
        omegaconf.dictconfig.DictConfig,
        omegaconf.base.Metadata,
        omegaconf.nodes.AnyNode,
        omegaconf.listconfig.ListConfig,
        collections.defaultdict,
        Any,
        dict,
        list,
        int,
    ]
)

# Load model from checkpoint
checkpoint_path = hf_hub_download(
    repo_id="NTU-NLP-sg/xCodeEval-nl-code-starencoder-ckpt-37",
    filename="dpr_biencoder.37.pt",
    repo_type="model",
    token=HF_TOKEN,
)
state_dict = torch.load(checkpoint_path, map_location="cpu")

# Initialize encoders with StarEncoder architecture
question_encoder = DPRQuestionEncoder.from_pretrained(
    "bigcode/starencoder", token=HF_TOKEN
)
ctx_encoder = DPRContextEncoder.from_pretrained("bigcode/starencoder", token=HF_TOKEN)


# Load fine-tuned weights
question_encoder.load_state_dict(
    rename_keys_substring(
        state_dict["model_dict"],
        r"question_model\.(embeddings|encoder)\.([Ll]ayer|token|word|position_embeddings)",
        r"question_encoder.bert_model.\1.\2",
    )
)
ctx_encoder.load_state_dict(
    rename_keys_substring(
        state_dict["model_dict"],
        r"ctx_model\.(embeddings|encoder)\.([Ll]ayer|token|word|position_embeddings)",
        r"ctx_encoder.bert_model.\1.\2",
    )
)

# Check if CUDA is available and set device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

question_encoder = question_encoder.to(device).eval()
ctx_encoder = ctx_encoder.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("bigcode/starencoder")

You are using a model of type bert to instantiate a model of type dpr. This is not supported for all configurations of models and can yield errors.
Some weights of DPRQuestionEncoder were not initialized from the model checkpoint at bigcode/starencoder and are newly initialized: ['question_encoder.bert_model.embeddings.LayerNorm.bias', 'question_encoder.bert_model.embeddings.LayerNorm.weight', 'question_encoder.bert_model.embeddings.position_embeddings.weight', 'question_encoder.bert_model.embeddings.token_type_embeddings.weight', 'question_encoder.bert_model.embeddings.word_embeddings.weight', 'question_encoder.bert_model.encoder.layer.0.attention.output.LayerNorm.bias', 'question_encoder.bert_model.encoder.layer.0.attention.output.LayerNorm.weight', 'question_encoder.bert_model.encoder.layer.0.attention.output.dense.bias', 'question_encoder.bert_model.encoder.layer.0.attention.output.dense.weight', 'question_encoder.bert_model.encoder.layer.0.attention.self.key.bias', 'question_encod

Using device: cpu
