# 数据增强方法

由于有效的数据量十分有限，为了能够获得更好的训练效果，同时留出一定量的数据来对模型进行评估，我们设计了一些数据增强方法，以获取更多数据，这些方法部分借鉴了文本增强方法中常用的思想。

## 1. 针对算数运算的增强

对于所有的加/减法, 以 p 概率换为减/加法

In [175]:
import random
import re

In [176]:
def replace_add_sub(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    replaced = ''
    for token in tokens:
        q = random.random()
        if token == '+' and q <= p:
            replaced += '- '
        elif token == '-' and q <= p:
            replaced += '+ '
        else:
            replaced += (token + ' ')
    return replaced

In [177]:
text = r"""
def add_sub(a, b):
    c = a + b
    d = c - a
    b = a * c + b
    e = b + a - c * d - q
    return e - b + d + a - c
"""

In [178]:
replace_add_sub(text, 0.1)

' def add_sub(a, b): c = a + b d = c - a b = a * c - b e = b + a - c * d - q return e + b + d + a - c  '

## 2. 针对逻辑运算的数据增强

### 2.1 以概率 p 删去某个条件分量

In [179]:
def remove_cond_component_cpp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    len_in_tokens = len(tokens)
    res = ''
    i = 0
    while i < len_in_tokens:
        q = random.random()
        if (tokens[i] == '&&' or tokens[i] == '||') and q <= p:
            while i < len_in_tokens and tokens[i] != ')':
                i += 1
        if i < len_in_tokens:
            res += (tokens[i] + ' ')
        i += 1
    return res        

In [180]:
text = r"""
if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }
"""

In [181]:
remove_cond_component_cpp(text, 0.2)

' if ( rank_ == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }  '

### 2.2 以概率 p 把每个 || 和 && 互换

In [182]:
def swap_cond_cpp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '&&' and q <= p:
            res += '|| '
        elif token == '||' and q <= p:
            res += '&& '
        else:
            res += token
            res += ' '
    return res
    

In [183]:
def swap_cond_py(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == 'and' and q <= p:
            res += 'or '
        elif token == 'or' and q <= p:
            res += 'and '
        else:
            res += token
            res += ' '
    return res
    

In [184]:
text = r"""
if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }
"""

In [185]:
swap_cond_cpp(text, 0.2)

' if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }  '

In [186]:
text = r"""
if a and b or c or d and e or f:
    i += 1
"""

In [187]:
swap_cond_py(text, 0.3)

' if a and b or c or d or e or f: i += 1  '

## 3. 针对对象访问的增强

对于所有的 . 和 ->, 以 p 概率换成 -> 或者 .

In [188]:
def swap_access(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '->' and q <= p:
            res += '. '
        elif token == '.' and q <= p:
            res += '-> '
        else:
            res += token
            res += ' '
    return res

In [189]:
text = r"""
if ( check_has_torch_function ( self ) ) { return handle_torch_function_setter ( ( THPVariable * ) self , ""names"" , names ) ; } auto & var = ( ( THPVariable * ) self ) -> cdata ;
"""

In [190]:
swap_access(text, 0.5)

' if ( check_has_torch_function ( self ) ) { return handle_torch_function_setter ( ( THPVariable * ) self , ""names"" , names ) ; } auto & var = ( ( THPVariable * ) self ) -> cdata ;  '

## 4. 针对比较运算的数据增强

以 p 概率把 > / >=, 或者 </ <= 替换

In [191]:
def swap_comp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '>=' and q <= p:
            res += '> '
        elif token == '<=' and q <= p:
            res += '< '
        else:
            res += token
            res += ' '
    return res

In [192]:
text = r"""
if a >= b {}
else if a <= b {}
else if a <= c {}
else if a >= d {}
"""

In [193]:
swap_comp(text, 0.3)

' if a > b {} else if a <= b {} else if a <= c {} else if a >= d {}  '

由于程序片段的破碎性，无法对程序片段进行语法分析，因此损失了很多可用的手段，以上的数据增强方法仅为在这种限制场景下的一些探索

In [194]:
def corrupt(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if q > p:
            res += token
            res += ' '
    return res

In [195]:
text = r"""
this is a token sequence to be corrupted sda sjfhals shags lasjf or 888.
"""

In [196]:
corrupt(text, 0.2)

' this is a token be corrupted sda sjfhals lasjf or 888.  '

In [197]:
def randomlyCorrupt(code_snippet, p):
    q = random.random()
    if q <= 0.15:
        return corrupt(code_snippet, p)
    return code_snippet

### 对训练集使用数据增强方法扩增数据

In [198]:
import pandas as pd

In [199]:
cpp_data_path = '../data/cpp/dl_cpp_tokenized.csv'

In [200]:
cpp_data = pd.read_csv(cpp_data_path)

In [201]:
py_data_path = '../data/py/dl_py_tokenized.csv'

In [202]:
py_data = pd.read_csv(py_data_path)

In [203]:
cpp_data.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...,std::shared_ptr<ProcessGroupNCCL::WorkNCCL> Pr...
1,std::chrono::milliseconds ( kWaitForAbortCommS...,std::chrono::milliseconds ( kWaitForAbortCommS...
2,if ( rank_ == 0 || ( commType != NCCLCommType:...,if ( rank_ == 0 || ( isP2POp ( opType ) && p2p...
3,"int numRanks , rank ; if ( commType == NCCLCom...","int numRanks , rank ; if ( sP2POp ( opType ) )..."
4,"if ( str != ""None"" && str != """" ) { throw std:...","if ( str != ""None"" ) { default_string = parse_..."


In [204]:
cpp_data.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,357,357
unique,329,329
top,const auto source_n = n -> sourceRange ( ) . s...,const auto source_n = n -> sourceRange ( ) . s...
freq,3,3


In [205]:
py_data.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,def __sizeof__ ( self ) : return super ( _Stor...,def __sizeof__ ( self ) : return ( super ( _St...
1,self . quant_min : int = quant_min self . quan...,self . quant_min : int = quant_min self . quan...
2,def get_default_qat_qconfig ( backend = gemm v...,def get_default_qat_qconfig ( backend = gemm v...
3,"exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha...","exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha..."
4,def get_post_build_suffix ( self ) -> str : if...,def get_post_build_suffix ( self ) -> str : if...


In [206]:
py_data.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,275,275
unique,268,268
top,if it % 100 == 0 : print ( eration % d - - Out...,if it % 100 == 0 : print ( eration % d - - Out...
freq,3,3


In [207]:
from sklearn.model_selection import train_test_split

In [208]:
cpp_train_buggy, cpp_test_buggy, cpp_train_fixed, cpp_test_fixed = train_test_split(cpp_data['BUGGY_CODE'], cpp_data['FIXED_CODE'], test_size=0.16, random_state=42)

In [210]:
py_train_buggy, py_test_buggy, py_train_fixed, py_test_fixed = train_test_split(py_data['BUGGY_CODE'], py_data['FIXED_CODE'], test_size=0.16, random_state=42)

In [211]:
py_aug = pd.DataFrame(data={'BUGGY_CODE': [], "FIXED_CODE": []})

In [212]:
py_aug_methods = [lambda x,p : x, replace_add_sub, swap_cond_py, swap_comp]

In [213]:
cpp_aug = pd.DataFrame(data={'BUGGY_CODE': [], "FIXED_CODE": []})

In [214]:
cpp_aug_methods = [lambda x,p : x, replace_add_sub, remove_cond_component_cpp, swap_cond_cpp, swap_access, swap_comp]

In [215]:
py_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE


In [216]:
cpp_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE


In [217]:
cpp_train_buggy, cpp_val_buggy, cpp_train_fixed, cpp_val_fixed = train_test_split(cpp_train_buggy, cpp_train_fixed, test_size=0.1, random_state=42)

In [218]:
py_train_buggy, py_val_buggy, py_train_fixed, py_val_fixed = train_test_split(py_train_buggy, py_train_fixed, test_size=0.1, random_state=42)

In [219]:
py_data_len = py_train_buggy.shape[0]
cpp_data_len = cpp_train_buggy.shape[0]

In [220]:
py_data_len

207

In [221]:
cpp_data_len

269

In [222]:
len(cpp_test_buggy)

58

In [223]:
len(py_test_buggy)

44

In [224]:
for i in range(py_data_len):
    buggy_code = py_train_buggy.iloc[i]
    fixed_code = py_train_fixed.iloc[i]
    py_aug.loc[len(py_aug.index)] = [buggy_code, fixed_code]
    for f in py_aug_methods:
        transformed = f(buggy_code, 0.2)
        transformed = randomlyCorrupt(transformed, 0.15)
        py_aug.loc[len(py_aug.index)] = [transformed, fixed_code]

In [225]:
py_aug.head()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
0,if len ( buffers ) > 0 : raise RuntimeError ( ...,if len ( buffers ) > 0 : raise RuntimeError ( ...
1,if len ( buffers ) > 0 : raise RuntimeError ( ...,if len ( buffers ) > 0 : raise RuntimeError ( ...
2,if len ( buffers ) > 0 : raise RuntimeError ( ...,if len ( buffers ) > 0 : raise RuntimeError ( ...
3,if len ( buffers ) > 0 : raise RuntimeError ( ...,if len ( buffers ) > 0 : raise RuntimeError ( ...
4,if len ( buffers ) > 0 : raise RuntimeError ( ...,if len ( buffers ) > 0 : raise RuntimeError ( ...


In [226]:
py_aug.drop_duplicates(subset=['BUGGY_CODE'], keep='first', inplace=True)

In [227]:
py_aug.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,506,506
unique,506,203
top,if len ( buffers ) > 0 : raise RuntimeError ( ...,"prev_node = get_normalized_nth_input ( node , ..."
freq,1,5


In [228]:
for i in range(cpp_data_len):
    buggy_code = cpp_train_buggy.iloc[i]
    fixed_code = cpp_train_fixed.iloc[i]
    cpp_aug.loc[len(cpp_aug.index)] = [buggy_code, fixed_code]
    for f in cpp_aug_methods:
        transformed = f(buggy_code, 0.2)
        transformed = randomlyCorrupt(transformed, 0.15)
        cpp_aug.loc[len(cpp_aug.index)] = [transformed, fixed_code]

In [229]:
cpp_aug.drop_duplicates(subset=['BUGGY_CODE'], keep='first', inplace=True)

In [230]:
cpp_aug.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,842,842
unique,842,257
top,if ( y::is_none ( size ) ) { d -> set_size ( p...,"std::map < ParallelType , bool > ParallelTypeB..."
freq,1,7


### 保存增强之后的训练集 (还未划分出验证集) 

In [231]:
py_aug.to_csv('../data/py/dl_py_aug.csv', index=False)

In [232]:
cpp_aug.to_csv('../data/cpp/dl_cpp_aug.csv', index=False)

### 保存测试集

In [233]:
cpp_test = pd.DataFrame(data={'BUGGY_CODE': cpp_test_buggy, 'FIXED_CODE': cpp_test_fixed})

In [234]:
cpp_test.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,58,58
unique,56,56
top,if ( bag_size . defined ( ) ) { bag_size_data ...,if ( bag_size . defined ( ) ) { bag_size_data ...
freq,2,2


In [235]:
py_test = pd.DataFrame(data={'BUGGY_CODE': py_test_buggy, 'FIXED_CODE': py_test_fixed})

In [236]:
py_test.describe()

Unnamed: 0,BUGGY_CODE,FIXED_CODE
count,44,44
unique,44,44
top,"def icdf ( self , value ) : if self . _validat...","def icdf ( self , value ) : return torch . tan..."
freq,1,1


In [237]:
cpp_test.to_csv('../data/cpp/dl_cpp_test.csv', index=False)

In [238]:
py_test.to_csv('../data/py/dl_py_test.csv', index=False)

### 保存验证集

In [239]:
py_val = pd.DataFrame(data={"BUGGY_CODE": py_val_buggy, "FIXED_CODE": py_val_fixed})

In [240]:
cpp_val = pd.DataFrame(data={"BUGGY_CODE": cpp_val_buggy, "FIXED_CODE": cpp_val_fixed})

In [241]:
py_val.to_csv('../data/py/dl_py_val.csv', index=False)

In [242]:
cpp_val.to_csv('../data/cpp/dl_cpp_val.csv', index=False)

### Done