In [66]:
!pip install tree-sitter==0.22.3
!pip install tree-sitter-python==0.21.0
!pip install tree-sitter-java==0.21.0
!pip install datasets==2.20.0
!pip install evaluate==0.4.2
!pip install rouge_score==0.1.2
!pip install torch==2.5.1
!pip install transformers==4.46.2



In [67]:
import pandas as pd
import datasets


def prepare(dataset: datasets.Dataset, prettyfy_function, parse_function) -> datasets.Dataset:
    df = pd.DataFrame(dataset)

    for i, field in df.iterrows():
        code = field[0]

        func_name, func_body, func_body_without_comment = parse_function(code)

        df.at[i, "func_name"] = func_name
        df.at[i, "func_body"] = prettyfy_function(func_body)
        df.at[i, "func_body_without_comment"] = prettyfy_function(func_body_without_comment)


    return datasets.Dataset.from_pandas(df)


In [68]:
import pandas
import torch

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
import os
import evaluate
from pprint import pprint

def print_worst_example(predictions, references):
    exact_match, rouge = _init_metrics()

    worst_example = {
        "prediction": None,
        "reference": None,
        "score": float('-inf')
    }

    for prediction, reference in zip(predictions, references):
        em_score = exact_match.compute(predictions=[prediction], references=[reference])["exact_match"]
        rouge_score = rouge.compute(predictions=[prediction], references=[reference])["rougeL"]

        combined_score = (1 - em_score) + (1 - rouge_score)

        if combined_score > worst_example["score"]:
            worst_example = {
                "prediction": prediction,
                "reference": reference,
                "score": combined_score
            }
    print('WORST EXAMPLE:')
    pprint(worst_example)

def _init_metrics():
    return (evaluate.load('exact_match'), evaluate.load('rouge'))

def predict(dataset: datasets.Dataset, references: datasets.Dataset) -> int:
    torch.cuda.empty_cache()
    device = torch.device("cuda")

    checkpoint = 'Salesforce/codet5p-220m'

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = T5ForConditionalGeneration.from_pretrained(checkpoint).to(device)

    predictions = []
    try:
      inputs = tokenizer(dataset,
                         return_tensors='pt',
                         padding=True,
                         truncation=True,
                         max_length=80,
                         ).to(device)
      outputs = model.generate(**inputs, max_length=80)
      predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
      del model
    except Exception as e:
      del model
    print(dataset[0])

    x = 0
    for i in range(len(references)):
      if references[i] in predictions[i]:
        x += 1
    print(f'X: {x}\n')


    for i in range(len(predictions)):
      try:
        predictions[i] = predictions[i].split(" ")[1].split("(")[0]
      except Exception:
        pass
    print(predictions)
    print(references)

    eval_results = run_evaluate(predictions=predictions, references=references)
    print()
    print('*' * 80)
    print('Evaluation results:')
    pprint(eval_results)
    print('*' * 80)
    print()
    print_worst_example(predictions, references)

    return x


def run_evaluate(
    predictions, references
) -> dict[str, float]:
    em, rouge = _init_metrics()
    em_score = em.compute(predictions=predictions, references=references)
    rouge_scores = rouge.compute(predictions=predictions, references=references)

    return {**rouge_scores, **em_score}


## Python metrics

### Download python dataset

In [69]:
dataset_python = datasets.load_dataset(
        'code_search_net',
        'python',
        split='test',
        trust_remote_code=True
    )
dataset_python

Dataset({
    features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
    num_rows: 22176
})

### Check dataset

In [70]:
dataset_python[0]

