In [1]:
import ast
import pandas as pd
from tqdm import tqdm

# Загружаем данные (пример)
df = pd.read_csv('../data/train.csv', sep=';')
df.columns = ['text', 'spans']
df['spans'] = df['spans'].apply(lambda x: ast.literal_eval(x))
df['spans'] = df['spans'].apply(lambda x: [(s[0], s[1], s[2].replace('0', 'O')) for s in x])
df['spans_d'] = df['spans'].apply(lambda x: [{'start': span[0], 'end': span[1], 'label': span[2]} for span in x])
df = df.iloc[[i for i, s in enumerate(df['spans_d'].values) if s[0]['start'] == 0]]

# словарь меток
label2id = {
    "O": 0, 
    "B-BRAND": 1, "B-TYPE": 2, "B-VOLUME": 3, "B-PERCENT": 4,
    "I-BRAND": 5, "I-TYPE": 6, "I-VOLUME": 7, "I-PERCENT": 8,
}
id2label = {v: k for k, v in label2id.items()}

print(df.shape)
df.head(2)

(27248, 3)


Unnamed: 0,text,spans,spans_d
0,aa,"[(0, 2, O)]","[{'start': 0, 'end': 2, 'label': 'O'}]"
1,aala,"[(0, 4, O)]","[{'start': 0, 'end': 4, 'label': 'O'}]"


In [2]:
import random
import re

def get_random_int_range(num: int | str) -> tuple[int, int]:
    n = '1'
    for _ in range(len(str(num))):
        n += '0'
    n = n[:-1]
    start = int(n)
    start = 0 if start == 1 else start
    end = int(n + '0') - 1
    return start, end

def augment_numbers(texts, spans, n_aug=3):
    r = re.compile(r'[0-9]+')
    new_texts, new_labels = [], []
    
    for text, span in zip(texts, spans):
        new_texts.append(text)
        new_labels.append(span)

        token_spans = [(text[s[0] : s[1]], s[2]) for s in span]

        num_token_ids = [
            i for i, (_, s) in enumerate(token_spans) 
            if s != 'O' and s.split('-')[1] in ('PERCENT', 'VALUE')
        ]

        if not num_token_ids:
            continue
        
        for _ in range(n_aug):
            aug_text = str(text)
            aug_lbls = list(span)

            for num_token_id in num_token_ids:
                token, label = token_spans[num_token_id]

                num = r.search(token)
                if not num:
                    continue
                num = num.group(0)

                if 'PERCENT' in label or "VOLUME" in label:
                    start, end = get_random_int_range(num)
                    new_num = str(random.randint(start, end))
                else:
                    continue

                aug_text = aug_text.replace(num, new_num, 1)

            new_texts.append(aug_text)
            new_labels.append(aug_lbls)

    return new_texts, new_labels

In [3]:
texts = df.values[:, 0]
spans = df.values[:, 1]
aug_text, aug_spans = augment_numbers(texts, spans, 5)
num_aug_df = pd.DataFrame({'text': aug_text, 'spans': aug_spans})
num_aug_df['spans'] = num_aug_df['spans'].apply(lambda x: str(x))
num_aug_df.to_csv('../data/aug_train.csv', sep=';', index=False)

In [407]:
from itertools import product, chain

def aug_group_bio(token_spans):
    span_bios = [s for s in token_spans if s[2].startswith('B-')]
    span_groups = []
    for span in span_bios:
        o = span[2].split('-')[1]
        group_tokens = [s[3] for s in token_spans if s[2].endswith(o)]
        aug_group = []
        for _ in range(6):
            aug_tokens = list(group_tokens)
            random.shuffle(aug_tokens)
            aug_spans = [(f'B-{o}', aug_tokens[0])]
            for token in aug_tokens[1:]:
                aug_spans.append((f'I-{o}', token))
            aug_group.append(tuple(aug_spans))

        aug_group = list(set(aug_group))
        span_groups.append(tuple(aug_group))
    return span_groups

def valid_permutations(entities):
    n = len(entities)
    used = [False] * n
    result = []

    # Преобразуем в список для удобства
    items = list(entities)

    # Словарь: для каждого I знаем его базовый B
    base_for = {}
    for i, (tag, word) in enumerate(items):
        if tag.startswith("I-"):
            _, name = tag.split("-", 1)
            base_for[i] = [j for j, (t2, _) in enumerate(items) if t2 == f"B-{name}"][0]

    def backtrack(path):
        if len(path) == n:
            result.append([items[i] for i in path])
            return

        for i in range(n):
            if used[i]:
                continue

            # Если это I, то его B должен быть уже в path
            if i in base_for and base_for[i] not in path:
                continue

            used[i] = True
            path.append(i)
            backtrack(path)
            path.pop()
            used[i] = False

    backtrack([])
    return result

def shuffle_token_spans(token_spans, k: int = 3):
    aug_spans = []
    for span_tokens in random.choices(token_spans, k=10):
        cur_i = 0
        new_spans = []
        text = ' '.join([token for _, token in span_tokens])
        for span, token in span_tokens:
            new_spans.append((cur_i, cur_i + len(token), span))
            cur_i += len(token) + 1
        aug_spans.append((text, tuple(new_spans)))
    aug_spans = list(set(aug_spans))[:k]
    return aug_spans

def aug_text_spans(spans, text, k = 3):
    token_spans = [(s[0], s[1], s[2], token) for s, token in zip(spans, text.split())]
    span_groups = aug_group_bio(token_spans)

    span_products = [list(chain(*s)) for s in product(*span_groups)]
    token_spans = list(chain(*[valid_permutations(sp) for sp in span_products]))
    aug_spans = shuffle_token_spans(token_spans, k)
    return aug_spans

In [377]:
text, spans = df[[len(r.split()) > 3 for r in df['text'].values]][['text', 'spans']].sample(1).values[0]
print(text)
print(spans)

вода 5л красная цена
[(0, 4, 'B-TYPE'), (5, 7, 'B-VOLUME'), (8, 15, 'B-BRAND'), (16, 20, 'I-BRAND')]


In [416]:
texts = num_aug_df['text'].values
spans = num_aug_df['spans'].values

In [418]:
aug_vals = []
for text, span in zip(texts, spans):
    aug_vals += aug_text_spans(span, text)

In [421]:
aug_df = pd.DataFrame(aug_vals, columns=['text', 'spans'])
aug_df = aug_df[aug_df['text'] != '']
aug_df = pd.concat([num_aug_df, aug_df])
aug_df['spans'] = aug_df['spans'].apply(lambda x: str(x))
aug_df = aug_df.drop_duplicates()
print(aug_df.shape)
aug_df.head(2)

(61534, 2)


Unnamed: 0,text,spans
0,aa,"[(0, 2, 'O')]"
1,aala,"[(0, 4, 'O')]"


In [422]:
aug_df.to_csv('../data/aug_df.csv', sep=';', index=False)