# 数据增强方法 Demo

由于有效的数据量十分有限，为了能够获得更好的训练效果，同时留出一定量的数据来对模型进行评估，我们设计了一些数据增强方法，以获取更多数据，这些方法部分借鉴了文本增强方法中常用的思想。

## 1. 针对算数运算的增强

对于所有的加/减法, 以 p 概率换为减/加法

In [2]:
import random

In [6]:
def replace_add_sub(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    replaced = ''
    for token in tokens:
        q = random.random()
        if token == '+' and q <= p:
            replaced += '- '
        elif token == '-' and q <= p:
            replaced += '+ '
        else:
            replaced += (token + ' ')
    return replaced

In [8]:
text = r"""
def add_sub(a, b):
    c = a + b
    d = c - a
    b = a * c + b
    e = b + a - c * d - q
    return e - b + d + a - c
"""

In [11]:
replace_add_sub(text, 0.1)

' def add_sub(a, b): c = a + b d = c - a b = a * c + b e = b + a - c * d - q return e - b - d + a - c  '

## 2. 针对逻辑运算的数据增强

### 2.1 以概率 p 删去某个条件分量

In [31]:
def remove_cond_component_cpp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    len_in_tokens = len(tokens)
    res = ''
    i = 0
    while i < len_in_tokens:
        q = random.random()
        if (tokens[i] == '&&' or tokens[i] == '||') and q <= p:
            while i < len_in_tokens and tokens[i] != ')':
                i += 1
        res += (tokens[i] + ' ')
        i += 1
    return res        

In [56]:
text = r"""
if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }
"""

In [57]:
remove_cond_component_cpp(text, 0.2)

' if ( rank_ == 0 || ( commType != NCCLCommType::COLL ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }  '

### 2.2 以概率 p 把每个 || 和 && 互换

In [41]:
def swap_cond_cpp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '&&' and q <= p:
            res += '|| '
        elif token == '||' and q <= p:
            res += '&& '
        else:
            res += token
            res += ' '
    return res
    

In [60]:
def swap_cond_py(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == 'and' and q <= p:
            res += 'or '
        elif token == 'or' and q <= p:
            res += 'and '
        else:
            res += token
            res += ' '
    return res
    

In [42]:
text = r"""
if ( rank_ == 0 || ( commType != NCCLCommType::COLL && p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }
"""

In [43]:
swap_cond_cpp(text, 0.2)

' if ( rank_ == 0 || ( commType != NCCLCommType::COLL || p2pRank == 0 ) ) { C10D_NCCL_CHECK ( ncclGetUniqueId ( & ncclID ) ) ; }  '

In [65]:
text = r"""
if a and b or c or d and e or f:
    i += 1
"""

In [66]:
swap_cond_py(text, 0.3)

' if a and b and c or d and e or f: i += 1  '

## 3. 针对对象访问的增强

对于所有的 . 和 ->, 以 p 概率换成 -> 或者 .

In [67]:
def swap_access(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '->' and q <= p:
            res += '. '
        elif token == '.' and q <= p:
            res += '-> '
        else:
            res += token
            res += ' '
    return res

In [81]:
text = r"""
if ( check_has_torch_function ( self ) ) { return handle_torch_function_setter ( ( THPVariable * ) self , ""names"" , names ) ; } auto & var = ( ( THPVariable * ) self ) -> cdata ;
"""

In [82]:
swap_access(text, 0.5)

' if ( check_has_torch_function ( self ) ) { return handle_torch_function_setter ( ( THPVariable * ) self , ""names"" , names ) ; } auto & var = ( ( THPVariable * ) self ) . cdata ;  '

## 4. 针对比较运算的数据增强

以 p 概率把 > / >=, 或者 </ <= 替换

In [83]:
def swap_comp(code_snippet, p):
    tokens = re.split('\s+', code_snippet)
    res = ''
    for token in tokens:
        q = random.random()
        if token == '>=' and q <= p:
            res += '> '
        elif token == '<=' and q <= p:
            res += '< '
        else:
            res += token
            res += ' '
    return res

In [85]:
text = r"""
if a >= b {}
else if a <= b {}
else if a <= c {}
else if a >= d {}
"""

In [92]:
swap_comp(text, 0.3)

' if a >= b {} else if a <= b {} else if a < c {} else if a > d {}  '

由于程序片段的破碎性，无法对程序片段进行语法分析，因此损失了很多可用的手段，以上的数据增强方法仅为在这种限制场景下的一些探索