In [None]:
import pickle
import random
from tqdm import tqdm
import pandas as pd
import re
from nltk import sent_tokenize

# Seed to replicate the splitting process. Do not change it!.
SEED = 42
random.seed(42)

stop_word_java_tag = [#'@author',
                      '@param',
                      '@deprecated',
                      '@return',
                      #'@see',
                      '{@link}',
                      #'@since',
                      '@throws',
                      '@Override',
                      '{@docRoot}',
                      '@exception',
                      '{@inheritDoc}',
                      '{@linkplain}',
                      '{@literal}',
                      '@serial',
                      '@serialData',
                      #'@version',
                      '@{value}',
                      '@argfiles'
                ]

###### PLACEHOLDERS ######
PLACEHOLDER_MASK = '<extra_id_0>'
FAKE_LINK = '|__link__|'
FAKE_REF =  '|__ref__|'
#########################

TOKEN_LEN = 256

###### T5 SETTINGS ######
T5_EOS = ' </s>'
JAVADOC_PREFIX = 'complete javadoc comment: '
BLOCK_COMMENT_PREFIX =  'complete block/inline comment: '
#########################

In [None]:
def splitJavaDocByTag(javadoc):

    javadoc_sentences = []
    items = javadoc.split()
    delimiters = [word for word in items if word.startswith("@") and word in stop_word_java_tag]
    for idx in range(0, len(delimiters)):

        if idx==len(delimiters)-1:
            starting_index = javadoc.rfind(delimiters[idx])
            st = javadoc[starting_index:].strip('<sep>')
            javadoc_sentences.append(st)
            break

        else:
            sentence = re.findall("%s(.*?)%s" % (delimiters[idx],delimiters[idx+1]), javadoc)[0]
            javadoc_sentences.append(delimiters[idx]+sentence)
            #print('--> %s%s' % (delimiters[idx], sentence))

        javadoc = javadoc.replace(delimiters[idx],'', 1)

    return javadoc_sentences


In [None]:
#Quick fix
def check4Context(item):

    lines = item.splitlines()

    for (id_line, line) in enumerate(lines):

        if line.strip().startswith('<sep>'):

            if id_line >= 1:

                if len(lines[id_line-1])==0:
                    #we have to tight the context
                    for i in range(0, id_line):
                        lines[i]='||_to_remove_||'
                    break

    refined_item = ''
    for line in lines:
        if '||_to_remove_||' in line:
            continue
        else:
            refined_item+=line

    if refined_item.startswith('<sep>') and refined_item.endswith('<sep>'):
        return None
    else:
        return refined_item

In [None]:
# Load finetuning dataset


with open('finetuning_single_comment.pickle', 'rb') as newObj:
    data_finetune = pickle.load(newObj)
    ft_not_keep_comment = []
    for item in data_finetune:
        refined = check4Context(item)
        if refined is not None:
            ft_not_keep_comment.append(refined)


ft_javadoc_list = []
with open('finetuning_javadoc.pickle', 'rb') as newObj:
    ft_javadoc = pickle.load(newObj)
    for item in ft_javadoc:
        ft_javadoc_list.append(item)

In [None]:
def flatten(string):
    string = string.strip()
    string = string.replace('\n',' ')
    string = re.sub('\s+',' ',string)
    return string

In [None]:
random.seed(SEED)

####### JAVADOC SPLITTING ######

random.shuffle(ft_javadoc_list)

train_len_javadoc = round(len(ft_javadoc_list) * 0.8)
test_len_javadoc = eval_len_javadoc = round(len(ft_javadoc_list) * 0.1)

train_javadoc_instances = ft_javadoc_list[0:train_len_javadoc]
test_javadoc_instances =  ft_javadoc_list[train_len_javadoc:train_len_javadoc+test_len_javadoc]
eval_javadoc_instances =  ft_javadoc_list[train_len_javadoc+test_len_javadoc:]

# First round of duplicates dropping
train_javadoc_instances_set = set(train_javadoc_instances)
test_javadoc_instances_set = set(test_javadoc_instances)
eval_javadoc_instances_set = set(eval_javadoc_instances)

# assert(len(train_javadoc_instances_set)==train_javadoc_instances)
# assert(len(test_javadoc_instances_set) == test_javadoc_instances)
# assert(len(eval_javadoc_instances_set) == eval_javadoc_instances)

