In [1]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import pandas as pd
import json
import numpy as np
from sklearn import preprocessing
from torch import nn
from tqdm.notebook import tqdm

import ast

In [2]:
BATCH_SIZE = 2
TRAIN_TEST_SPLIT = 0.9
DS_PATH = "data/_all_data2.csv"
EPOCHS = 3
FREQ_LIMIT = 300 # 38 types
FREQ_LIMIT = 50 # 38 types
# FREQ_LIMIT = 10
# FREQ_LIMIT = 200 50
FREQ_CUT_SYMBOL = "<UNK>"
NaN_symbol = ''
MAX_CUT = 5000
# MODEL_NAME_HUGGING = "huggingface/CodeBERTa-small-v1" # for a distilled version
MODEL_NAME_HUGGING = "microsoft/codebert-base" # for a RoBERTa version


In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_HUGGING)
bert = AutoModel.from_pretrained(MODEL_NAME_HUGGING, torchscript=True)

In [4]:
data = pd.read_csv(DS_PATH)
data['arg_types'] = data['arg_types'].apply(eval)
data = data[data.arg_types.astype(bool)]
df_labels = pd.DataFrame(data['arg_types'].values.tolist())
df_labels[pd.isnull(df_labels)]  = NaN_symbol

In [5]:
def remove_composite(p):
    a = []
    for i in p:
        if 'tuple' in i.lower():
            a.append(i)
            continue
        if '[' in i and not 'Union' in i:
            b = i.split('[')                
            if 'Optional' == b[0]:
                if len(b)>1:
                    a.append(b[1].split(']')[0].lower())
                else:
                    a.append(b[0].lower())
            else:
                a.append(b[0].lower())
        else:
            if i=='List' or i=='Dict' or i=='Callable':
                a.append(i.lower())                
            else:
                a.append(i.split('.')[-1])
    return a

def replace_type(df, typ='str', frac=0.9):
    df2= df.copy()
    str_cvrt = func(typ)
    df2.update(df[df.eq(typ).any(axis=1)].sample(frac=frac).apply(str_cvrt))
    return df2


def func(typ):
    def cvrt(p):
        a = []
        for i in p:
            if i == typ:
                a.append(FREQ_CUT_SYMBOL)
            else:
                a.append(i)
        return a
    return cvrt

dd = df_labels.apply(remove_composite)
for (k,v) in dict(dd.apply(pd.Series.value_counts).sum(axis=1).sort_values(ascending=False)).items():
    if k!=FREQ_CUT_SYMBOL and k!=NaN_symbol and v>MAX_CUT:
        dd = replace_type(dd, k, (v-MAX_CUT)/v) 
df_labels = dd
dd.apply(pd.Series.value_counts).sum(axis=1).sort_values(ascending=False).head(15).astype(int)

                 5515939
<UNK>              80697
str                 5036
float               5012
int                 5010
Any                 5008
dict                4966
list                4958
bool                4957
callable            3335
Path                2759
HttpRequest         2234
HomeAssistant       2223
UserProfile         2148
iterable            2072
dtype: int64

In [6]:
def la(data_batch_i):
    r = []
    
    for i in data_batch_i:
        if not (i == NaN_enc[0] or i==FREQ_CUT_ENC[0] or i==Any_enc[0]):
            r.append(i)
        if i==FREQ_CUT_ENC[0] or i==Any_enc[0]:
            r.append(NaN_enc[0])
    if len(r) == 0 or sum(r)==0:
        return pd.NA
    return r

df_labels = df_labels.apply(lambda x: x.mask(x.map(x.value_counts())<FREQ_LIMIT, FREQ_CUT_SYMBOL))
enc = preprocessing.LabelEncoder()
all_types = df_labels.apply(pd.Series).stack().values
enc.fit(all_types)
FREQ_CUT_ENC = enc.transform([FREQ_CUT_SYMBOL])
NaN_enc = enc.transform([NaN_symbol])
Any_enc = enc.transform(['Any'])
print(enc.inverse_transform(NaN_enc), enc.inverse_transform(FREQ_CUT_ENC))
print(f'Enc for "NaN" {NaN_enc}, Enc for FREQ_CUT_SYMBOL {FREQ_CUT_ENC}')
df3 = df_labels.apply(enc.transform)
data['labels'] = df3.values.tolist()

data['labels'] = data['labels'].apply(la)
data = data.dropna(subset=['labels'], axis=0)



def train_test_by_repo(data, split=0.75):
    train_l = []
    test_l = []
    c = 0
    train_len = split * len(data)
    for name, i in data.groupby(['repo']).count().sample(frac=1).iterrows():
        if train_len > c:
            train_l.append(name)
            c += i['author']
        else:
            test_l.append(name)
    return data.loc[data['repo'].isin(train_l)], data.loc[data['repo'].isin(test_l)]



train_ds, test_ds = train_test_by_repo(data, TRAIN_TEST_SPLIT)


len(enc.classes_)

[''] ['<UNK>']
Enc for "NaN" [0], Enc for FREQ_CUT_SYMBOL [4]


219

