# 数据增强方法

由于有效的数据量十分有限，为了能够获得更好的训练效果，同时留出一定量的数据来对模型进行评估，我们设计了一些数据增强方法，以获取更多数据，这些方法部分借鉴了文本增强方法中常用的思想。

## 1. 针对算数运算的增强

对于所有的加/减法, 以 p 概率换为减/加法

In [2]:
import random

In [6]:
def replace_add_sub(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    replaced = ''
    for token in tokens:
        q = random.random()
        if token == '+' and q <= p:
            replaced += '- '
        elif token == '-' and q <= p:
            replaced += '+ '
        else:
            replaced += (token + ' ')
    return replaced

In [8]:
text = r"""
def add_sub(a, b):
    c = a + b
    d = c - a
    b = a * c + b
    e = b + a - c * d - q
    return e - b + d + a - c
"""

In [11]:
replace_add_sub(text, 0.1)

' def add_sub(a, b): c = a + b d = c - a b = a * c + b e = b + a - c * d - q return e - b - d + a - c  '

## 2. 针对逻辑运算的数据增强

### 2.1 以概率 p 删去某个条件分量

In [159]:
def remove_cond_component_cpp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    len_in_tokens = len(tokens)
    res = ''
    i = 0
    while i < len_in_tokens:
        q = random.random()
        if (tokens[i] == '&&' or tokens[i] == '||') and q <= p:
            while i < len_in_tokens and tokens[i] != ')':
                i += 1
        if i < len_in_tokens:
            res += (tokens[i] + ' ')
        i += 1
    return res        

In [160]:
text = r"""
if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }
"""

In [161]:
remove_cond_component_cpp(text, 0.2)

' if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }  '

### 2.2 以概率 p 把每个 || 和 && 互换

In [41]:
def swap_cond_cpp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '&&' and q <= p:
            res += '|| '
        elif token == '||' and q <= p:
            res += '&& '
        else:
            res += token
            res += ' '
    return res
    

In [60]:
def swap_cond_py(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == 'and' and q <= p:
            res += 'or '
        elif token == 'or' and q <= p:
            res += 'and '
        else:
            res += token
            res += ' '
    return res
    

In [42]:
text = r"""
if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }
"""

In [43]:
swap_cond_cpp(text, 0.2)

' if ( rank_ == 0 || ( commType != NCCLCommType::COLL || p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }  '

In [65]:
text = r"""
if a and b or c or d and e or f:
    i += 1
"""

In [66]:
swap_cond_py(text, 0.3)

' if a and b and c or d and e or f: i += 1  '

## 3. 针对对象访问的增强

对于所有的 . 和 ->, 以 p 概率换成 -> 或者 .

In [67]:
def swap_access(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '->' and q <= p:
            res += '. '
        elif token == '.' and q <= p:
            res += '-> '
        else:
            res += token
            res += ' '
    return res

In [81]:
text = r"""
if ( check_has_torch_function ( self ) ) { return handle_torch_function_setter ( ( THPVariable * ) self , ""names"" , names ) ; } auto & var = ( ( THPVariable * ) self ) -> cdata ;
"""

In [82]:
swap_access(text, 0.5)

' if ( check_has_torch_function ( self ) ) { return handle_torch_function_setter ( ( THPVariable * ) self , ""names"" , names ) ; } auto & var = ( ( THPVariable * ) self ) . cdata ;  '

## 4. 针对比较运算的数据增强

以 p 概率把 > / >=, 或者 </ <= 替换

In [83]:
def swap_comp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '>=' and q <= p:
            res += '> '
        elif token == '<=' and q <= p:
            res += '< '
        else:
            res += token
            res += ' '
    return res

In [85]:
text = r"""
if a >= b {}
else if a <= b {}
else if a <= c {}
else if a >= d {}
"""

In [92]:
swap_comp(text, 0.3)

' if a >= b {} else if a <= b {} else if a < c {} else if a > d {}  '

由于程序片段的破碎性，无法对程序片段进行语法分析，因此损失了很多可用的手段，以上的数据增强方法仅为在这种限制场景下的一些探索

In [131]:
def corrupt(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if q > p:
            res += token
            res += ' '
    return res

In [132]:
text = r"""
this is a token sequence to be corrupted sda sjfhals shags lasjf or 888.
"""

In [133]:
corrupt(text, 0.2)

'this is a token to corrupted sda sjfhals shags lasjf or 888. '

In [145]:
def randomlyCorrupt(code_snippet, p):
    q = random.random()
    if q <= 0.3:
        return corrupt(code_snippet, p)
    return code_snippet

### 使用数据增强方法扩增数据

In [94]:
import pandas as pd

In [96]:
cpp_data_path = '../data/cpp/dl_cpp_tokenized.csv'

In [97]:
cpp_data = pd.read_csv(cpp_data_path)

In [98]:
py_data_path = '../data/py/dl_py_tokenized.csv'

In [99]:
py_data = pd.read_csv(py_data_path)

In [100]:
cpp_data.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...
1,std::chrono::milliseconds ( kWaitForAbortCommS...,std::chrono::milliseconds ( kWaitForAbortCommS...
2,if ( rank_ == 0 || ( commType != NCCLCommType:...,if ( rank_ == 0 || ( isP2POp ( opType ) && p2p...
3,"int numRanks , rank ; if ( commType == NCCLCom...","int numRanks , rank ; if ( sP2POp ( opType ) )..."
4,"if ( str != ""None"" && str != """" ) { throw std:...","if ( str != ""None"" ) { default_string = parse_..."


In [101]:
cpp_data.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,357,357
unique,329,329
top,const auto source_n = n -> sourceRange ( ) . s...,const auto source_n = n -> sourceRange ( ) . s...
freq,3,3


In [102]:
py_data.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,def __sizeof__ ( self ) : return super ( _Stor...,def __sizeof__ ( self ) : return ( super ( _St...
1,self . quant_min : int = quant_min self . quan...,self . quant_min : int = quant_min self . quan...
2,def get_default_qat_qconfig ( backend = gemm v...,def get_default_qat_qconfig ( backend = gemm v...
3,"exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha...","exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha..."
4,def get_post_build_suffix ( self ) -> str : if...,def get_post_build_suffix ( self ) -> str : if...


In [103]:
py_data.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,275,275
unique,268,268
top,if it % 100 == 0 : print ( eration % d - - Out...,if it % 100 == 0 : print ( eration % d - - Out...
freq,3,3


In [174]:
py_aug = pd.DataFrame(data={'BUGGY_CODE': [], "FIXED_CODE": []})

In [175]:
py_aug_methods = [replace_add_sub, swap_cond_py, swap_comp]

In [176]:
cpp_aug = pd.DataFrame(data={'BUGGY_CODE': [], "FIXED_CODE": []})

In [177]:
cpp_aug_methods = [replace_add_sub, remove_cond_component_cpp, swap_cond_cpp, swap_access, swap_comp]

In [178]:
py_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE


In [179]:
cpp_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE


In [180]:
py_data_len = py_data.shape[0]
cpp_data_len = cpp_data.shape[0]

In [181]:
for i in range(py_data_len):
    fixed_code = py_data.iloc[i]['FIXED_CODE']
    for p in [0.3, 0.4, 0.5]:
        for f in py_aug_methods:
            transformed = f(fixed_code, p)
            transformed = randomlyCorrupt(transformed, p)
            py_aug.loc[len(py_aug.index)] = [transformed, fixed_code]

In [182]:
py_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,def __sizeof__ ( self ) : return ( super ( _St...,def __sizeof__ ( self ) : return ( super ( _St...
1,def __sizeof__ ( self ) : return ( super ( _St...,def __sizeof__ ( self ) : return ( super ( _St...
2,def __sizeof__ ( self ) : return ( super ( _St...,def __sizeof__ ( self ) : return ( super ( _St...
3,def __sizeof__ ( self ) : return ( super ( _St...,def __sizeof__ ( self ) : return ( super ( _St...
4,"def __sizeof__ ( self ) : return ( super ( , ....",def __sizeof__ ( self ) : return ( super ( _St...


In [183]:
py_aug.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,2475,2475
unique,1119,268
top,paths . append ( p ) if not found_one : print ...,if it % 100 == 0 : print ( eration % d - - Out...
freq,16,27


In [184]:
for i in range(cpp_data_len):
    fixed_code = cpp_data.iloc[i]['FIXED_CODE']
    for p in [0.3, 0.4, 0.5]:
        for f in cpp_aug_methods:
            transformed = f(fixed_code, p)
            transformed = randomlyCorrupt(transformed, p)
            cpp_aug.loc[len(cpp_aug.index)] = [transformed, fixed_code]

In [185]:
cpp_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...
1,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...
2,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...
3,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...
4,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> ( ...,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...


In [186]:
cpp_aug.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,5355,5355
unique,2438,329
top,const auto source_n = n -> sourceRange ( ) . s...,const auto source_n = n -> sourceRange ( ) . s...
freq,25,45


In [187]:
unique_py_aug = py_aug.drop_duplicates(subset=['BUGGY_CODE', 'FIXED_CODE'], keep = False, inplace=False)

In [188]:
unique_py_aug.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,829,829
unique,829,258
top,def __sizeof__ ( self ) : return ( super ( _St...,"def __rdiv__ ( self , other ) : if self . dtyp..."
freq,1,9


In [189]:
unique_cpp_aug = cpp_aug.drop_duplicates(subset=['BUGGY_CODE', 'FIXED_CODE'], keep = False, inplace=False)

In [190]:
unique_cpp_aug.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,2045,2045
unique,2045,329
top,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...,if ( resolver_ ) { if ( auto typePtr = resolve...
freq,1,22


In [191]:
unique_py_aug.to_csv('../data/py/dl_py_aug.csv', index=False)

In [192]:
unique_cpp_aug.to_csv('../data/cpp/dl_cpp_aug.csv', index=False)

至此，我们完成了数据增强，虽然数据量仍然很少，但是相比原来已经多了不少