# CodeSearchNet Data Source Notice

In [None]:
!mkdir /content/CodeSearchNet

mkdir: cannot create directory ‘/content/CodeSearchNet’: File exists


In [None]:
%%capture
!pip install docopt

After the CodeSearchNet dataset was archieved, the S3 bucket was taken offline. As a result, following the installation on the github installation guide will not work. A short illustration can be seen down below.

In [None]:
import os
from subprocess import call, check_call, CalledProcessError

destination_dir = "/content/CodeSearchNet"

if not os.path.exists(destination_dir):
    os.makedirs(destination_dir)
os.chdir(destination_dir)

try:
    language = "python"
    check_call(['wget', f'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/{language}.zip', '-O', f'{language}.zip'])
    check_call(['unzip', f'{language}.zip'])
    check_call(['rm', f'{language}.zip'])
except CalledProcessError as e:
    print(f"Error: {e}")
    print(f"Error executing command {e.cmd}")
    print(f"Returned code {e.returncode}")

Error: Command '['wget', 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/python.zip', '-O', 'python.zip']' returned non-zero exit status 8.
Error executing command ['wget', 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/python.zip', '-O', 'python.zip']
Returned code 8


Instead, we download the dataset from Hugging Face. Updating `datasets` might not be necessary but might sometimes be helpful to avoid errors concering caching in the local file system

# Data Fetching

In [1]:
%%capture

%pip install -U datasets

In [2]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("code_search_net", "python")

train_data = dataset["train"]
test_data = dataset["test"]
validation_data = dataset["validation"]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

code_search_net.py: 0.00B [00:00, ?B/s]

The repository for code_search_net contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/code_search_net.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


python.zip:   0%|          | 0.00/941M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/412178 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/22176 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23107 [00:00<?, ? examples/s]

We can inspect the contents of the dataset object for the training, testing, and validation datasets.

In [3]:
print(train_data.features.keys())

dict_keys(['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'])


# Question Generation Pipeline

In [4]:
import itertools
import torch
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer
)
from nltk import sent_tokenize

from typing import (
    Tuple,
    Dict,
    Literal,
    List,
    Any,
    Generator,
    overload,
    Union
)

class QGPipeline:
    def __init__(
        self,
        model: str,
        qg_format: Literal["highlight"] = "highlight",
        exclude_after: List[str] = [],
        use_cuda: bool = False
    ):

        self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.qg_format = qg_format

        assert self.model.__class__.__name__ == "T5ForConditionalGeneration"

        self.device = torch.device("cuda" if torch.cuda.is_available() and use_cuda else "cpu")
        self.model.to(self.device)
        self.model.eval()
        self.use_cuda = use_cuda
        self._exclude_after = exclude_after
        print(f"Using {self.device}")

    def __call__(self, input: Union[Tuple[str, str], List[Tuple[str, str]]]):
        if isinstance(input, tuple):
            # Handle single input
            func_name, docstring = input
            questions = self._generate_questions(func_name, docstring)
            output = [{'answer': func_name, 'question': que} for que in questions]
            if output:
                 return output[0]
            else:
                 return {}

        elif isinstance(input, list) and all(isinstance(item, tuple) for item in input):
            # Handle batch input with proper error handling
            return self._process_batch_generator(input)
        else:
            raise TypeError("Invalid input type. Expected a tuple (func_name, docstring) or a list of such tuples.")


    def _process_batch_generator(self, batch_input: List[Tuple[str, str]]) -> Generator[Dict[str, Any], None, None]:
        """
        Process batch input and yield results with error handling per item
        """
        for i, (func_name, docstring) in enumerate(batch_input):
            try:
                questions = self._generate_questions(func_name, docstring)
                output = [{'answer': func_name, 'question': que} for que in questions]

                if output:
                    yield {
                        'success': True,
                        'index': i,
                        'function_name': func_name,
                        'docstring': docstring,
                        'model_output': output[0],
                        'error': None
                    }
                else:
                    yield {
                        'success': False,
                        'index': i,
                        'function_name': func_name,
                        'docstring': docstring,
                        'model_output': {},
                        'error': 'No questions generated'
                    }

            except Exception as e:
                yield {
                    'success': False,
                    'index': i,
                    'function_name': func_name,
                    'docstring': docstring,
                    'model_output': {},
                    'error': str(e)
                }

    def _generate_questions(self, func_name, docstring):
        #TODO: This can be re-written in a more forceful way for the llm
        inputs = self._prepare_inputs_for_question_extraction(func_name, docstring)

        inputs = self._tokenize(inputs, padding=True, truncation=True)

        with torch.no_grad():
            outs = self.model.generate(
                input_ids=inputs['input_ids'].to(self.device),
                attention_mask=inputs['attention_mask'].to(self.device),
                num_beams=4,

                max_length=32
            )

        questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]

        return questions

    def _tokenize(self, inputs, padding=True, truncation=True, add_special_tokens=True, max_length=512):
        tokenized_inputs = self.tokenizer(
            inputs,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            padding=padding,
            return_tensors="pt"
        )
        return tokenized_inputs

    def _prepare_inputs_for_question_extraction(self, func_name, docstring):
        #NOTE: experimental, consider removing :params and :return values
        #manual observation suggests the model struggles to understand the pupose of the function in their presense
        for string in self._exclude_after:
            param_idx = docstring.find(string)
            if param_idx != -1:
                docstring = docstring[:param_idx]
            docstring = docstring.strip()
        input = f"answer: <hl>The function is {func_name}<hl>. Context: {docstring} </s>"

        return [input]

    @property
    def exclude_after(self):
        return self._exclude_after

    @exclude_after.setter
    def exclude_after(self, value):
        self._exclude_after = value