overlapped_1 = train_javadoc_instances_set.intersection(test_javadoc_instances_set)
if len(overlapped_1)>0:
    print('Find overlapped in train and test, we are going to remove them')
    train_javadoc_instances_set.difference(test_javadoc_instances_set)

overlapped_2 = train_javadoc_instances_set.intersection(eval_javadoc_instances_set)
if len(overlapped_2 )> 0:
    print('Find overlapped in train and eval, we are going to remove them')
    train_javadoc_instances_set.difference(eval_javadoc_instances_set)

overlapped_3 = test_javadoc_instances_set.intersection(eval_javadoc_instances_set)
if len(overlapped_3 )> 0:
    print('Find overlapped in test and eval, we are going to remove them')
    train_javadoc_instances_set.difference(eval_javadoc_instances_set)

train_javadoc_instances = list(train_javadoc_instances_set)
test_javadoc_instances = list(test_javadoc_instances_set)
eval_javadoc_instances = list(eval_javadoc_instances_set)

with open('javadoc_train_instances.pickle', 'wb') as fp:
    pickle.dump(train_javadoc_instances, fp, protocol=pickle.HIGHEST_PROTOCOL)

with open('javadoc_test_instances.pickle', 'wb') as fp:
    pickle.dump(test_javadoc_instances, fp, protocol=pickle.HIGHEST_PROTOCOL)

with open('javadoc_eval_instances.pickle', 'wb') as fp:
    pickle.dump(eval_javadoc_instances, fp, protocol=pickle.HIGHEST_PROTOCOL)


####### SINGLE COMMENT SPLITTING ######

random.shuffle(ft_not_keep_comment)

train_len_single_comment = round(len(ft_not_keep_comment) * 0.8)
test_len_single_comment = eval_len_single_comment = round(len(ft_not_keep_comment) * 0.1)

train_single_comment_instances = ft_not_keep_comment[0:train_len_single_comment]
test_single_comment_instances =  ft_not_keep_comment[train_len_single_comment:train_len_single_comment+test_len_single_comment]
eval_single_comment_instances =  ft_not_keep_comment[train_len_single_comment+test_len_single_comment:]

train_single_comment_instances_set = set(train_single_comment_instances)
test_single_comment_instances_set  = set(test_single_comment_instances)
eval_single_comment_instances_set  = set(eval_single_comment_instances)

# assert( len(train_single_comment_instances_set) == train_single_comment_instances)
# assert( len(test_single_comment_instances_set)  == test_single_comment_instances)
# assert( len(eval_single_comment_instances_set)  == eval_single_comment_instances)

overlapped_1 = train_single_comment_instances_set.intersection(test_single_comment_instances)
if len(overlapped_1)>0:
    print('Find overlapped in train and test, we are going to remove them')
    train_single_comment_instances_set.difference(test_single_comment_instances)

overlapped_2 = train_single_comment_instances_set.intersection(eval_single_comment_instances_set)
if len(overlapped_2 )> 0:
    print('Find overlapped in train and eval, we are going to remove them')
    train_single_comment_instances_set.difference(eval_single_comment_instances_set)

overlapped_3 = test_single_comment_instances_set.intersection(eval_single_comment_instances_set)
if len(overlapped_3 )> 0:
    print('Find overlapped in test and eval, we are going to remove them')
    test_single_comment_instances_set.difference(eval_single_comment_instances_set)

train_single_comment_instances = list(train_single_comment_instances_set)
test_single_comment_instances = list(test_single_comment_instances_set)
eval_single_comment_instances = list(eval_single_comment_instances_set)

with open('single_comment_train_instances.pickle', 'wb') as fp:
    pickle.dump(train_single_comment_instances, fp, protocol=pickle.HIGHEST_PROTOCOL)

with open('single_comment_test_instances.pickle', 'wb') as fp:
    pickle.dump(test_single_comment_instances, fp, protocol=pickle.HIGHEST_PROTOCOL)