{'repository_name': 'soimort/you-get',
 'func_path_in_repository': 'src/you_get/extractors/youtube.py',
 'func_name': 'YouTube.get_vid_from_url',
 'whole_func_string': 'def get_vid_from_url(url):\n        """Extracts video ID from URL.\n        """\n        return match1(url, r\'youtu\\.be/([^?/]+)\') or \\\n          match1(url, r\'youtube\\.com/embed/([^/?]+)\') or \\\n          match1(url, r\'youtube\\.com/v/([^/?]+)\') or \\\n          match1(url, r\'youtube\\.com/watch/([^/?]+)\') or \\\n          parse_query_param(url, \'v\') or \\\n          parse_query_param(parse_query_param(url, \'u\'), \'v\')',
 'language': 'python',
 'func_code_string': 'def get_vid_from_url(url):\n        """Extracts video ID from URL.\n        """\n        return match1(url, r\'youtu\\.be/([^?/]+)\') or \\\n          match1(url, r\'youtube\\.com/embed/([^/?]+)\') or \\\n          match1(url, r\'youtube\\.com/v/([^/?]+)\') or \\\n          match1(url, r\'youtube\\.com/watch/([^/?]+)\') or \\\n          par

In [71]:
print(dataset_python[0]["whole_func_string"])

def get_vid_from_url(url):
        """Extracts video ID from URL.
        """
        return match1(url, r'youtu\.be/([^?/]+)') or \
          match1(url, r'youtube\.com/embed/([^/?]+)') or \
          match1(url, r'youtube\.com/v/([^/?]+)') or \
          match1(url, r'youtube\.com/watch/([^/?]+)') or \
          parse_query_param(url, 'v') or \
          parse_query_param(parse_query_param(url, 'u'), 'v')


### Crop python dataset

In [72]:
dataset_python = dataset_python.select(range(1000))

### Prepare dataset python

In [73]:
import tree_sitter_python as tspython
from tree_sitter import Language, Parser

def prettyfy_function(x):
  return f"def <extra_id_0> :{x}"

PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)
def parse_function(code):
  ast = parser.parse(bytes(code, "utf-8"))
  query = PY_LANGUAGE.query("""
  (function_definition
      name: (identifier) @func_name
      body: (block) @func_body
  )
  (expression_statement
      (string) @docstring
  )
  (comment) @comment
  """)
  captures = query.captures(ast.root_node)

  func_name = ''
  func_body = ''
  func_body_without_comment = []
  comments = []

  for node, capture_name in captures:
      if capture_name == "func_name":
          func_name = code[node.start_byte:node.end_byte]
      elif capture_name == "func_body":
          func_body = code[node.start_byte:node.end_byte]
      elif capture_name == "comment":
          comments.append(code[node.start_byte:node.end_byte])
      elif capture_name == "docstring":
          comments.append(code[node.start_byte:node.end_byte])

  func_body_without_comment = func_body
  for comment in comments:
      func_body_without_comment = func_body_without_comment.replace(comment, "")

  return (func_name, func_body, func_body_without_comment)

dataset_python = prepare(dataset_python["whole_func_string"], prettyfy_function, parse_function)
dataset_python

  return cls(pa.Table.from_pandas(*args, **kwargs))


Dataset({
    features: ['0', 'func_name', 'func_body', 'func_body_without_comment'],
    num_rows: 1000
})

#### Funcion in python name

In [74]:
print(dataset_python['func_name'][0])

get_vid_from_url


#### Funcion in python with comments and docstring

In [75]:
print(dataset_python['func_body'][0])

def <extra_id_0> :"""Extracts video ID from URL.
        """
        return match1(url, r'youtu\.be/([^?/]+)') or \
          match1(url, r'youtube\.com/embed/([^/?]+)') or \
          match1(url, r'youtube\.com/v/([^/?]+)') or \
          match1(url, r'youtube\.com/watch/([^/?]+)') or \
          parse_query_param(url, 'v') or \
          parse_query_param(parse_query_param(url, 'u'), 'v')


#### Funcion in python WITHOUT comments and docstring

In [76]:
print(dataset_python['func_body_without_comment'][0])

def <extra_id_0> :
        return match1(url, r'youtu\.be/([^?/]+)') or \
          match1(url, r'youtube\.com/embed/([^/?]+)') or \
          match1(url, r'youtube\.com/v/([^/?]+)') or \
          match1(url, r'youtube\.com/watch/([^/?]+)') or \
          parse_query_param(url, 'v') or \
          parse_query_param(parse_query_param(url, 'u'), 'v')


### Predict python function name

#### Python without comments

In [77]:
predict(dataset_python["func_body_without_comment"], dataset_python["func_name"])


def <extra_id_0> :
        return match1(url, r'youtu\.be/([^?/]+)') or \
          match1(url, r'youtube\.com/embed/([^/?]+)') or \
          match1(url, r'youtube\.com/v/([^/?]+)') or \
          match1(url, r'youtube\.com/watch/([^/?]+)') or \
          parse_query_param(url, 'v') or \
          parse_query_param(parse_query_param(url, 'u'), 'v')
X: 220

