In [1]:
import pandas as pd
import numpy as np
import random

from deap import base
from deap import creator
from deap import tools

In [3]:
def read_msg_and_embs_and_lbls(path):
    msgs = []
    labels = []
    with open(f"{path}/seq.in", 'r') as seq_file:
        for line in seq_file:
            msgs.append(line.strip())
    
    with open(f"{path}/label", 'r') as label_file:
        for line in label_file:
            labels.append(line.strip())
    
    embs = np.load(f"{path}/embs.npy")
    
    tuples = []
    
    for i in range(len(msgs)):
        tuples.append((msgs[i], embs[i], labels[i]))
    
    tuples = sorted(tuples, key=lambda x: x[2])
    new_msgs, new_embs, new_labels = [], [], []
    
    for i in range(len(msgs)):
        new_msgs.append(tuples[i][0])
        new_embs.append(tuples[i][1])
        new_labels.append(tuples[i][2])
    
    return new_msgs, new_embs, new_labels

In [4]:
def read_augs_to_pandas():
    true_data = read_msg_and_embs_and_lbls('./aug_data/data/banking77/train_10/')
    back_trans = read_msg_and_embs_and_lbls('./aug_data/data/back_translation/')
    context_replacement = read_msg_and_embs_and_lbls('./aug_data/data/contextual_replacement/')
    glove_word_replace = read_msg_and_embs_and_lbls('./aug_data/data/glove_word_replace/')
    gpt2_msgs = read_msg_and_embs_and_lbls('./aug_data/data/gpt2_msgs/')
    random_delete = read_msg_and_embs_and_lbls('./aug_data/data/random_delete/')
    random_swap = read_msg_and_embs_and_lbls('./aug_data/data/random_swap/')
    synonym_replace = read_msg_and_embs_and_lbls('./aug_data/data/synonym_replace/')
    
    df = pd.DataFrame({
        "lbl": true_data[2],
        "true_msg": true_data[0],
        "back_trans_msg": back_trans[0],
        "context_replacement_msg": context_replacement[0],
        "glove_word_replace_msg": glove_word_replace[0],
        "gpt2_msgs_msg": gpt2_msgs[0],
        "random_delete_msg": random_delete[0],
        "random_swap_msg": random_swap[0],
        "synonym_replace_msg": synonym_replace[0],
        "true_embs": true_data[1],
        "back_trans_embs": back_trans[1],
        "context_replacement_embs": context_replacement[1],
        "glove_word_replace_embs": glove_word_replace[1],
        "gpt2_msgs_embs": gpt2_msgs[1],
        "random_delete_embs": random_delete[1],
        "random_swap_embs": random_swap[1],
        "synonym_replace_embs": synonym_replace[1],
    })
    
    return df

In [5]:
df = read_augs_to_pandas()

In [6]:
# df.head()

In [7]:
for i, val in enumerate(df.columns):
    print(i, val)

0 lbl
1 true_msg
2 back_trans_msg
3 context_replacement_msg
4 glove_word_replace_msg
5 gpt2_msgs_msg
6 random_delete_msg
7 random_swap_msg
8 synonym_replace_msg
9 true_embs
10 back_trans_embs
11 context_replacement_embs
12 glove_word_replace_embs
13 gpt2_msgs_embs
14 random_delete_embs
15 random_swap_embs
16 synonym_replace_embs


In [18]:
def n_per_aug_method():
    return random.choices( range(1, 10), k = 7)

In [19]:
from sklearn.metrics import silhouette_score

In [20]:
print(random.randint(0, 9))

7


In [21]:
def evaluate(individual):
    individual = individual[0]
    embs = []
    labels = []
#     print("HUI")
    for index, row in df.iterrows():
        embs.append(row['true_embs'])
        labels.append(row[0])
        for ind, el in enumerate(individual):
            rand_int = random.randint(0, 9)
            if rand_int < el:
                embs.append(row[10 + ind])
                labels.append(row[0])
    return -1 * silhouette_score(embs, labels),

In [26]:
# this is the setup of the deap library: registering the different function into the toolbox
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Individual", list, fitness=creator.FitnessMin)

toolbox = base.Toolbox()

toolbox.register("n_per_aug_method", n_per_aug_method)

toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.n_per_aug_method, n=1)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

toolbox.register("evaluate", evaluate)
toolbox.register("mate", tools.cxTwoPoint)
toolbox.register("mutate", tools.mutUniformInt, low=1, up=10, indpb=0.1)
toolbox.register("select", tools.selTournament, tournsize=3)