with open('single_comment_eval_instances.pickle', 'wb') as fp:
    pickle.dump(eval_single_comment_instances, fp, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
def maskTokens(text, n_instances=5):

    tokens = text.split(' ')
    masked_input = []
    output = []

    if len(tokens)==1:
        return None


    elif len(tokens)==2:

        #Handling @param
        if tokens[0] == '@param' or tokens[0]=='@throws' or tokens[0]=='@exception':
            return None

        else:
            masked_input.append(PLACEHOLDER_MASK)
            output.append(tokens[1])

            return (masked_input, output)

    else:

        #Handling specifica javatags
        if tokens[0].strip() == '@param' or tokens[0].strip() == '@throws' or tokens[0]=='@exception' :
            #tokens = tokens[2:]
            choices = list(range(2,len(tokens)))

        else:
            #tokens = tokens[1:]
            choices = list(range(1,len(tokens))) #starting from 1 since the model needs at least one token to provide the suggestion

        counter_pop = len(choices)
        random.shuffle(choices)
        counter_instances = 0

        while(True):

            token_position = choices.pop()

            counter_pop -= 1

            counter_instances += 1

            dirty_input = ' '.join(tokens[0:token_position]) + ' ' + PLACEHOLDER_MASK
            masked_input.append(dirty_input)

            dirty_output = ' '.join(tokens[token_position:])
            output.append(dirty_output)

            if (counter_pop == 0) or (counter_instances == n_instances):
                break

        return (masked_input, output)

In [None]:
def getInstaces4Sample(comment, sample, javadoc=True):

    all_samples_X = []
    all_samples_Y = []
    sentences = []
    to_append = []
    backup_sample = sample

    if javadoc:
        comment = re.sub("\s\s+" , " ", comment)

        javadoc_splitted_by_tag = splitJavaDocByTag(comment)

        for tag in javadoc_splitted_by_tag:
            to_append.append(tag)
            comment = comment.replace(tag,'')

        i_list = sent_tokenize(comment)

        for item in i_list:
            sentences.append(item.strip())

        for item in to_append:
            sentences.append(item)

        comment = ' '.join(sentences)

        comment = re.sub('\s+',' ',comment)

    else:
        sentences = sent_tokenize(comment)

    make_comment = ''

    for (idx, sent) in enumerate(sentences):

        sent = sent.strip()

        result = maskTokens(sent)

        if result == None: continue

        for (item_x, item_y) in zip(result[0],result[1]):

            if javadoc: to_replace = (' '+ make_comment + ' ' + item_x + ' ').replace('\n','')#.strip()
            else: to_replace = (' '+ make_comment + ' ' + item_x + ' ').replace('\n','')

            to_replace = re.sub("\s\s+" , " ", to_replace)
            sample = sample.replace(comment, to_replace)

            all_samples_X.append(sample)
            all_samples_Y.append(item_y.replace('\n',''))
            sample = backup_sample

        make_comment += sent + ' '

    return ( all_samples_X, all_samples_Y)

In [None]:
# Since the cleaning process has been performed by taking into account multiple scenarios,
# we cannot be totally sure about the validity of a resulting sample. (corner cases management is extremely challenging)
# Thereby, this function is the last step before getting the final version of the dataset

def checkSampleSanity(sample):
    if sample.count('<sep>')  % 2 == 0 and len(sample)>= 3: return 1
    else: return 0

In [None]:
#threshold  = 10000

# Loop over these 3 lists to create the final datasets version
#ft_javadoc_list
#ft_keep_comment
#ft_not_keep_comment

def createSetBySample(task, split='train'):

    #Key: task_name + idx
    task_dictionary = {}
    all_samples_X = []
    all_samples_Y = []

    #Select task list
    if task == 'javadoc':

        if split=='train':
            task_list = train_javadoc_instances
        elif split == 'eval':
            task_list = eval_javadoc_instances
        else:
            task_list = test_javadoc_instances

        flag_javadoc = True

    elif task == 'multi_comment':

        if split == 'train':
            task_list = train_multi_comment_instances

        elif split == 'eval':
            task_list = eval_multi_comment_instances

        else:
            task_list = test_multi_comment_instances

        flag_javadoc = False

    else:

        if split == 'train':
            task_list = train_single_comment_instances

        elif split == 'eval':
            task_list = eval_single_comment_instances

        else:
            task_list = test_single_comment_instances

        flag_javadoc = False

    for (idx,sample) in enumerate(tqdm(task_list)):

        sample = sample.replace('{@link _REF_}',FAKE_REF)
        sample = sample.replace('{@link _LINK_}',FAKE_LINK)

        if task == 'javadoc':
            comments = re.findall("<sep>([\s\S]*?)<sep>", sample)
            comments = [comments[-1]]
        else:
            comments = re.findall("<sep>([\s\S]*?)<sep>", sample)

        if len(comments)>0:

           for comment in comments:

                x,y = getInstaces4Sample(comment, sample, javadoc=flag_javadoc)

                if x == None: continue

                for input,label in zip(x,y):

                    if '{@link _LINK_}' in input:
                        print(input)

                    input = input.replace(FAKE_LINK,'{@link _LINK_}')
                    input = input.replace(FAKE_REF,'{@link _REF_}')

                    if FAKE_REF in input:
                        print(label)

                    label = label.replace(FAKE_LINK,'{@link _LINK_}')
                    label = label.replace(FAKE_REF,'{@link _REF_}')

                    all_samples_X.append(input)
                    all_samples_Y.append(label)


                if task == 'javadoc':
                    task_dictionary['javadoc_%s' % idx] = {'input':all_samples_X, 'output':all_samples_Y}
                elif task == 'multi_comment':
                    task_dictionary['multi_comment_%s' % idx] = {'input':all_samples_X, 'output':all_samples_Y}
                else:
                    task_dictionary['single_comment_%s' % idx] = {'input':all_samples_X, 'output':all_samples_Y}

                all_samples_X = []
                all_samples_Y = []

    return task_dictionary

In [None]:
ft_javadoc_list_dict_train = createSetBySample('javadoc', split='train')
ft_javadoc_list_dict_test = createSetBySample('javadoc',  split='test')
ft_javadoc_list_dict_eval = createSetBySample('javadoc',  split='eval')

ft_single_comment_list_dict_train = createSetBySample('single_comment', split='train')
ft_single_comment_list_dict_test = createSetBySample('single_comment', split='test')
ft_single_comment_list_dict_eval = createSetBySample('single_comment', split='eval')

ft_multi_task_dict_train = {**ft_javadoc_list_dict_train, **ft_single_comment_list_dict_train}
ft_multi_task_dict_test = {**ft_javadoc_list_dict_test, **ft_single_comment_list_dict_test}
ft_multi_task_dict_eval = {**ft_javadoc_list_dict_eval, **ft_single_comment_list_dict_eval}

assert(len(ft_multi_task_dict_train) == len(ft_javadoc_list_dict_train)  + len(ft_single_comment_list_dict_train))
assert(len(ft_multi_task_dict_test) == len(ft_javadoc_list_dict_test)  + len(ft_single_comment_list_dict_test))
assert(len(ft_multi_task_dict_eval) == len(ft_javadoc_list_dict_eval)  + len(ft_single_comment_list_dict_eval))

In [None]:
ft_javadoc_list_dict = {'train':ft_javadoc_list_dict_train, 'test':ft_javadoc_list_dict_test, 'eval':ft_javadoc_list_dict_eval}
ft_single_comment_list_dict = {'train':ft_single_comment_list_dict_train, 'test':ft_single_comment_list_dict_test, 'eval':ft_single_comment_list_dict_eval}
ft_multi_task_dict = {'train':ft_multi_task_dict_train, 'test':ft_multi_task_dict_test, 'eval':ft_multi_task_dict_eval}

#Saving task-filtered instances
with open('javadoc_task_instances.pickle', 'wb') as fp:
    pickle.dump(ft_javadoc_list_dict, fp, protocol=pickle.HIGHEST_PROTOCOL)

with open('single_comment_task_instances.pickle', 'wb') as fp:
    pickle.dump(ft_single_comment_list_dict, fp, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
# This function creates the multitask dataset for the finetuning
def writeFlattenedDataset():

    ###### TRAINING SET ######
    flatten_input = []
    flatten_output = []

    for key in ft_multi_task_dict_train.keys():

        for (input, output)  in zip(ft_multi_task_dict_train[key]['input'], ft_multi_task_dict_train[key]['output']):

            if 'javadoc' in key: target_prefix = JAVADOC_PREFIX
            else: target_prefix = BLOCK_COMMENT_PREFIX

            flatten_input.append(target_prefix + flatten(input)+ T5_EOS + '\n')
            flatten_output.append(flatten(output)+ T5_EOS + '\n')

    df_train = pd.DataFrame(list(zip(flatten_input, flatten_output)), columns = ['input' , 'output'])
    df_train = df_train.sample(frac = 1, random_state=SEED)
    train_source = open('train.source','a+')
    train_target = open('train.target','a+')

    for (input, output) in zip(df_train['input'],df_train['output']):
        train_source.write(input)
        train_target.write(output)

    train_source.close()
    train_target.close()

    ########################


    ###### EVALUATION SET ######
    flatten_input = []
    flatten_output = []

    for key in ft_multi_task_dict_eval.keys():

        for (input, output)  in zip(ft_multi_task_dict_eval[key]['input'], ft_multi_task_dict_eval[key]['output']):

            if 'javadoc' in key: target_prefix = JAVADOC_PREFIX
            else: target_prefix = BLOCK_COMMENT_PREFIX

            flatten_input.append(target_prefix + flatten(input)+ T5_EOS + '\n')
            flatten_output.append(flatten(output)+ T5_EOS + '\n')

    df_eval = pd.DataFrame(list(zip(flatten_input, flatten_output)), columns = ['input' , 'output'])
    df_eval = df_eval.sample(frac = 1, random_state=SEED)

    eval_source = open('eval.source','a+')
    eval_target = open('eval.target','a+')

    for (input, output) in zip(df_eval['input'],df_eval['output']):
        eval_source.write(input)
        eval_target.write(output)

    eval_source.close()
    eval_target.close()

    ########################

    ###### TEST SET ######
    flatten_input = []
    flatten_output = []

    for key in ft_multi_task_dict_test.keys():

        for (input, output)  in zip(ft_multi_task_dict_test[key]['input'], ft_multi_task_dict_test[key]['output']):

            if 'javadoc' in key: target_prefix = JAVADOC_PREFIX
            else: target_prefix = BLOCK_COMMENT_PREFIX

            flatten_input.append(target_prefix + flatten(input)+ T5_EOS + '\n')
            flatten_output.append(flatten(output)+ T5_EOS + '\n')

    df_test = pd.DataFrame(list(zip(flatten_input, flatten_output)), columns = ['input' , 'output'])
    df_test = df_test.sample(frac = 1, random_state=SEED)

    test_source = open('test.source','a+')
    test_target = open('test.target','a+')

    for (input, output) in zip(df_test['input'],df_test['output']):
        test_source.write(input)
        test_target.write(output)

    test_source.close()
    test_target.close()

    ########################

writeFlattenedDataset()

In [None]:
def checkForFakeJavadoc(item):
    for forbidden in stop_word_java_tag:
        if forbidden in item:
            return False
    return True

In [None]:
#### Final check, since we may have still have some duplicates
#### Going to apply set() on the flattened version of the instances

####### DUPLICATES CHECK STARTS HERE #########

eval_list_input = []
eval_list_output = []

with open('eval.source') as fread:

    for item in fread.readlines():
        item = item.strip()
        eval_list_input.append(item)


with open('eval.target') as fread:

    for item in fread.readlines():
        item = item.strip()
        eval_list_output.append(item)

assert(len(eval_list_output) == len(eval_list_input))

train_list_input = []
train_list_output = []

with open('train.source') as fread:

    for item in fread.readlines():
        item = item.strip()
        train_list_input.append(item)


with open('train.target') as fread:

    for item in fread.readlines():
        item = item.strip()
        train_list_output.append(item)

test_list_input = []
test_list_output = []

with open('test.source') as fread:

    for item in fread.readlines():
        item = item.strip()
        test_list_input.append(item)

with open('test.target') as fread:

    for item in fread.readlines():
        item = item.strip()
        test_list_output.append(item)


joined_train = []
for item1,item2 in zip(train_list_input, train_list_output):
    item1 = item1.strip()
    item2 = item2.strip()
    new_instance = '{}<COLLEGAMENTO>{}'.format(item1, item2)
    joined_train.append(new_instance)

joined_test = []
for item1,item2 in zip(test_list_input, test_list_output):
    item1 = item1.strip()
    item2 = item2.strip()
    new_instance = '{}<COLLEGAMENTO>{}'.format(item1, item2)
    joined_test.append(new_instance)

joined_eval = []
for item1,item2 in zip(eval_list_input, eval_list_output):
    item1 = item1.strip()
    item2 = item2.strip()
    new_instance = '{}<COLLEGAMENTO>{}'.format(item1, item2)
    joined_eval.append(new_instance)

set_train = set(joined_train)
set_test = set(joined_test)
set_eval = set(joined_eval)


In [None]:
final_train_input = open('train_new.source','a+')
final_train_output = open('train_new.target','a+')

final_test_input = open('test_new.source','a+')
final_test_output = open('test_new.target','a+')

overlapped_train_to_test = len(set_train.difference(set_test))
no_dup_train = list(set_train.difference(set_test))

set_train = set(no_dup_train)
overlapped_train_to_eval = len(set_train.difference(set_eval))
no_dup_train = list(set_train.difference(set_eval))

if len(no_dup_train) < len(set_train):
    for item in no_dup_train:
        input_x = item.split('<COLLEGAMENTO>')[0]
        label =  item.split('<COLLEGAMENTO>')[1]
        final_train_input.write(input_x+'\n')
        final_train_output.write(label+'\n')

overlapped_test_eval = len(set_test.difference(set_eval))
no_dup_test = list(set_test.difference(set_eval))

if len(no_dup_test) < len(set_test):
    for item in no_dup_test:
        input_x = item.split('<COLLEGAMENTO>')[0]
        label =  item.split('<COLLEGAMENTO>')[1]
        final_test_input.write(input_x+'\n')
        final_test_output.write(label+'\n')


final_test_output.close()
final_test_input.close()
final_train_input.close()
final_train_output.close()

train_set = set(no_dup_train)
test_set = set(no_dup_test)
duplicated = train_set.intersection(test_set)

assert(len(duplicated) == 0)

######### DUPLICATES CHECK ENDS HERE #########

In [None]:
# HERE WE'RE GOING TO FURTHER DISCARD BROKEN INSTANCES THAT MAY BE COLLOCATED INTO THE INSIDE-TASK DATASET
# THIS HAPPEN WHEN A JAVA METHOD DEFINES OTHER JAVA METHODS WITHIN ITSELF
# SPECIFICALLY, THE CODESEARCHNET DATASET ONLY REPORT THE TOP-LEVEL DOCSTRING, THEREFORE WE HAVE TO MANUALLY MANAGE SUCH CASES

train_input = open('train_new.source','r')
train_target = open('train_new.target','r')

train_input_finale = open('train_final.source','a+')
train_label_finale = open('train_final.target','a+')

for input,target in zip(train_input.readlines(), train_target.readlines()):

    input = input.strip()
    target = target.strip()

    if 'complete block/inline comment' in input:
        if checkForFakeJavadoc(input) and checkForFakeJavadoc(target):
            train_input_finale.write(input+'\n')
            train_label_finale.write(target+'\n')
    else:
        train_input_finale.write(input+'\n')
        train_label_finale.write(target+'\n')

train_input_finale.close()
train_label_finale.close()

train_input.close()
train_target.close()

test_input = open('test_new.source','r')
test_target = open('test_new.target','r')

test_input_finale = open('test_final.source','a+')
test_label_finale = open('test_final.target','a+')

for input,target in zip(test_input.readlines(), test_target.readlines()):

    input = input.strip()
    target = target.strip()

    if 'complete block/inline comment' in input:
        if checkForFakeJavadoc(input) and checkForFakeJavadoc(target):
            test_input_finale.write(input+'\n')
            test_label_finale.write(target+'\n')
    else:
        test_input_finale.write(input+'\n')
        test_label_finale.write(target+'\n')

test_input_finale.close()
test_label_finale.close()

test_input.close()
test_target.close()


eval_input = open('eval.source','r')
eval_target = open('eval.target','r')


eval_input_finale = open('eval_final.source','a+')
eval_label_finale = open('eval_final.target','a+')

for input,target in zip(eval_input.readlines(), eval_target.readlines()):

    input = input.strip()
    target = target.strip()

    if 'complete block/inline comment' in input:
        if checkForFakeJavadoc(input) and checkForFakeJavadoc(target):
            eval_input_finale.write(input+'\n')
            eval_label_finale.write(target+'\n')
    else:
        eval_input_finale.write(input+'\n')
        eval_label_finale.write(target+'\n')

eval_input_finale.close()
eval_label_finale.close()

eval_input.close()
eval_target.close()

In [None]:
#Uncomment this cell only for pre-training dataset only

# pretrain_flatten = open('pretrain_dataset.txt','a+')

# #Here we use the flatten function to flatten the pretraining set as well
# for item  in pretrain:
#     pretrain_flatten.write(flatten(item)+'\n')
#
#
# pretrain_flatten.close()

