# 数据增强方法

由于有效的数据量十分有限，为了能够获得更好的训练效果，同时留出一定量的数据来对模型进行评估，我们设计了一些数据增强方法，以获取更多数据，这些方法部分借鉴了文本增强方法中常用的思想。

## 1. 针对算数运算的增强

对于所有的加/减法, 以 p 概率换为减/加法

In [19]:
import random
import re

In [20]:
def replace_add_sub(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    replaced = ''
    for token in tokens:
        q = random.random()
        if token == '+' and q <= p:
            replaced += '- '
        elif token == '-' and q <= p:
            replaced += '+ '
        else:
            replaced += (token + ' ')
    return replaced

In [21]:
text = r"""
def add_sub(a, b):
    c = a + b
    d = c - a
    b = a * c + b
    e = b + a - c * d - q
    return e - b + d + a - c
"""

In [22]:
replace_add_sub(text, 0.1)

' def add_sub(a, b): c = a + b d = c - a b = a * c + b e = b + a - c * d - q return e - b + d + a - c  '

## 2. 针对逻辑运算的数据增强

### 2.1 以概率 p 删去某个条件分量

In [23]:
def remove_cond_component_cpp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    len_in_tokens = len(tokens)
    res = ''
    i = 0
    while i < len_in_tokens:
        q = random.random()
        if (tokens[i] == '&&' or tokens[i] == '||') and q <= p:
            while i < len_in_tokens and tokens[i] != ')':
                i += 1
        if i < len_in_tokens:
            res += (tokens[i] + ' ')
        i += 1
    return res        

In [24]:
text = r"""
if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }
"""

In [25]:
remove_cond_component_cpp(text, 0.2)

' if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }  '

### 2.2 以概率 p 把每个 || 和 && 互换

In [26]:
def swap_cond_cpp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '&&' and q <= p:
            res += '|| '
        elif token == '||' and q <= p:
            res += '&& '
        else:
            res += token
            res += ' '
    return res
    

In [27]:
def swap_cond_py(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == 'and' and q <= p:
            res += 'or '
        elif token == 'or' and q <= p:
            res += 'and '
        else:
            res += token
            res += ' '
    return res
    

In [28]:
text = r"""
if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }
"""

In [29]:
swap_cond_cpp(text, 0.2)

' if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }  '

In [30]:
text = r"""
if a and b or c or d and e or f:
    i += 1
"""

In [31]:
swap_cond_py(text, 0.3)

' if a and b or c and d and e and f: i += 1  '

## 3. 针对对象访问的增强

对于所有的 . 和 ->, 以 p 概率换成 -> 或者 .

In [32]:
def swap_access(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '->' and q <= p:
            res += '. '
        elif token == '.' and q <= p:
            res += '-> '
        else:
            res += token
            res += ' '
    return res

In [33]:
text = r"""
if ( check_has_torch_function ( self ) ) { return handle_torch_function_setter ( ( THPVariable * ) self , ""names"" , names ) ; } auto & var = ( ( THPVariable * ) self ) -> cdata ;
"""

In [34]:
swap_access(text, 0.5)

' if ( check_has_torch_function ( self ) ) { return handle_torch_function_setter ( ( THPVariable * ) self , ""names"" , names ) ; } auto & var = ( ( THPVariable * ) self ) -> cdata ;  '

## 4. 针对比较运算的数据增强

以 p 概率把 > / >=, 或者 </ <= 替换

In [35]:
def swap_comp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '>=' and q <= p:
            res += '> '
        elif token == '<=' and q <= p:
            res += '< '
        else:
            res += token
            res += ' '
    return res

In [36]:
text = r"""
if a >= b {}
else if a <= b {}
else if a <= c {}
else if a >= d {}
"""

In [37]:
swap_comp(text, 0.3)

' if a >= b {} else if a <= b {} else if a < c {} else if a >= d {}  '

由于程序片段的破碎性，无法对程序片段进行语法分析，因此损失了很多可用的手段，以上的数据增强方法仅为在这种限制场景下的一些探索

In [38]:
def corrupt(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if q > p:
            res += token
            res += ' '
    return res

In [39]:
text = r"""
this is a token sequence to be corrupted sda sjfhals shags lasjf or 888.
"""

In [40]:
corrupt(text, 0.2)

' this is a token sequence be corrupted sda sjfhals shags or 888. '

In [66]:
def randomlyCorrupt(code_snippet, p):
    q = random.random()
    if q <= 0.15:
        return corrupt(code_snippet, p)
    return code_snippet

### 对训练集使用数据增强方法扩增数据

In [42]:
import pandas as pd

In [43]:
cpp_data_path = '../data/cpp/dl_cpp_tokenized.csv'

In [44]:
cpp_data = pd.read_csv(cpp_data_path)

In [45]:
py_data_path = '../data/py/dl_py_tokenized.csv'

In [46]:
py_data = pd.read_csv(py_data_path)

In [47]:
cpp_data.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...
1,std::chrono::milliseconds ( kWaitForAbortCommS...,std::chrono::milliseconds ( kWaitForAbortCommS...
2,if ( rank_ == 0 || ( commType != NCCLCommType:...,if ( rank_ == 0 || ( isP2POp ( opType ) && p2p...
3,"int numRanks , rank ; if ( commType == NCCLCom...","int numRanks , rank ; if ( sP2POp ( opType ) )..."
4,"if ( str != ""None"" && str != """" ) { throw std:...","if ( str != ""None"" ) { default_string = parse_..."


In [48]:
cpp_data.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,357,357
unique,329,329
top,const auto source_n = n -> sourceRange ( ) . s...,const auto source_n = n -> sourceRange ( ) . s...
freq,3,3


In [49]:
py_data.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,def __sizeof__ ( self ) : return super ( _Stor...,def __sizeof__ ( self ) : return ( super ( _St...
1,self . quant_min : int = quant_min self . quan...,self . quant_min : int = quant_min self . quan...
2,def get_default_qat_qconfig ( backend = gemm v...,def get_default_qat_qconfig ( backend = gemm v...
3,"exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha...","exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha..."
4,def get_post_build_suffix ( self ) -> str : if...,def get_post_build_suffix ( self ) -> str : if...


In [50]:
py_data.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,275,275
unique,268,268
top,if it % 100 == 0 : print ( eration % d - - Out...,if it % 100 == 0 : print ( eration % d - - Out...
freq,3,3


In [51]:
from sklearn.model_selection import train_test_split

In [52]:
cpp_train_buggy, cpp_test_buggy, cpp_train_fixed, cpp_test_fixed = train_test_split(cpp_data['BUGGY_CODE'], cpp_data['FIXED_CODE'], test_size=0.16, random_state=42)

In [53]:
py_train_buggy, py_test_buggy, py_train_fixed, py_test_fixed = train_test_split(py_data['BUGGY_CODE'], py_data['FIXED_CODE'], test_size=0.16, random_state=42)

In [114]:
py_aug = pd.DataFrame(data={'BUGGY_CODE': [], "FIXED_CODE": []})

In [115]:
py_aug_methods = [lambda x,p : x, replace_add_sub, swap_cond_py, swap_comp]

In [161]:
cpp_aug = pd.DataFrame(data={'BUGGY_CODE': [], "FIXED_CODE": []})

In [162]:
cpp_aug_methods = [lambda x,p : x, replace_add_sub, remove_cond_component_cpp, swap_cond_cpp, swap_access, swap_comp]

In [118]:
py_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE


In [119]:
cpp_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE


In [120]:
py_data_len = py_train_buggy.shape[0]
cpp_data_len = cpp_train_buggy.shape[0]

In [121]:
py_data_len

231

In [122]:
cpp_data_len

299

In [123]:
len(cpp_test_buggy)

58

In [124]:
len(py_test_buggy)

44

In [125]:
for i in range(py_data_len):
    buggy_code = py_train_buggy.iloc[i]
    fixed_code = py_train_fixed.iloc[i]
    py_aug.loc[len(py_aug.index)] = [buggy_code, fixed_code]
    for f in py_aug_methods:
        transformed = f(buggy_code, 0.2)
        transformed = randomlyCorrupt(transformed, 0.15)
        py_aug.loc[len(py_aug.index)] = [transformed, fixed_code]

In [126]:
py_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,symbolic_helper . _set_opset_version ( opset_v...,symbolic_helper . _set_opset_version ( opset_v...
1,symbolic_helper . _set_opset_version ( opset_v...,symbolic_helper . _set_opset_version ( opset_v...
2,symbolic_helper . _set_opset_version ( opset_v...,symbolic_helper . _set_opset_version ( opset_v...
3,symbolic_helper _set_opset_version ( opset_ver...,symbolic_helper . _set_opset_version ( opset_v...
4,symbolic_helper . _set_opset_version ( opset_v...,symbolic_helper . _set_opset_version ( opset_v...


In [127]:
py_aug.drop_duplicates(subset=['BUGGY_CODE'], keep='first', inplace=True)

In [128]:
py_aug.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,578,578
unique,578,226
top,symbolic_helper . _set_opset_version ( opset_v...,"fw_compiler : Callable , bw_compiler : Optiona..."
freq,1,5


In [163]:
for i in range(cpp_data_len):
    buggy_code = cpp_train_buggy.iloc[i]
    fixed_code = cpp_train_fixed.iloc[i]
    cpp_aug.loc[len(cpp_aug.index)] = [buggy_code, fixed_code]
    for f in cpp_aug_methods:
        transformed = f(buggy_code, 0.2)
        transformed = randomlyCorrupt(transformed, 0.15)
        cpp_aug.loc[len(cpp_aug.index)] = [transformed, fixed_code]

In [164]:
cpp_aug.drop_duplicates(subset=['BUGGY_CODE'], keep='first', inplace=True)

In [165]:
cpp_aug.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,921,921
unique,921,283
top,const T * X_ptr = X_data + i * inner_size ; T ...,const auto source_n = n -> sourceRange ( ) . s...
freq,1,10


In [166]:
py_aug.to_csv('../data/py/dl_py_aug.csv', index=False)

In [167]:
cpp_aug.to_csv('../data/cpp/dl_cpp_aug.csv', index=False)