In [1]:
from typing import Generator
import tree_sitter_java as tsjava
from tree_sitter import Language, Parser, Tree, Node
from unixcoder import UniXcoder
import torch

In [2]:
parser = Parser(Language(tsjava.language()))

In [3]:
# Download from https://github.com/apache/ant-ivy/blob/master/src/java/org/apache/ivy/Ivy.java
with open('Ivy.java') as f:
    text = f.read()
text



In [4]:
tree = parser.parse(bytes(text, 'utf8'))
tree

<tree_sitter.Tree at 0x25786dd28e0>

In [5]:
def traverse_tree(tree: Tree) -> Generator[Node, None, None]:
    cursor = tree.walk()

    visited_children = False
    while True:
        if not visited_children:
            yield cursor.node
            if not cursor.goto_first_child():
                visited_children = True
        elif cursor.goto_next_sibling():
            visited_children = False
        elif not cursor.goto_parent():
            break

In [6]:
# See unique node types in tree
x = []
for node in traverse_tree(tree):
    x.append(node.type)
set(x)

{'!',
 '!=',
 '"',
 '&&',
 '(',
 ')',
 '+',
 ',',
 '.',
 '/',
 ':',
 ';',
 '<',
 '=',
 '==',
 '>',
 '@',
 '[',
 ']',
 'annotation',
 'annotation_argument_list',
 'argument_list',
 'array_type',
 'assignment_expression',
 'binary_expression',
 'block',
 'block_comment',
 'boolean_type',
 'break',
 'break_statement',
 'case',
 'cast_expression',
 'catch',
 'catch_clause',
 'catch_formal_parameter',
 'catch_type',
 'class',
 'class_body',
 'class_declaration',
 'class_literal',
 'constructor_body',
 'constructor_declaration',
 'decimal_integer_literal',
 'default',
 'dimensions',
 'else',
 'enhanced_for_statement',
 'expression_statement',
 'false',
 'field_access',
 'field_declaration',
 'final',
 'finally',
 'finally_clause',
 'for',
 'formal_parameter',
 'formal_parameters',
 'generic_type',
 'identifier',
 'if',
 'if_statement',
 'import',
 'import_declaration',
 'instanceof',
 'instanceof_expression',
 'int',
 'integral_type',
 'interface',
 'interface_body',
 'interface_declaration'

In [9]:
i = 0
for node in traverse_tree(tree):
    if node.type == 'class_declaration':
        print('i =', i)
        print(text[node.start_byte:node.end_byte])
        i += 1
i

i = 0
public class Ivy {
    /**
     * Callback used to execute a set of Ivy related methods within an {@link IvyContext}.
     *
     * @see Ivy#execute(org.apache.ivy.Ivy.IvyCallback)
     */
    public interface IvyCallback {
        /**
         * Executes Ivy related job within an {@link IvyContext}
         *
         * @param ivy
         *            the {@link Ivy} instance to which this callback is related
         * @param context
         *            the {@link IvyContext} in which this callback is executed
         * @return the result of this job, <code>null</code> if there is no result
         */
        Object doInIvyContext(Ivy ivy, IvyContext context);
    }

    private static final int KILO = 1024;

    /**
     * @deprecated Use the {@link DateUtil} utility class instead.
     */
    @Deprecated
    public static final SimpleDateFormat DATE_FORMAT = new SimpleDateFormat(
            DateUtil.DATE_FORMAT_PATTERN);

    /**
     * the current version of Ivy, as di

1

In [10]:
# package, class, method, and token
# Tokens list
# abstract	continue	for	new	switch
# assert***	default	goto*	package	synchronized
# boolean	do	if	private	this
# break	double	implements	protected	throw
# byte	else	import	public	throws
# case	enum****	instanceof	return	transient
# catch	extends	int	short	try
# char	final	interface	static	void
# class	finally	long	strictfp**	volatile
# const*	float	native	super	while
# TODO Find all tokens in Java
tokens = [
    'abstract', 'assert', 'boolean_type', 'break', 'byte', 'case', 'catch', 'case', 'catch', 'char', 'class', 'const', 'continue', 'int', # Keywords
    'identifier',
    '!', '!=', '"', '&&', '+', '/', '<', '=', '==', '>', '@', '||', # Operators
    '(', ')', ',', '.' ':', ';', '[', ']', '{', '}' # Delimiters
]
keys = ['package', 'class', 'method', 'token']
code_frags = {keys[i]:list() for i in range(len(keys))}

for node in traverse_tree(tree):
    code = text[node.start_byte:node.end_byte]
    if node.type == 'package_declaration':
        # If there is a package delcaration, let the whole text be the code fragment
        code_frags['package'].append(text)
    elif node.type == 'class_declaration':
        code_frags['class'].append(code)
    elif node.type == 'method_declaration':
        code_frags['method'].append(code)
    elif node.type in tokens:
        code_frags['token'].append(code)

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'
model = UniXcoder("microsoft/unixcoder-base")
model.to(device)

UniXcoder(
  (model): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(51416, 768, padding_idx=1)
      (position_embeddings): Embedding(1026, 768, padding_idx=1)
      (token_type_embeddings): Embedding(10, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): Layer

In [10]:
code = code_frags['class'][0]
tokens_ids = model.tokenize(code)
source_ids = torch.tensor(tokens_ids).to(device)
tokens_embeddings, code_embedding = model(source_ids)

: 

In [None]:
# # TODO Write code for collecting data in a single Java file
# data = [] # Data for Ivy.java
# for node in traverse_tree(tree):
#     if node.type == 'package_declaration':
#         level = 'package'
#     elif node.type == 'class_declaration':
#         level =  'class'
#     elif node.type == 'method_declaration':
#         level = 'method'
#     elif node.type in tokens:
#         level = 'token'
#     else:
#         continue
    
#     # Get code embedding
#     code = text[node.start_byte:node.end_byte]
#     tokens_ids = model.tokenize(code)
#     source_ids = torch.tensor(tokens_ids).to(device)
#     tokens_embeddings, code_embedding = model(source_ids)

#     data.append((level, node, code_embedding))