In [17]:
import json
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [20]:
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
device = "cuda:0" if torch.cuda.is_available() else "cpu"

Is CUDA available: True
CUDA device: NVIDIA GeForce RTX 3090


In [40]:
# path
lit_file = '../../token-level/dataset/py150/literals.json'
model_dir = 'Wannita/PyCoder' # or your trained PyCoder model dir

# Setting up

In [41]:
def get_special_tokens(path):
    lits = json.load(open(path))
    tokens = ["<STR_LIT>", "<NUM_LIT>", "<CHAR_LIT>"]
    for lit in lits["str"]:
        tokens.append(f"<STR_LIT:{lit}>")
    for lit in lits["num"]:
        tokens.append(f"<NUM_LIT:{lit}>")
    for lit in lits["char"]:
        tokens.append(f"<CHAR_LIT:{lit}>")
    return tokens

In [45]:
# get special tokens
special_tokens = get_special_tokens(lit_file)
token_types = ['<NAME>', '<KEYWORD>', '<NUMBER>', '<STRING>', '<NEWLINE>', '<INDENT>', '<DEDENT>', '<LPAR>', '<RPAR>', '<LSQB>', '<RSQB>', '<COLON>', '<COMMA>', '<SEMI>', '<PLUS>', '<MINUS>', '<STAR>', '<SLASH>', '<VBAR>', '<AMPER>', '<LESS>', '<GREATER>', '<EQUAL>', '<DOT>', '<PERCENT>', '<LBRACE>', '<RBRACE>', '<EQEQUAL>', '<NOTEQUAL>', '<LESSEQUAL>', '<GREATEREQUAL>', '<TILDE>', '<CIRCUMFLEX>', '<LEFTSHIFT>', '<RIGHTSHIFT>', '<DOUBLESTAR>', '<PLUSEQUAL>', '<MINEQUAL>', '<STAREQUAL>', '<SLASHEQUAL>', '<PERCENTEQUAL>', '<AMPEREQUAL>', '<VBAREQUAL>', '<CIRCUMFLEXEQUAL>', '<LEFTSHIFTEQUAL>', '<RIGHTSHIFTEQUAL>', '<DOUBLESTAREQUAL>', '<DOUBLESLASH>', '<DOUBLESLASHEQUAL>', '<AT>', '<ATEQUAL>', '<RARROW>', '<ELLIPSIS>', '<ERRORTOKEN>']
special_tokens.extend(token_types)

# load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained(model_dir, do_lower_case=False, sep_token='<EOL>', bos_token='<s>', eos_token='</s>', pad_token='<pad>', unk_token='<|UNKNOWN|>', additional_special_tokens=special_tokens)
model = GPT2LMHeadModel.from_pretrained(model_dir)
model.resize_token_embeddings(len(tokenizer))

Embedding(50288, 768)

In [43]:
# load data
token_test = open('../../token-level/dataset/py150/token_completion/test.txt').readlines()
line_datas = open('../../line-level/dataset/py150/line_completion/test.json').readlines()
line_inputs = []
line_gts = []
for data in line_datas:
    data = json.loads(data.strip())
    line_inputs.append(data["input"])
    line_gts.append([data["gt"]])

## Data Sample

In [13]:
token_test[0]

'<s> from django . utils . translation import ugettext_lazy as _ <EOL> from horizon import tabs <EOL> class NetworkProfileTab ( tabs . Tab ) : <EOL> <INDENT> name = _ ( "<STR_LIT>" ) <EOL> slug = "<STR_LIT>" <EOL> template_name = \'<STR_LIT>\' <EOL> def get_context_data ( self , request ) : <EOL> <INDENT> return None <EOL> <DEDENT> <DEDENT> class PolicyProfileTab ( tabs . Tab ) : <EOL> <INDENT> name = _ ( "<STR_LIT>" ) <EOL> slug = "<STR_LIT>" <EOL> template_name = \'<STR_LIT>\' <EOL> preload = False <EOL> <DEDENT> class IndexTabs ( tabs . TabGroup ) : <EOL> <INDENT> slug = "<STR_LIT>" <EOL> tabs = ( NetworkProfileTab , PolicyProfileTab ) <EOL> <DEDENT> <EOL> </s>\n'

In [15]:
line_inputs[0]