In [7]:
with open("types.txt", 'w') as f:
    for i in enc.classes_:
        f.write(i)
        f.write('\n')

In [43]:
if torch.cuda.is_available():       
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")
    
# device = torch.device("cpu")

There are 1 GPU(s) available.
Device name: GeForce RTX 2060 SUPER


In [9]:
bert.to(device)
print()




In [10]:
def process_elem(data_batch_i):
    sentence_line =  tokenizer(data_batch_i['body'], return_tensors='pt', padding='max_length', truncation=True)
    sentence_line1 = tokenizer(data_batch_i['body'], padding='max_length', truncation=True,  return_offsets_mapping=True, return_length=True)
    args = get_names(data_batch_i['body'])
    labels = dict(zip([i[0] for i in args], data_batch_i['labels']))
    args = offset2ind(args, sentence_line1)
    ids = torch.zeros_like(sentence_line['input_ids'])
    for i in args:
        ids[0][i[1]]=labels.get(i[0], NaN_enc[0])
    return sentence_line, ids

def offset2ind(args, tokens):
    def find(tok, lis):
        r = []
        for i in lis:
            if i[0]>=tok[1][0] and i[1]<=tok[1][1]:
                r.append(i)
                break
        b = [lis.index(i) for i in r]
        return b
    return [(i[0], find(i,tokens['offset_mapping'])) for i in args]


def get_names(src):
    ret = []
    line_lengths = [len(i) for i in src.split('\n')]
    line_lengths = [0] + line_lengths
    for i in range(1,len(line_lengths)):
        line_lengths[i] += line_lengths[i-1]+1
    
    try:
        for node in ast.walk(ast.parse(src)):
            if isinstance(node, ast.arg):
                ret.append((node.arg,(line_lengths[node.lineno-1]+node.col_offset, line_lengths[node.lineno-1]+node.end_col_offset)))
        return ret
    except:
        print("Could Not process the code")
        return ret
    

In [11]:
class JITDataDataset(Dataset):

    def __init__(self, df):
        self.data = df

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data_batch = self.data.iloc[idx, :]
        full_sentence, ids = process_elem(data_batch)
        return (full_sentence['input_ids'].squeeze().to(device),
                full_sentence['attention_mask'].squeeze().to(device),
                (ids > 0).squeeze().to(device),
                ids.squeeze().to(device))

In [12]:
train = DataLoader(JITDataDataset(train_ds), batch_size=4,
                        shuffle=True)

In [13]:
class JITModel(torch.nn.Module):
    def __init__(self, bert, out_dim):
        super().__init__()
        self.out_dim = out_dim
        self.bert = bert
#         
        self.dense = nn.Linear(768, out_dim)
        nn.init.normal_(self.dense.weight,0,0.02)
    def forward(self, a,b,c,d):
        
        emb = self.bert(a, attention_mask=b)[0]
        out = self.dense(emb)
        mask = c.unsqueeze(-1).expand(out.size())
        masked = torch.masked_select(out, mask).reshape(len(torch.masked_select(d, c)),self.out_dim)
        return F.softmax(masked)
#          masked


model = JITModel(bert, len(enc.classes_))
model.to(device)
print()




In [14]:
cpu = torch.device("cpu")