['match1', 'parse_url', 'md5', 'is_fc2video_download', 'get_video_info', '', 'get_resource_id', 'ucas_download', 'download_', 'download_', 'sina_zxt', 'download_by_smid', 'match1', 'get_watch_url', 'download', 'get_url', 'color_to_color', 'print_error', 'print_log', 'print_log', 'system', 'get_video_id', 'vimeo_download_by_channel_id', 'get_vimeo_id', 'download_by_vid', 'get_ckplayer', 'test_mp4_url\n', 'match1', '', 'get_branch_from_commit', 'translate', 'get_win_size', 'theplatform_download_', '', 'get_video', '_options', 'download_stream', 'find_pattern', 'find_all', 'parse_qs', 'read_gzip', 'decompress', 'get_content', 'post_con

220

#### Python + comments

In [78]:
predict(dataset_python["func_body"], dataset_python["func_name"])


def <extra_id_0> :"""Extracts video ID from URL.
        """
        return match1(url, r'youtu\.be/([^?/]+)') or \
          match1(url, r'youtube\.com/embed/([^/?]+)') or \
          match1(url, r'youtube\.com/v/([^/?]+)') or \
          match1(url, r'youtube\.com/watch/([^/?]+)') or \
          parse_query_param(url, 'v') or \
          parse_query_param(parse_query_param(url, 'u'), 'v')
X: 315

['or', 'parse_xml', 'init', 'wrapper', 'download_dailymotion_video', 'dictify', 'get_video', 'ucas_download', 'download_sina_video', 'download_sina_video', 'sina_zxt', 'miaopai_download_by_smid', 'get_item_id', 'get_content', 'prepare', 'url)\n', 'format_text', 'print_log', 'print_error', 'print_log', 'windows_linux_detect', 'get_mobile_page', 'download_vimeo', 'get_id', 'get_video', 'get_ckinfo', 'test_mp4_url\n', 'extract_video_id', 'get_real_urls', '=', 'filename_to_filename', 'get_terminal_size', '', '', 'get_content', 'main', 'download', 'scan_string', 'scan', 'parse_qs', 'decompress', 

315

## Java metrics

### Download java dataset

In [79]:
dataset_java_original = datasets.load_dataset(
        'code_search_net',
        'java',
        split='test',
        trust_remote_code=True
    )
dataset_java_original

Dataset({
    features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
    num_rows: 26909
})

### Check dataset

In [80]:
dataset_java_original[0]

{'repository_name': 'ReactiveX/RxJava',
 'func_path_in_repository': 'src/main/java/io/reactivex/internal/observers/QueueDrainObserver.java',
 'func_name': 'QueueDrainObserver.fastPathOrderedEmit',
 'whole_func_string': 'protected final void fastPathOrderedEmit(U value, boolean delayError, Disposable disposable) {\n        final Observer<? super V> observer = downstream;\n        final SimplePlainQueue<U> q = queue;\n\n        if (wip.get() == 0 && wip.compareAndSet(0, 1)) {\n            if (q.isEmpty()) {\n                accept(observer, value);\n                if (leave(-1) == 0) {\n                    return;\n                }\n            } else {\n                q.offer(value);\n            }\n        } else {\n            q.offer(value);\n            if (!enter()) {\n                return;\n            }\n        }\n        QueueDrainHelper.drainLoop(q, observer, delayError, disposable, this);\n    }',
 'language': 'java',
 'func_code_string': 'protected final void fastPathOr

In [81]:
print(dataset_java_original[214]["whole_func_string"])

@CheckReturnValue
    @SchedulerSupport(SchedulerSupport.CUSTOM)
    public final Observable<T> skipLast(long time, TimeUnit unit, Scheduler scheduler, boolean delayError, int bufferSize) {
        ObjectHelper.requireNonNull(unit, "unit is null");
        ObjectHelper.requireNonNull(scheduler, "scheduler is null");
        ObjectHelper.verifyPositive(bufferSize, "bufferSize");
        // the internal buffer holds pairs of (timestamp, value) so double the default buffer size
        int s = bufferSize << 1;
        return RxJavaPlugins.onAssembly(new ObservableSkipLastTimed<T>(this, time, unit, scheduler, s, delayError));
    }


### Crop java dataset

In [82]:
dataset_java_crop = dataset_java_original.select(range(1000))

### Prepare java dataset

In [83]:
import tree_sitter_java as tsjava
from tree_sitter import Language, Parser

def prettyfy_function_java(x):
  return f"public final void <extra_id_0> {x}"

JAVA_LANGUAGE = Language(tsjava.language())
parser = Parser(JAVA_LANGUAGE)
def parse_function_java(code):
  ast = parser.parse(bytes(code, "utf-8"))
  query = JAVA_LANGUAGE.query("""
      (block_comment) @block_comment
      (line_comment) @line_comment
      (method_declaration
        (block) @function
      )
      (method_declaration
        (identifier) @name
      )
  """)
  captures = query.captures(ast.root_node)

  func_name = ''
  func_body = ''
  func_body_without_comment = []
  comments = []

  for node, capture_name in captures:
      if capture_name == "name":
          func_name = code[node.start_byte:node.end_byte]
      elif capture_name == "function":
          func_body = code[node.start_byte:node.end_byte]
      elif capture_name == "line_comment":
          comments.append(code[node.start_byte:node.end_byte])
      elif capture_name == "block_comment":
          comments.append(code[node.start_byte:node.end_byte])

  func_body_without_comment = func_body
  for comment in comments:
      func_body_without_comment = func_body_without_comment.replace(comment, "")

  return (func_name, func_body, func_body_without_comment)

dataset_java = prepare(dataset_java_crop["whole_func_string"], prettyfy_function_java, parse_function_java)
dataset_java

  return cls(pa.Table.from_pandas(*args, **kwargs))


Dataset({
    features: ['0', 'func_name', 'func_body', 'func_body_without_comment'],
    num_rows: 1000
})

#### Funcion in Java name

In [84]:
print(dataset_java['func_name'][214])

skipLast


#### Funcion in Java with comments and docstring

In [85]:
print(dataset_java['func_body'][214])

public final void <extra_id_0> {
        ObjectHelper.requireNonNull(unit, "unit is null");
        ObjectHelper.requireNonNull(scheduler, "scheduler is null");
        ObjectHelper.verifyPositive(bufferSize, "bufferSize");
        // the internal buffer holds pairs of (timestamp, value) so double the default buffer size
        int s = bufferSize << 1;
        return RxJavaPlugins.onAssembly(new ObservableSkipLastTimed<T>(this, time, unit, scheduler, s, delayError));
    }


#### Funcion in Java WITHOUT comments and docstring

In [86]:
print(dataset_java['func_body_without_comment'][214])

public final void <extra_id_0> {
        ObjectHelper.requireNonNull(unit, "unit is null");
        ObjectHelper.requireNonNull(scheduler, "scheduler is null");
        ObjectHelper.verifyPositive(bufferSize, "bufferSize");
        
        int s = bufferSize << 1;
        return RxJavaPlugins.onAssembly(new ObservableSkipLastTimed<T>(this, time, unit, scheduler, s, delayError));
    }


### Predict java function name

#### Predict java function name WITHOUT comment

In [87]:
predict(dataset_java["func_body_without_comment"], dataset_java["func_name"])


public final void <extra_id_0> {
        final Observer<? super V> observer = downstream;
        final SimplePlainQueue<U> q = queue;

        if (wip.get() == 0 && wip.compareAndSet(0, 1)) {
            if (q.isEmpty()) {
                accept(observer, value);
                if (leave(-1) == 0) {
                    return;
                }
            } else {
                q.offer(value);
            }
        } else {
            q.offer(value);
            if (!enter()) {
                return;
            }
        }
        QueueDrainHelper.drainLoop(q, observer, delayError, disposable, this);
    }
X: 431

['accept', 'ObservableAmb<T>', 'wrap', 'concatMapDelayError', 'concatMap', 'concat<T>', 'concatDelayError', 'concatArray', 'concatMapEagerDelayError', 'concatMapEagerDelayError', 'concatDelayError', 'concat', 'concatMapEager', 'concatMapEagerDelayError', 'ObservableEmpty<T>', 'onError', 'flatMap', 'ObservableFromIterable<T>', 'subscribe', 'generate', 'generate', 'Obse

431

#### Predict java function name with comment

In [88]:
predict(dataset_java["func_body"], dataset_java["func_name"])


public final void <extra_id_0> {
        final Observer<? super V> observer = downstream;
        final SimplePlainQueue<U> q = queue;

        if (wip.get() == 0 && wip.compareAndSet(0, 1)) {
            if (q.isEmpty()) {
                accept(observer, value);
                if (leave(-1) == 0) {
                    return;
                }
            } else {
                q.offer(value);
            }
        } else {
            q.offer(value);
            if (!enter()) {
                return;
            }
        }
        QueueDrainHelper.drainLoop(q, observer, delayError, disposable, this);
    }
X: 432

['accept', 'ObservableAmb<T>', 'wrap', 'concatMapDelayError', 'concatMap', 'concat<T>', 'concatDelayError', 'concatArray', 'concatMapEagerDelayError', 'concatMapEagerDelayError', 'concatDelayError', 'concat', 'concatMapEager', 'concatMapEagerDelayError', 'ObservableEmpty<T>', 'onError', 'flatMap', 'ObservableFromIterable<T>', 'subscribe', 'generate', 'generate', 'Obse

432