In [27]:
toolbox.population(n=10)

[[[2, 5, 9, 4, 2, 3, 7]],
 [[3, 8, 1, 5, 4, 4, 9]],
 [[4, 9, 9, 2, 3, 5, 9]],
 [[6, 2, 5, 3, 9, 1, 9]],
 [[2, 6, 8, 6, 2, 7, 7]],
 [[3, 9, 8, 8, 4, 4, 5]],
 [[8, 3, 2, 8, 6, 8, 7]],
 [[9, 7, 1, 2, 3, 1, 8]],
 [[9, 1, 8, 9, 5, 3, 1]],
 [[5, 6, 1, 4, 5, 7, 2]]]

In [28]:
def main():
    pop = toolbox.population(n=50)

    fitnesses = list(map(toolbox.evaluate, pop))
    for ind, fit in zip(pop, fitnesses):
        ind.fitness.values = fit


    CXPB, MUTPB = 0.5, 0.2

    fits = [ind.fitness.values[0] for ind in pop]

    g = 0

    while g < 200:
        
        g = g + 1
        print("-- Generation %i --" % g)

        offspring = toolbox.select(pop, len(pop))
        offspring = list(map(toolbox.clone, offspring))

        for child1, child2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < CXPB:
                toolbox.mate(child1[0], child2[0])
                del child1.fitness.values
                del child2.fitness.values

        for mutant in offspring:
            if random.random() < MUTPB:
                toolbox.mutate(mutant[0])
                del mutant.fitness.values

        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        fitnesses = map(toolbox.evaluate, invalid_ind)
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit
            
        pop[:] = offspring

        fits = [ind.fitness.values[0] for ind in pop]
        
        length = len(pop)
        mean = sum(fits) / length
        sum2 = sum(x*x for x in fits)
        std = abs(sum2 / length - mean**2)**0.5
        
        cur_best = pop[np.argmin([toolbox.evaluate(x) for x in pop])]
        print(cur_best)
        #print(min(fits), max(fits), mean, std)
    
    best = pop[np.argmin([toolbox.evaluate(x) for x in pop])]
    return best

In [None]:
#todo попробовать без true embs

In [17]:
best_solution = main()

-- Generation 1 --
[[9, 0, 1, 1, 3, 4, 0]]
-- Generation 2 --
[[9, 0, 1, 1, 0, 9, 0]]
-- Generation 3 --
[[4, 0, 1, 1, 0, 9, 1]]
-- Generation 4 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 5 --
[[9, 0, 1, 1, 0, 9, 0]]
-- Generation 6 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 7 --
[[4, 0, 1, 1, 0, 9, 0]]
-- Generation 8 --
[[4, 0, 1, 1, 0, 9, 0]]
-- Generation 9 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 10 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 11 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 12 --
[[4, 0, 1, 1, 0, 9, 0]]
-- Generation 13 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 14 --
[[2, 0, 1, 1, 0, 0, 0]]
-- Generation 15 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 16 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 17 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 18 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 19 --
[[2, 0, 1, 1, 0, 9, 0]]
-- Generation 20 --
[[2, 0, 1, 0, 0, 3, 0]]
-- Generation 21 --
[[2, 0, 1, 0, 0, 3, 0]]
-- Generation 22 --
[[2, 0, 1, 0, 0, 3, 0]]
-- Generation 23 --
[[2, 0, 1, 0, 0, 3, 0

[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 186 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 187 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 188 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 189 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 190 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 191 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 192 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 193 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 194 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 195 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 196 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 197 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 198 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 199 --
[[0, 0, 0, 0, 0, 0, 0]]
-- Generation 200 --
[[0, 0, 0, 0, 0, 0, 0]]


In [29]:
best_solution = main()

-- Generation 1 --
[[3, 2, 1, 1, 5, 5, 1]]
-- Generation 2 --
[[3, 4, 1, 1, 1, 8, 1]]
-- Generation 3 --
[[7, 1, 1, 1, 1, 8, 6]]
-- Generation 4 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 5 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 6 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 7 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 8 --
[[3, 1, 1, 1, 1, 8, 1]]
-- Generation 9 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 10 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 11 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 12 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 13 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 14 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 15 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 16 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 17 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 18 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 19 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 20 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 21 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 22 --
[[7, 1, 1, 1, 1, 8, 1]]
-- Generation 23 --


KeyboardInterrupt: 