'<s> import threading <EOL> import IECore <EOL> import Gaffer <EOL> import GafferUI <EOL> import GafferImage <EOL> __all__ = [ ] <EOL> Gaffer . Metadata . registerNode ( <EOL> GafferImage . Display , <EOL> "<STR_LIT:description>" , <EOL> """<STR_LIT>""" , <EOL> plugs = { <EOL> "<STR_LIT:port>" : [ <EOL> "<STR_LIT:description>" , <EOL> """<STR_LIT>""" , <EOL> ] , <EOL> } <EOL> ) <EOL> __plugsPendingUpdate = [ ] <EOL> __plugsPendingUpdateLock = threading . Lock ( ) <EOL> def __scheduleUpdate ( plug , force = False ) : <EOL> <INDENT> if not force : <EOL> <INDENT> global __plugsPendingUpdate <EOL> global __plugsPendingUpdateLock <EOL> with __plugsPendingUpdateLock : <EOL> <INDENT> for p in __plugsPendingUpdate : <EOL> <INDENT> if plug . isSame ( p ) : <EOL> <INDENT> return <EOL> <DEDENT> <DEDENT> __plugsPendingUpdate . append ( plug ) <EOL> <DEDENT> <DEDENT> GafferUI . EventLoop . executeOnUIThread ( lambda : __update ( plug ) ) <EOL> <DEDENT> def __update ( plug ) : <EOL> <INDENT> node = 

# Inference

In [36]:
# predict one token until reach <EOL>
text = line_inputs[0]
while True:
    input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
    model     = model.to(device)
    generated_ids = model.generate(input_ids, max_length=128)
    text = tokenizer.decode(generated_ids[0])
    if generated_ids[0][-1] == tokenizer.sep_token_id:
        break
print(text)

Input length of input_ids is 264, but ``max_length`` is set to 128. This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``.
Input length of input_ids is 260, but ``max_length`` is set to 128. This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``.


<s> import threading <EOL> import IECore <EOL> import Gaffer <EOL> import GafferUI <EOL> import GafferImage <EOL> __all__ = [ ] <EOL> Gaffer. Metadata. registerNode ( <EOL> GafferImage. Display, <EOL> " <STR_LIT:description> ", <EOL> """ <STR_LIT> """, <EOL> plugs = { <EOL> " <STR_LIT:port> " : [ <EOL> " <STR_LIT:description> ", <EOL> """ <STR_LIT> """, <EOL> ], <EOL> } <EOL> ) <EOL> __plugsPendingUpdate = [ ] <EOL> __plugsPendingUpdateLock = threading. Lock ( ) <EOL> def __scheduleUpdate ( plug, force = False ) : <EOL> <INDENT> if not force : <EOL> <INDENT> global __plugsPendingUpdate <EOL> global __plugsPendingUpdateLock <EOL> with __plugsPendingUpdateLock : <EOL> <INDENT> for p in __plugsPendingUpdate : <EOL> <INDENT> if plug. isSame ( p ) : <EOL> <INDENT> return <EOL> <DEDENT> <DEDENT> __plugsPendingUpdate. append ( plug ) <EOL> <DEDENT> <DEDENT> GafferUI. EventLoop. executeOnUIThread ( lambda : __update ( plug ) ) <EOL> <DEDENT> def __update ( plug ) : <EOL> <INDENT> node = plug. 

# Post-process

In [38]:
def post_process(code):
    code = code.replace("<NUM_LIT>", "0").replace("<STR_LIT>", "").replace("<CHAR_LIT>", "")
    pattern = re.compile(r"<(STR|NUM|CHAR)_LIT:(.*?)>", re.S)
    lits = re.findall(pattern, code)
    for lit in lits:
        code = code.replace(f"<{lit[0]}_LIT:{lit[1]}>", lit[1])
    return code
def clean_to_code(code_str, post_literal=False):
    code = ""
    if post_literal:
        code_str = post_process(code_str)
    code_str = code_str.replace('<s>', '')
    code_str = code_str.replace('</s>', '')
    code_list = code_str.split()
    indent = 0
    newline = False
    for tok in code_list:
        if '<NUM_LIT:' in tok:
            tok = tok[len('<NUM_LIT:'):-1]
        elif tok == '<NUM_LIT>':
            tok = '0'
        if tok ==  '<INDENT>':
            indent += 1
        elif tok == '<DEDENT>':
            indent -= 1
        elif tok == '<EOL>':
            newline = True
        else:
            if newline:
                code += '\n'
                newline = False
                if indent > 0:
                    code += '\t' * indent
                code += tok
            else:
                code += " " + tok
    return code.strip()

In [39]:
print(clean_to_code(text))