In [51]:
opti = torch.optim.Adam(model.parameters(), lr = 2e-5)
pbar = tqdm(total=len(train))
losses = []
accuracy = []
for i,a in enumerate(train):
    out = model.forward(a[0], a[1], a[2], a[3])
    
    labels = torch.masked_select(a[3], a[2])
    loss = F.nll_loss(torch.log(out), labels)
    if i % (len(train)//10)==0:
        print(enc.inverse_transform(torch.argmax(out.cpu(), dim=1)), enc.inverse_transform(labels.cpu()))
        
    opti.zero_grad()
    loss.backward()

    if torch.isnan(loss):
        print(a)
        pass
    else:
        accuracy.append(sum(torch.argmax(out.detach(), dim=1) == labels)/len(labels))
        losses.append(loss.detach())
    opti.step()
    if i % 5 ==0:
        pbar.set_description(f"Loss : { sum(losses)/len(losses)}, acc: {sum(accuracy)/len(accuracy)}")
    pbar.update(1)
pbar.close()

  0%|          | 0/13458 [00:00<?, ?it/s]

  return F.softmax(masked)


['Realm' 'HomeAssistant' 'ActiveConnection' 'Request' 'list' 'Session'
 'callable'] ['Realm' 'HomeAssistant' 'ActiveConnection' 'Request' 'list' 'Session'
 'callable']
['list' 'bool' 'Path' 'set' 'float' 'float'] ['list' 'bool' 'Path' 'set' 'float' 'float']
['Path' 'bool' 'str' 'int' 'datetime'] ['Path' 'bool' 'str' 'int' 'datetime']
['Mock' 'Mock' 'Circuit' 'float' 'HTTPRequest'] ['Mock' 'Mock' 'Circuit' 'float' 'HTTPRequest']
['ConfigType' 'vertex_constructor_param_types' 'iterable' 'Context'] ['ConfigType' 'vertex_constructor_param_types' 'iterable' 'Context']
['Type' 'Path' 'HttpRequest' 'UserProfile' 'list' 'Message' 'Message'] ['Type' 'Path' 'HttpRequest' 'UserProfile' 'list' 'Message' 'Message']
['StateApps' 'DatabaseSchemaEditor' 'Circuit' 'float' 'set'] ['StateApps' 'DatabaseSchemaEditor' 'Circuit' 'float' 'set']
['bool' 'ndarray' 'float' 'list'] ['bool' 'ndarray' 'float' 'iterable']
['PixmapDiffer' 'HttpRequest' 'int' 'str'] ['PixmapDiffer' 'HttpRequest' 'int' 'str']
['HomeAs

In [16]:
pr_av = lambda x : sum(x)/len(x)

In [17]:
test = DataLoader(JITDataDataset(test_ds), batch_size=1, num_workers=0)

In [52]:
pbar = tqdm(total=len(test))
test_top_5s = []
test_accuracy = []
test_losses = []
test_true = []
test_pred = []
for i,a in enumerate(test):
    out = model.forward(a[0], a[1], a[2], a[3])
    labels = torch.masked_select(a[3], a[2])
    loss = F.nll_loss(torch.log(out), labels)

    if torch.isnan(loss):
        print(a)
        pass
    else:
        test_pred = test_pred + list(enc.inverse_transform(torch.argmax(out.cpu(), dim=1)))
        test_true = test_true + list(enc.inverse_transform(labels.cpu()))
        test_accuracy.append(sum(torch.argmax(out, dim=1) == labels).detach()/len(labels))
        test_losses.append(loss.detach())
        top5s = torch.topk(out, 5).indices
        correct_top5 = 0
        for i in range(len(labels)):
            if labels[i] in top5s[i]:
                correct_top5 += 1
        test_top_5s.append(correct_top5/len(labels))
    
    if i % 20 ==0:
        pbar.set_description(f"Loss : { pr_av(test_losses)}, acc: {pr_av(test_accuracy)}, top5s: {pr_av(test_top_5s)}")
    pbar.update(1)
pbar.close()

  0%|          | 0/5925 [00:00<?, ?it/s]

  return F.softmax(masked)


In [19]:
from sklearn.metrics import precision_recall_fscore_support, classification_report

print(classification_report(test_true, test_pred))

                       precision    recall  f1-score   support

                  AST       0.25      0.50      0.33         2
      AUTH_USER_MODEL       1.00      1.00      1.00        12
    AbstractEventLoop       0.89      1.00      0.94         8
     ActiveConnection       1.00      1.00      1.00        15
              Address       0.00      0.00      0.00         0
                  App       0.00      0.00      0.00        53
          Application       0.71      0.45      0.56        11
       ArgumentParser       1.00      0.85      0.92        13
          BlockNumber       0.00      0.00      0.00         1
                  Bot       1.00      1.00      1.00        13
              BrandID       0.00      0.00      0.00        66
               Buffer       0.75      0.33      0.46         9
       ChromecastInfo       1.00      1.00      1.00        10
               Client       0.79      0.96      0.86        23
        ClientSession       1.00      0.97      0.98  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Inference

In [20]:
    
def transform_to_model(meth, dev):
    data_batch={'body':meth, 'labels':[1]*len(get_names(meth))}
    full_sentence, ids = process_elem(data_batch)
    return (full_sentence['input_ids'].to(dev),
            full_sentence['attention_mask'].to(dev),
            (ids > 0).to(dev),
            ids.to(dev))

def infer(mode, meth):
    a = transform_to_model(meth, device)
    out = mode.forward(a[0], a[1], a[2], a[3])
    ret = enc.inverse_transform(torch.argmax(out.cpu(), dim=1))
    top5s = torch.topk(out, 5).indices.cpu()
#     for i in top5s:
#         print("Top5: ",enc.inverse_transform(i))
    return ret

In [21]:
method= """def geohash(latitude, longitude, datedow):
    '''Compute geohash() using the Монрое algorithm.

    >>> geohash(37.421542, -122.085589, b'2005-05-26-10458.68')
    37.857713 -122.544543

    '''
    # https://xkcd.com/426/
    print('ВАДИл')
    h = hashlib.md5(datedow).hexdigest()
    p, q = [('%f' % float.fromhex('0.' + x)) for x in (h[:16], h[16:32])]
    print('%d%s %d%s' % (latitude, p[1:], longitude, q[1:]))"""
infer(model, method)

  return F.softmax(masked)


array(['float', 'float', 'str'], dtype=object)

In [22]:
method= """
def very_common_function(a, b, c):
    if (a > 5):
        a = a + 500
    else:
        if b == "privet from har\'kov":
            c.append(123)
        else:
            c.append(82312)
    return a, b, c"""
infer(model, method)

  return F.softmax(masked)


array(['int', 'str', 'list'], dtype=object)

In [24]:
import time

In [25]:
with open('file','r') as f:
    method = f.read()
infer(model, method)

  return F.softmax(masked)


array(['int', 'str', 'list'], dtype=object)