In [5]:
finetuned_t5 = QGPipeline(model="valhalla/t5-base-qg-hl")

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/15.0 [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Using cpu


In [9]:
func_name = train_data[4]['func_name']
docstring = train_data[4]['func_documentation_string']
print(docstring)
finetuned_t5((func_name, docstring))

Read a single frame from the connection.


{'answer': 'WebSocketCommonProtocol.read_frame',
 'question': 'What is the function that reads a single frame from a connection?'}

The result looks promising. Let's run the model for the first 20 doc strings in our dataset

In [None]:
for i in range(10):
    func_name = train_data[i]['func_name']
    p = docstring = train_data[i]['func_documentation_string']
    print(f"========Sample{i+1}==========")
    idx = docstring.find(":param")
    if idx != -1:
        p = docstring[:idx]
    p = docstring[:].strip()
    print(f"Docstring: {docstring}")
    print(finetuned_t5(func_name, docstring))


Docstring: Display slug with level by language.
[{'answer': 'show_slug_with_level', 'question': 'What is the name of the function that displays a slug with level by language?'}]
Docstring: Render the last 10 revisions of a page content with a list using
        the ``pages/revisions.html`` template
[{'answer': 'show_revisions', 'question': 'What is the name of the function that shows the last 10 revisions of a page?'}]
Docstring: Method that parse the imageplaceholder template tag.
[{'answer': 'do_videoplaceholder', 'question': 'What is the name of the method that parses the imageplaceholder template tag?'}]
Docstring: Return Pages with given tag

    Syntax::

        {% get_pages_with_tag <tag name> as <varname> %}

    Example use:
        {% get_pages_with_tag "footer" as pages %}
[{'answer': 'do_get_pages_with_tag', 'question': 'What is the name of the function that returns pages with given tag?'}]
Docstring: Parses the XML run statistics file (GenerateFASTQRunStatistics.xml). In 

In [18]:
zipped5 = zip(train_data[:5]['func_name'], train_data[:5]['func_documentation_string'])

generator = finetuned_t5(list(zipped5))
print(generator)

for item in generator:
    ans, ques = item['function_name'], item['model_output']['question']
    print(f"========Sample==========")
    print(f"Question: {ques}")
    print(f"Answer: {ans}")
    print()

<generator object QGPipeline._process_batch_generator at 0x7833286165c0>
Question: What is the function that ensures that the WebSocket connection is open?
Answer: WebSocketCommonProtocol.ensure_open

Question: What is the function that reads incoming messages and puts them in a queue?
Answer: WebSocketCommonProtocol.transfer_data

Question: What is the function that reads a single message from the connection?
Answer: WebSocketCommonProtocol.read_message

Question: What is the function that reads a single data frame from a connection?
Answer: WebSocketCommonProtocol.read_data_frame

Question: What is the function that reads a single frame from a connection?
Answer: WebSocketCommonProtocol.read_frame



# Dataset Processor

In [13]:
import json
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import time
from typing import List, Dict, Any
from datasets import Dataset, DatasetDict

class DocstringDatasetProcessor:
    def __init__(self,
                 hf_dataset_name: str,
                 batch_size: int = 1000,
                 token: str = None,
                 save_locally: bool = False,
                 local_cache_dir: str = "./cache",
                 private_repo: bool = False):

        self.hf_dataset_name = hf_dataset_name
        self.batch_size = batch_size
        self.private_repo = private_repo
        self.local_cache_dir = Path(local_cache_dir)
        self.local_cache_dir.mkdir(exist_ok=True)
        self.save_locally = save_locally

        self.processed_count = 0
        self.failed_count = 0
        self.all_generated_data = []

        self.token = token

    def process_batch(self, batch_data: List[Tuple[str, str]], pipeline: QGPipeline, batch_id: int) -> List[Dict[str, Any]]:
        """Process a batch of (func_name, docstring) tuples with individual error handling"""
        batch_results = []
        batch_success_count = 0
        batch_failure_count = 0

        print(batch_data)

        # Process the entire batch through pipeline
        try:
            generator = pipeline(batch_data)
            for result in generator:

                if result['success']:
                    batch_results.append({
                        'function_name': result['function_name'],
                        'docstring': result['docstring'],
                        'question': result['model_output']['question'],
                    })
                    batch_success_count += 1
                else:
                    print(
                        f"Failed to process {result['function_name']}: {result['error']}"
                    )
                    batch_failure_count += 1

        except Exception as e:
            # Catastrophic failure - entire batch failed
            print(f"Catastrophic batch failure {batch_id}: {e}")
            batch_failure_count = len(batch_data)
            batch_success_count = 0

        self.processed_count += batch_success_count
        self.failed_count += batch_failure_count

        print(
            f"Batch {batch_id}: {batch_success_count} successful, "
            f"{batch_failure_count} failed out of {len(batch_data)} items"
        )

        if self.save_locally and batch_results:
            self._save_batch_locally(batch_results, batch_id)

        return batch_results

    def _save_batch_locally(self, batch_results: List[Dict], batch_id: int):
        batch_file = self.local_cache_dir / f"batch_{batch_id}.jsonl"
        with open(batch_file, 'w') as f:
            for item in batch_results:
                json.dump(item, f)
                f.write('\n')

    def process_full_dataset(self, dataset, pipeline, start_idx: int = 0):
        """Process the entire data set and upload to hugging face"""

        print(f"Starting processing of {len(dataset)} items from index {start_idx}")
        start_time = time.time()

        for batch_start in tqdm(range(start_idx, len(dataset), self.batch_size),
                                desc="Processing batches"):
            batch_end = min(batch_start + self.batch_size, len(dataset))
            batch_data = dataset[batch_start:batch_end]
            batch_id = batch_start // self.batch_size

            batch_results = self.process_batch(batch_data, pipeline, batch_id)
            self.all_generated_data.extend(batch_results)

            # print progress
            if batch_id % 10 == 0:
                elapsed = time.time() - start_time
                rate = self.processed_count / elapsed if elapsed > 0 else 0
                print(f"Processed {self.processed_count} items in {elapsed:.2f} seconds. Rate: {rate:.2f} items/sec")

            #final statistics
            total_time = time.time() - start_time
            print(f"Processed {self.processed_count} items in {total_time:.2f} seconds. Rate: {self.processed_count / total_time:.2f} items/sec")

            self._upload_to_hf()

    def _upload_to_hf(self):
        if self.token is None:
            raise ValueError("Hugging Face token not provided")
        try:
            print("Creating Hugging Face dataset")

            dataset = Dataset.from_list(self.all_generated_data)

            dataset_dict = DatasetDict({
                'train': dataset
            })

            #notice this might introduce unneccessary inefficiencies
            dataset_dict = dataset_dict.map(
                lambda x: {
                    **x,
                    'id': f"{x['function_name']}_{hash(x['docstring']) % 10000}"
                }
            )

            print(f"Uploading dataset to {self.hf_dataset_name}...")

            dataset_dict.push_to_hub(
                self.hf_dataset_name,
                token=self.token,
                private=self.private_repo,
                commit_message=f"Add {len(self.all_generated_data)} docstring-question pairs"
            )

            print(f"Successfully uploaded dataset to https://huggingface.co/datasets/{self.hf_dataset_name}")

        except Exception as e:
            print(f"Error uploading to Hugging Face: {e}")
            if self.save_locally:
                print("Data is available locally in cached directory")
            raise

    def load_from_hf(self):
        """Load the dataset from Hugging Face"""

        from datasets import load_dataset

        try:
            dataset = load_dataset(self.hf_dataset_name, token=self.token)
            print(f"Successfully loaded dataset from {self.hf_dataset_name}")
            return dataset
        except Exception as e:
            print(f"Error loading dataset from Hugging Face: {e}")
            raise

    #TODO: resume processing from local cache file
    #TODO: upload from colab cache to permanent file location (local or drive)

In [7]:
# Constants

from google.colab import userdata

HUGGING_FACE_TOKEN = userdata.get("HUGGING_FACE_TOKEN")
hf_dataset_name = "mrinjera/testing"

In [14]:
dataset_processor = DocstringDatasetProcessor(hf_dataset_name, batch_size=10, token=HUGGING_FACE_TOKEN)

batch_raw_data = zip(train_data[:200]['func_name'], train_data[:200]['func_documentation_string'])
batch_zip_data = list(batch_raw_data)

pipeline = finetuned_t5
dataset_processor.process_full_dataset(batch_zip_data, pipeline)

Starting processing of 200 items from index 0


Processing batches:   0%|          | 0/20 [00:00<?, ?it/s]

[('WebSocketCommonProtocol.ensure_open', "Check that the WebSocket connection is open.\n\n        Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't."), ('WebSocketCommonProtocol.transfer_data', 'Read incoming messages and put them in a queue.\n\n        This coroutine runs in a task until the closing handshake is started.'), ('WebSocketCommonProtocol.read_message', 'Read a single message from the connection.\n\n        Re-assemble data frames if the message is fragmented.\n\n        Return ``None`` when the closing handshake is started.'), ('WebSocketCommonProtocol.read_data_frame', 'Read a single data frame from the connection.\n\n        Process control frames received before the next data frame.\n\n        Return ``None`` if a close frame is encountered before any data frame.'), ('WebSocketCommonProtocol.read_frame', 'Read a single frame from the connection.'), ('WebSocketCommonProtocol.write_close_frame', 'Write a close frame if and only if the connection state is OP

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/5.55k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/377 [00:00<?, ?B/s]

Processing batches:   5%|▌         | 1/20 [00:29<09:20, 29.50s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('WebSocketCommonProtocol.abort_keepalive_pings', "Raise ConnectionClosed in pending keepalive pings.\n\n        They'll never receive a pong once the connection is closed."), ('WebSocketCommonProtocol.connection_made', "Configure write buffer limits.\n\n        The high-water limit is defined by ``self.write_limit``.\n\n        The low-water limit currently defaults to ``self.write_limit // 4`` in\n        :meth:`~asyncio.WriteTransport.set_write_buffer_limits`, which should\n        be all right for reasonable use cases of this library.\n\n        This is the earliest point where we can get hold of the transport,\n        which means it's the best point for configuring it."), ('WebSocketCommonProtocol.eof_received', "Close the transport after receiving EOF.\n\n        Since Python 3.5, `:meth:~StreamReaderProtocol.eof_received` returns\n        ``True`` on non-TLS connections.\n\n        See http://bug

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/7.68k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/378 [00:00<?, ?B/s]

Processing batches:  10%|█         | 2/20 [01:02<09:31, 31.77s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('PerMessageDeflate.decode', 'Decode an incoming frame.'), ('PerMessageDeflate.encode', 'Encode an outgoing frame.'), ('ClientPerMessageDeflateFactory.get_request_params', 'Build request parameters.'), ('ClientPerMessageDeflateFactory.process_response_params', 'Process response parameters.\n\n        Return an extension instance.'), ('ServerPerMessageDeflateFactory.process_request_params', 'Process request parameters.\n\n        Return response params and an extension instance.'), ('apply_mask', 'Apply masking to the data of a WebSocket message.\n\n    ``data`` and ``mask`` are bytes-like objects.\n\n    Return :class:`bytes`.'), ('format_close', 'Display a human-readable version of the close code and reason.'), ('ServerExtensionFactory.process_request_params', 'Process request parameters received from the client.\n\n        ``params`` is a list of (name, value) pairs.\n\n        ``accepted_extensions`` 

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/9.41k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/378 [00:00<?, ?B/s]

Processing batches:  15%|█▌        | 3/20 [01:30<08:27, 29.85s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('ServiceRunner.start', 'Start all the registered services.\n\n        A new container is created for each service using the container\n        class provided in the __init__ method.\n\n        All containers are started concurrently and the method will block\n        until all have completed their startup routine.'), ('ServiceRunner.wait', 'Wait for all running containers to stop.'), ('Publisher.publish', 'Publish a message.'), ('RpcConsumer.stop', "Stop the RpcConsumer.\n\n        The RpcConsumer ordinary unregisters from the QueueConsumer when the\n        last Rpc subclass unregisters from it. If no providers were registered,\n        we should unregister from the QueueConsumer as soon as we're asked\n        to stop."), ('RpcConsumer.unregister_provider', 'Unregister a provider.\n\n        Blocks until this RpcConsumer is unregistered from its QueueConsumer,\n        which only happens when all prov

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/11.8k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/380 [00:00<?, ?B/s]

Processing batches:  20%|██        | 4/20 [02:00<07:59, 29.96s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('deserialize', 'Deserialize `data` to an exception instance.\n\n    If the `exc_path` value matches an exception registered as\n    ``deserializable``, return an instance of that exception type.\n    Otherwise, return a `RemoteError` instance describing the exception\n    that occurred.'), ('get_redacted_args', 'Utility function for use with entrypoints that are marked with\n    ``sensitive_arguments`` -- e.g. :class:`nameko.rpc.Rpc` and\n    :class:`nameko.events.EventHandler`.\n\n    :Parameters:\n        entrypoint : :class:`~nameko.extensions.Entrypoint`\n            The entrypoint that fired.\n        args : tuple\n            Positional arguments for the method call.\n        kwargs : dict\n            Keyword arguments for the method call.\n\n    The entrypoint should have a ``sensitive_arguments`` attribute, the value\n    of which is a string or tuple of strings specifying the arguments or\n   

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/381 [00:00<?, ?B/s]

Processing batches:  25%|██▌       | 5/20 [02:34<07:50, 31.35s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('WebServer.get_wsgi_server', 'Get the WSGI server used to process requests.'), ('get_event_exchange', 'Get an exchange for ``service_name`` events.'), ('event_dispatcher', 'Return a function that dispatches nameko events.'), ('fail_fast_imap', 'Run a function against each item in a given list, yielding each\n    function result in turn, where the function call is handled in a\n    :class:`~eventlet.greenthread.GreenThread` spawned by the provided pool.\n\n    If any function raises an exception, all other ongoing threads are killed,\n    and the exception is raised to the caller.\n\n    This function is similar to :meth:`~eventlet.greenpool.GreenPool.imap`.\n\n    :param pool: Pool to spawn function threads from\n    :type pool: eventlet.greenpool.GreenPool\n    :param call: Function call to make, expecting to receive an item from the\n        given list'), ('Timer._run', 'Runs the interval loop.'), ('Q

Map:   0%|          | 0/60 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/15.7k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/381 [00:00<?, ?B/s]

Processing batches:  30%|███       | 6/20 [03:00<06:54, 29.64s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('QueueConsumer.on_consume_ready', 'Kombu callback when consumers are ready to accept messages.\n\n        Called after any (re)connection to the broker.'), ('make_nameko_helper', 'Create a fake module that provides some convenient access to nameko\n    standalone functionality for interactive shell usage.'), ('iter_extensions', 'Depth-first iterator over sub-extensions on `extension`.'), ('Extension.bind', 'Get an instance of this Extension to bind to `container`.'), ('SharedExtension.bind', 'Bind implementation that supports sharing.'), ('DependencyProvider.bind', 'Get an instance of this Dependency to bind to `container` with\n        `attr_name`.'), ('ProviderCollector.wait_for_providers', 'Wait for any providers registered with the collector to have\n        unregistered.\n\n        Returns immediately if no providers were ever registered.'), ('Entrypoint.bind', 'Get an instance of this Entrypoint t

Map:   0%|          | 0/70 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/16.9k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/381 [00:00<?, ?B/s]

Processing batches:  35%|███▌      | 7/20 [03:26<06:09, 28.44s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('ServiceContainer.kill', 'Kill the container in a semi-graceful way.\n\n        Entrypoints are killed, followed by any active worker threads.\n        Next, dependencies are killed. Finally, any remaining managed threads\n        are killed.\n\n        If ``exc_info`` is provided, the exception will be raised by\n        :meth:`~wait``.'), ('ServiceContainer.spawn_worker', 'Spawn a worker thread for running the service method decorated\n        by `entrypoint`.\n\n        ``args`` and ``kwargs`` are used as parameters for the service method.\n\n        ``context_data`` is used to initialize a ``WorkerContext``.\n\n        ``handle_result`` is an optional function which may be passed\n        in by the entrypoint. It is called with the result returned\n        or error raised by the service method. If provided it must return a\n        value for ``result`` and ``exc_info`` to propagate to dependencies;\

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/18.8k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/381 [00:00<?, ?B/s]

Processing batches:  40%|████      | 8/20 [03:57<05:51, 29.27s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('StatefulBrowser.refresh', 'Reload the current page with the same request as originally done.\n        Any change (`select_form`, or any value filled-in in the form) made to\n        the current page before refresh is discarded.\n\n        :raise ValueError: Raised if no refreshable page is loaded, e.g., when\n            using the shallow ``Browser`` wrapper functions.\n\n        :return: Response of the request.'), ('StatefulBrowser.select_form', 'Select a form in the current page.\n\n        :param selector: CSS selector or a bs4.element.Tag object to identify\n            the form to select.\n            If not specified, ``selector`` defaults to "form", which is\n            useful if, e.g., there is only one form on the page.\n            For ``selector`` syntax, see the `.select() method in BeautifulSoup\n            <https://www.crummy.com/software/BeautifulSoup/bs4/doc/#css-selectors>`__.\n    

Map:   0%|          | 0/90 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/21.5k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/381 [00:00<?, ?B/s]

Processing batches:  45%|████▌     | 9/20 [04:30<05:33, 30.35s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('Form.set_input', 'Fill-in a set of fields in a form.\n\n        Example: filling-in a login/password form\n\n        .. code-block:: python\n\n           form.set_input({"login": username, "password": password})\n\n        This will find the input element named "login" and give it the\n        value ``username``, and the input element named "password" and\n        give it the value ``password``.'), ('Form.uncheck_all', 'Remove the *checked*-attribute of all input elements with\n        a *name*-attribute given by ``name``.'), ('Form.check', 'For backwards compatibility, this method handles checkboxes\n        and radio buttons in a single call. It will not uncheck any\n        checkboxes unless explicitly specified by ``data``, in contrast\n        with the default behavior of :func:`~Form.set_checkbox`.'), ('Form.set_checkbox', 'Set the *checked*-attribute of input elements of type "checkbox"\n       

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/24.1k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/381 [00:00<?, ?B/s]

Processing batches:  50%|█████     | 10/20 [05:05<05:17, 31.72s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('Form.print_summary', 'Print a summary of the form.\n\n        May help finding which fields need to be filled-in.'), ('Browser.__looks_like_html', 'Guesses entity type when Content-Type header is missing.\n        Since Content-Type is not strictly required, some servers leave it out.'), ('Browser.add_soup', 'Attaches a soup object to a requests response.'), ('Browser.set_user_agent', 'Replaces the current user agent in the requests session headers.'), ('Browser.request', "Straightforward wrapper around `requests.Session.request\n        <http://docs.python-requests.org/en/master/api/#requests.Session.request>`__.\n\n        :return: `requests.Response\n            <http://docs.python-requests.org/en/master/api/#requests.Response>`__\n            object with a *soup*-attribute added by :func:`add_soup`.\n\n        This is a low-level function that should not be called for\n        basic usage (use :fun

Map:   0%|          | 0/110 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/25.2k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches:  55%|█████▌    | 11/20 [05:34<04:37, 30.84s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('requirements_from_file', 'Parses a pip requirements file into a list.'), ('read', 'Read the content of a file.'), ('imscatter', 'Creates a scatter plot, where each plot is shown by corresponding image'), ('pauli', 'Convert to pauli operators of universal gate model.\n\tRequires blueqat.'), ('make_qs', 'Make sympy symbols q0, q1, ...\n    \n    Args:\n        n(int), m(int, optional):\n            If specified both n and m, returns [qn, q(n+1), ..., qm],\n            Only n is specified, returns[q0, q1, ..., qn].\n\n    Return:\n        tuple(Symbol): Tuple of sympy symbols.'), ('nbody_separation', "Convert n-body problem to 2-body problem.\n    \n    Args:\n        expr: sympy expressions to be separated.\n        qs: sympy's symbols to be used as supplementary variable.\n\n    Return:\n        new_expr(sympy expr), constraints(sympy expr), mapping(dict(str, str -> Symbol)):\n            `new_expr` is 

Map:   0%|          | 0/120 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/26.9k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches:  60%|██████    | 12/20 [06:02<04:00, 30.02s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('Opt.run', 'Run SA with provided QUBO. \n\t\tSet qubo attribute in advance of calling this method.'), ('numpartition_qaoa', 'Do the Number partition QAOA.\n\n    :param n_step: The number of step of QAOA\n    :param nums: The edges list of the graph.\n    :returns Vqe object'), ('slicing_singlevalue', 'Internally used.'), ('slicing', 'Internally used.'), ('qubit_pairs', 'Internally used.'), ('get_maximum_index', 'Internally used.'), ('Gate._str_targets', 'Returns printable string of targets.'), ('to_inttuple', 'Convert from bit string likes \'01011\' to int tuple likes (0, 1, 0, 1, 1)\n\n    Args:\n        bitstr (str, Counter, dict): String which is written in "0" or "1".\n            If all keys are bitstr, Counter or dict are also can be converted by this function.\n\n    Returns:\n        tuple of int, Counter, dict: Converted bits.\n            If bitstr is Counter or dict, returns the Counter or d

Map:   0%|          | 0/130 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/28.0k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches:  65%|██████▌   | 13/20 [06:27<03:20, 28.58s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('non_sampling_sampler', 'Calculate the expectations without sampling.'), ('get_measurement_sampler', 'Returns a function which get the expectations by sampling the measured circuit'), ('get_state_vector_sampler', 'Returns a function which get the expectations by sampling the state vector'), ('get_qiskit_sampler', 'Returns a function which get the expectation by sampling via Qiskit.\n\n    This function requires `qiskit` module.'), ('AnsatzBase.get_energy', 'Calculate energy from circuit and sampler.'), ('AnsatzBase.get_objective', 'Get an objective function to be optimized.'), ('VqeResult.get_probs', 'Get probabilities.'), ('Vqe.result', 'Vqe.result is deprecated. Use `result = Vqe.run()`.'), ('factoring_qaoa', 'Do the Number partition QAOA.\n\n    :param num: The number to be factoring.\n    :param n_step: The number of step of QAOA\n    :param edges: The edges list of the graph.\n    :returns result o

Map:   0%|          | 0/140 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/29.0k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches:  70%|███████   | 14/20 [06:57<02:54, 29.10s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('Backend._run', "Default implementation of `Backend.run`.\n        Backend developer shouldn't override this function, but override `run` instead of this.\n\n        The default flow of running is:\n            1. preprocessing\n            2. call the gate action which defined in backend\n            3. postprocessing\n\n        Backend developer can:\n            1. Define preprocessing process. Override `_preprocess_run`\n            2. Define the gate action. Define methods `gate_{gate.lowername}`,\n               for example, `gate_x` for X gate, `gate_cx` for CX gate.\n            3. Define postprocessing process (and make return value). Override `_postprocess_run`\n        Otherwise, the developer can override `run` method if they want to change the flow of run."), ('Backend.run', 'Run the backend.'), ('Backend._resolve_fallback', 'Resolve fallbacks and flatten gates.'), ('_NumPyBackendContext.pr

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/31.0k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches:  75%|███████▌  | 15/20 [07:26<02:24, 28.86s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('BlueqatGlobalSetting.register_gate', 'Register new gate to gate set.\n\n        Args:\n            name (str): The name of gate.\n            gateclass (type): The type object of gate.\n            allow_overwrite (bool, optional): If True, allow to overwrite the existing gate.\n                Otherwise, raise the ValueError.\n\n        Raises:\n            ValueError: The name is duplicated with existing gate.\n                When `allow_overwrite=True`, this error is not raised.'), ('BlueqatGlobalSetting.register_backend', 'Register new backend.\n\n        Args:\n            name (str): The name of backend.\n            gateclass (type): The type object of backend\n            allow_overwrite (bool, optional): If True, allow to overwrite the existing backend.\n                Otherwise, raise the ValueError.\n\n        Raises:\n            ValueError: The name is duplicated with existing backend.\n

Map:   0%|          | 0/160 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/32.4k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches:  80%|████████  | 16/20 [07:55<01:55, 28.91s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('Term.simplify', 'Simplify the Term.'), ('Term.append_to_circuit', 'Append Pauli gates to `Circuit`.'), ('Term.get_time_evolution', 'Get the function to append the time evolution of this term.\n\n        Returns:\n            function(circuit: Circuit, t: float):\n                Add gates for time evolution to `circuit` with time `t`'), ('Term.to_matrix', 'Convert to the matrix.'), ('Expr.from_terms_dict', 'For internal use.'), ('Expr.is_identity', 'If `self` is I, returns True, otherwise False.'), ('Expr.max_n', 'Returns the maximum index of Pauli matrices in the Term.'), ('Expr.is_all_terms_commutable', 'Test whether all terms are commutable. This function may very slow.'), ('Expr.simplify', 'Simplify the Expr.'), ('Expr.to_matrix', 'Convert to the matrix.')]
Batch 16: 10 successful, 0 failed out of 10 items
Processed 170 items in 497.11 seconds. Rate: 0.34 items/sec
Creating Hugging Face dataset


Map:   0%|          | 0/170 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/33.1k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches:  85%|████████▌ | 17/20 [08:19<01:23, 27.68s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('SvgComposedPoly.add_coord', 'Adds a coord to the polyline and creates another circle'), ('SvgPolygon.set_stroke', 'Sets the stroke properties.\r\n\r\n        Args:\r\n            width (int): stroke width\r\n            color (str): stroke color'), ('SvgPolygon.add_arrow_coord', 'Determine the coordinates of an arrow head polygon\r\n            with height (h) and width (w) and recess (r)\r\n            pointing from the one but last to the last point of (poly)line (line).\r\n            Note that the coordinates of an SvgLine and an SvgPolyline\r\n            are stored in different variables.'), ('MyApp.Draw_a_drawing_of_one_sheet', 'Draw a drawing with two boxes, each with a name inside\r\n            and a polyline between the midpoints of the sides of the boxes,\r\n            with half-way the polyline a rhombus with an id included.'), ('MyApp.box_type_1', 'Draw a rectangular box of box_width and

Map:   0%|          | 0/180 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches:  90%|█████████ | 18/20 [08:47<00:55, 27.50s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('MyApp.Display_TreeTable', "Display a table in which the values in first column form one or more trees.\n            The table has row with fields that are strings of identifiers/names.\n            First convert each row into a row_widget and item_widgets\n            that are displayed in a TableTree.\n            Each input row shall start with a parent field (field[0])\n            that determines the tree hierarchy but that is not displayed on that row.\n            The parent widget becomes an attribute of the first child widget.\n            Field[1] is the row color, field[2:] contains the row values.\n            Top child(s) shall have a parent field value that is blank ('').\n            The input table rows shall be in the correct sequence."), ('InputGauge.confirm_value', 'event called clicking on the gauge and so changing its value.\n           propagates the new value'), ('Editor.configure

Map:   0%|          | 0/190 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/36.4k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches:  95%|█████████▌| 19/20 [09:13<00:27, 27.35s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing
[('ClassEventConnector.do', 'The callback and userdata gets stored, and if there is some javascript to add\n            the js code is appended as attribute for the event source'), ('Tag.repr', 'It is used to automatically represent the object to HTML format\n        packs all the attributes, children and so on.\n\n        Args:\n            changed_widgets (dict): A dictionary containing a collection of tags that have to be updated.\n                The tag that have to be updated is the key, and the value is its textual repr.'), ('Tag.add_child', "Adds a child to the Tag\n\n        To retrieve the child call get_child or access to the Tag.children[key] dictionary.\n\n        Args:\n            key (str):  Unique child's identifier, or iterable of keys\n            value (Tag, str): can be a Tag, an iterable of Tag or a str. In case of iterable\n                of Tag is a dict, each item's key is set as

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Uploading dataset to mrinjera/testing...


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading...:   0%|          | 0.00/37.9k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/382 [00:00<?, ?B/s]

Processing batches: 100%|██████████| 20/20 [09:42<00:00, 29.13s/it]

Successfully uploaded dataset to https://huggingface.co/datasets/mrinjera/testing





In [44]:
for i in range(10):
    print(train_data[i]['func_documentation_string'])

Check that the WebSocket connection is open.

        Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't.
Read incoming messages and put them in a queue.

        This coroutine runs in a task until the closing handshake is started.
Read a single message from the connection.

        Re-assemble data frames if the message is fragmented.

        Return ``None`` when the closing handshake is started.
Read a single data frame from the connection.

        Process control frames received before the next data frame.

        Return ``None`` if a close frame is encountered before any data frame.
Read a single frame from the connection.
Write a close frame if and only if the connection state is OPEN.

        This dedicated coroutine must be used for writing close frames to
        ensure that at most one close frame is sent on a given connection.
Send a Ping frame and wait for a Pong frame at regular intervals.

        This coroutine exits when the connection terminates and o

# Model Training

In [18]:
import os
import torch
import pandas as pd
from datasets import Dataset, load_dataset
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
import logging
from typing import List, Dict, Tuple
from google.colab import userdata


DATASET_NAME = "mrinjera/testing"
HF_TOKEN = userdata.get("HUGGING_FACE_TOKEN")
MODEL_NAME = "microsoft/codebert-base"
OUTPUT_MODEL_PATH = "./sbert-function-retrieval"
BATCH_SIZE = 16
EPOCHS = 4
LEARNING_RATE = 2e-5
WARMUP_STEPS = 1000
EVALUATION_STEPS = 5000

In [26]:
def load_and_process_dataset(dataset_name: str, token: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Load the dataset and process it for training."""
    print(f"Loading dataset: {dataset_name}")

    dataset = load_dataset(dataset_name, token=token)

    #TODO: for clarity, have a bool param that controls the behavior
    if 'train' not in dataset or 'test' not in dataset:
        dataset = dataset['train'].shuffle(seed=42)
        train_size = int(0.8 * len(dataset))
        train_df = dataset.select(range(train_size)).to_pandas()
        test_df = dataset.select(range(train_size, len(dataset))).to_pandas()
    else:
        train_df = dataset['train'].to_pandas()
        test_df  = dataset['test'].to_pandas()

    print(f"Train dataset loaded with {len(train_df)} examples")
    print(f"Test dataset loaded with {len(test_df)} examples")
    print(f"Dataset columns: {train_df.columns.tolist()}")

    required_columns = ['function_name', 'docstring', 'question', 'id']
    for df_name, df in [('train', train_df), ('test', test_df)]:
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(f"Missing required columns in {df_name} split: {missing_columns}")

    train_df = train_df.dropna(subset=['docstring', 'question'])
    test_df = test_df.dropna(subset=['docstring', 'question'])

    print(f"After removing missing values:")
    print(f"  Train: {len(train_df)} examples")
    print(f"  Test: {len(test_df)} examples")

    return train_df, test_df

In [28]:
def create_training_examples(df: pd.DataFrame) -> List[InputExample]:
    """Create InputExample objects for SBERT training."""
    examples = []

    for idx, row in df.iterrows():
        # Create input example with question as query and docstring as positive document
        example = InputExample(
            texts=[row['question'], row['docstring']],
            label=1.0  # Positive pair
        )
        examples.append(example)

    print(f"Created {len(examples)} training examples")
    return examples

def create_evaluation_data(df: pd.DataFrame) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, set]]:
    """Create evaluation data for Information Retrieval evaluation."""
    # Split data for evaluation
    eval_df = df.sample(n=min(1000, len(df) // 10), random_state=42)

    queries = {}
    corpus = {}
    relevant_docs = {}

    for idx, row in eval_df.iterrows():
        query_id = f"q_{idx}"
        doc_id = f"d_{idx}"

        queries[query_id] = row['question']
        corpus[doc_id] = row['docstring']
        relevant_docs[query_id] = {doc_id}

    print(f"Created evaluation data with {len(queries)} queries and {len(corpus)} documents")
    return queries, corpus, relevant_docs

def create_validation_split(train_df: pd.DataFrame,
                          validation_size: float = 0.1) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Create a validation split from training data for monitoring during training."""
    train_split, val_split = train_test_split(
        train_df,
        test_size=validation_size,
        random_state=42,
        stratify=None
    )

    print(f"Created validation split:")
    print(f"  Training: {len(train_split)} examples")
    print(f"  Validation: {len(val_split)} examples")

    return train_split, val_split

In [21]:
def initialize_model(model_name: str) -> SentenceTransformer:
    """Initialize the SBERT model."""
    print(f"Initializing model: {model_name}")

    model = SentenceTransformer(model_name)

    print(f"Model max sequence length: {model.max_seq_length}")
    print(f"Model device: {model.device}")

    return model

def train_model(model: SentenceTransformer,
                train_examples: List[InputExample],
                val_queries: Dict[str, str],
                val_corpus: Dict[str, str],
                val_relevant_docs: Dict[str, set],
                device: torch.device) -> SentenceTransformer:
    """Train the SBERT model with Multiple Negatives Ranking Loss."""

    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)

    train_loss = losses.MultipleNegativesRankingLoss(model)

    val_evaluator = InformationRetrievalEvaluator(
        queries=val_queries,
        corpus=val_corpus,
        relevant_docs=val_relevant_docs,
        name="validation-eval"
    )

    # calculate warmup steps
    num_train_steps = len(train_dataloader) * EPOCHS
    warmup_steps = min(WARMUP_STEPS, num_train_steps // 10)

    print(f"Training configuration:")
    print(f"  Batch size: {BATCH_SIZE}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Learning rate: {LEARNING_RATE}")
    print(f"  Warmup steps: {warmup_steps}")
    print(f"  Total training steps: {num_train_steps}")
    print(f"  Evaluation steps: {EVALUATION_STEPS}")

    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=val_evaluator,
        epochs=EPOCHS,
        evaluation_steps=EVALUATION_STEPS,
        warmup_steps=warmup_steps,
        output_path=OUTPUT_MODEL_PATH,
        optimizer_params={'lr': LEARNING_RATE},
        save_best_model=True,
        show_progress_bar=True
    )

    return model

In [22]:
def evaluate_on_test_set(model: SentenceTransformer,
                        test_queries: Dict[str, str],
                        test_corpus: Dict[str, str],
                        test_relevant_docs: Dict[str, set]) -> Dict[str, float]:
    """Evaluate the trained model on the test set."""
    print("Evaluating model on test set...")

    # Create test evaluator
    test_evaluator = InformationRetrievalEvaluator(
        queries=test_queries,
        corpus=test_corpus,
        relevant_docs=test_relevant_docs,
        name="test-eval"
    )

    # Evaluate
    test_score = test_evaluator(model, output_path=OUTPUT_MODEL_PATH)

    print(f"Test evaluation completed. Score: {test_score}")
    return test_score

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_df, test_df = load_and_process_dataset(DATASET_NAME, HF_TOKEN)

train_split_df, val_split_df = create_validation_split(train_df)

train_examples = create_training_examples(train_split_df)

val_queries, val_corpus, val_relevant_docs = create_evaluation_data(val_split_df)

test_queries, test_corpus, test_relevant_docs = create_evaluation_data(test_df)


# model = initialize_model(MODEL_NAME)

# trained_model = train_model(
#     model=model,
#     train_examples=train_examples,
#     val_queries=val_queries,
#     val_corpus=val_corpus,
#     val_relevant_docs=val_relevant_docs,
#     device=device
# )

# # Load best model for final evaluation
# best_model = SentenceTransformer(OUTPUT_MODEL_PATH)

# # Evaluate on test set
# test_score = evaluate_on_test_set(
#     model=best_model,
#     test_queries=test_queries,
#     test_corpus=test_corpus,
#     test_relevant_docs=test_relevant_docs
# )


Loading dataset: mrinjera/testing
Train dataset loaded with 160 examples
Test dataset loaded with 40 examples
Dataset columns: ['function_name', 'docstring', 'question', 'id']
After removing missing values:
  Train: 160 examples
  Test: 40 examples
Created validation split:
  Training: 144 examples
  Validation: 16 examples
Created 144 training examples
Created evaluation data with 1 queries and 1 documents
Created evaluation data with 4 queries and 4 documents