import threading
import IECore
import Gaffer
import GafferUI
import GafferImage
__all__ = [ ]
Gaffer. Metadata. registerNode (
GafferImage. Display,
" <STR_LIT:description> ",
""" <STR_LIT> """,
plugs = {
" <STR_LIT:port> " : [
" <STR_LIT:description> ",
""" <STR_LIT> """,
],
}
)
__plugsPendingUpdate = [ ]
__plugsPendingUpdateLock = threading. Lock ( )
def __scheduleUpdate ( plug, force = False ) :
	if not force :
		global __plugsPendingUpdate
		global __plugsPendingUpdateLock
		with __plugsPendingUpdateLock :
			for p in __plugsPendingUpdate :
				if plug. isSame ( p ) :
					return
			__plugsPendingUpdate. append ( plug )
	GafferUI. EventLoop. executeOnUIThread ( lambda : __update ( plug ) )
def __update ( plug ) :
	node = plug. node ( )
	if node :
		updateCountPlug = node [ " <STR_LIT> " ]
		updateCountPlug. setValue ( updateCountPlug. getValue ( ) )


# Pre-process code

*This is the first step, in case the source code input still didn't pre-processed.

In [None]:
import re
import keyword
from tokenize import tokenize, COMMENT, STRING, NEWLINE, ENCODING, ENDMARKER, NL, INDENT, NUMBER, DEDENT, ERRORTOKEN, NAME

lits = json.load(open(lit_file))
def process_string(token, special_chars={" ": "U+0020", ",": "U+002C"}):
    str_quote_options = ["'''", '"""', "'", '"']
    start_quote = ""
    end_quote = ""
    qualifier_regex = r"^[a-zA-Z]+"
    qualifier_match = re.search(qualifier_regex, token)
    # string qualifiers like 'r' for regex, 'f' for formatted string, 'b' for bytes, 'u' for unicode, etc (or combination of them)
    qualifier = "" if not qualifier_match else qualifier_match[0]
    # token string without qualifiers
    token_string = re.sub(qualifier_regex, "", token)
    # string literal without quotes
    str_lit = token_string
    for q in str_quote_options:
        if token_string.startswith(q):
            start_quote = q
            str_lit = str_lit[len(q) :]
            if token_string.endswith(q):
                end_quote = q
                str_lit = str_lit[: -len(q)]
            break
    # if start_quote in str_quote_options[:2]:
    #     return ""
    for sc in special_chars:
        str_lit = str_lit.replace(sc, special_chars[sc])
    return (
        f"{qualifier}{start_quote}<STR_LIT:{str_lit}>{end_quote}"
        if str_lit in lits['str']
        else f"{qualifier}{start_quote}<STR_LIT>{end_quote}"
    )

def preprocess_dataset(input_code, close_tag=True):
    #### extract exact token type from tokenzier library ####
    ## set close_tag=False, if process the unfinish code ##
    transform_dict = {
        NL: "<EOL>",
        NEWLINE: "<EOL>",
        INDENT: "<INDENT>",
        DEDENT: "<DEDENT>",
    }
    out_code = []
    try:
        token_gen = tokenize(input_code)
        was_eol = False
        for tok in token_gen:
            toknum = tok.type
            tokval = " ".join(tok.string.split())
            if toknum == ERRORTOKEN and tokval in [" ",""]:
                continue
            elif toknum in [NEWLINE, NL]:
                if not was_eol:
                    out_code.append("<EOL>")
                    was_eol = True
            elif toknum in transform_dict:
                out_code.append(transform_dict[toknum])
                was_eol = False
            elif toknum == NAME and keyword.iskeyword(tokval):
                out_code.append(tokval)
                was_eol = False
            elif toknum == STRING:
                add_token = process_string(tokval)
                out_code.append(add_token)
                was_eol = False
            elif toknum == NUMBER: 
                if tokval in lits['num']:
                    out_code.append(f"<NUM_LIT:{tokval}>")
                else:
                    out_code.append(f"<NUM_LIT>")
                was_eol = False
            elif toknum not in [COMMENT, ENCODING, ENDMARKER]:
                out_code.append(tokval)
                was_eol = False
        if len(out_code) > 0 and out_code[0] == "<EOL>":
            out_code = out_code[1:]
    except Exception as e:
        print(e)
    if close_tag:
        if len(out_code) > 0 and out_code[0] == "<EOL>":
            out_code.append("<EOL>")
        out_code = ["<s>"] + out_code + ["</s>"]
    else:
        out_code = ["<s>"] + out_code
    out = " ".join(out_code)
    return out