<a href="https://colab.research.google.com/github/amirpaia/blenderbot/blob/main/french_reddit_data_preparation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0. Installation & Download the dataset

Upload the kaggle.json file

In [1]:
# download from kaggle
! pip install kaggle
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
! kaggle datasets download breandan/french-reddit-discussion
! unzip french-reddit-discussion.zip

Downloading french-reddit-discussion.zip to /content
 97% 415M/426M [00:03<00:00, 108MB/s]
100% 426M/426M [00:04<00:00, 110MB/s]
Archive:  french-reddit-discussion.zip
  inflating: final_SPF_2.xml         
  inflating: spf.tar.gz              


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# 1. Loading the dataset

In [4]:
import time 
import lxml.etree as ET
import pandas as pd
import numpy as np

file_path = 'final_SPF_2.xml'
start=time.time()
#Initializes the parser
parser = ET.XMLParser(recover=True)
#Parses the file
tree = ET.parse(file_path, parser=parser)
xroot = tree.getroot()

#One conversation -> one line in the data array
dfcols = ['link_id', 'subreddit_id', 'uid',"comment_id",'score', 'parent_id', 'create_utc', 'text']
data =np.array(
    (
        [
            [
                [
                    node.attrib.get('link_id'),
                    node.attrib.get('subreddit_id'), 
                    node.getchildren()[j].get('uid'), 
                    node.getchildren()[j].get('comment_id'), 
                    node.getchildren()[j].get('score'), 
                    node.getchildren()[j].get('parent_id'), 
                    node.getchildren()[j].get('create_utc'),
                    node.getchildren()[j].text
                ] 
                for j in range(len(node.getchildren()))
            ] 
            for node in xroot
        ]
     ), 
    dtype=object)

print('number of conversations: ',data.shape[0])

#one comments -> one line in the data array
data=np.array([liste for conversation in data for liste in conversation], dtype=object)
print('number of comments: ',data.shape[0])

df_xml = pd.DataFrame(data=data, columns=dfcols)
print('All done in :',time.time()-start,' seconds')

df_xml.head()

number of conversations:  556622
number of comments:  1583083
All done in : 26.908819913864136  seconds


Unnamed: 0,link_id,subreddit_id,uid,comment_id,score,parent_id,create_utc,text
0,8r1kz,2qhjz,1688932,c0a62uj,3,8r1kz,1244576002,Ironie : l'article disant qu'on est plus capab...
1,8r1kz,2qhjz,786883,c0a6lmb,1,c0a62uj,1244621120,"Moi-même, j'ai dû me forcer pour arriver jusqu..."
2,8sncs,2qhjz,390497,c0aawpk,1,8sncs,1245076061,Service qui sera rendu au contribuable pour la...
3,8sncs,2qhjz,32884,c0aaxba,3,c0aawpk,1245077396,Eeeeh oui ! 70 millions pour une loi qui aura ...
4,8v13c,2qhjz,796919,c0aj3ov,2,8v13c,1245830384,Est-ce qu'elle a vraiment commis des crimes qu...


We assumed every thread in Reddit as a tree and this recursive functions tries to find all paths in a tree and consider each one as a dialog

In [None]:
# a recusive function to get all children
def get_children(i, comment_id, comment_ids = []):
    children = df_xml[df_xml['parent_id'] == comment_id]
    if (len(children) == 0): 
        return comment_ids
    else:
        for child in children.values:
            temp_comment_ids = []
            temp_comment_ids.extend(comment_ids)
            temp_comment_ids.append(child[3])
            
            path = get_children(i, child[3],temp_comment_ids)
            if (path != None):
                all_dialogs.append(f"{i}\t{path}\n")
        

all_dialogs = []
get_children(0, '9gw99')
print(all_dialogs)
# print(get_children('c0imipv')) # root: c0imipv  | c0impi6

["0\t['c0cq7vy', 'c0cqynp', 'c0cqyze', 'c0cr17m', 'c0cr1hs']\n"]


Finding all paths and save Ids in the output file

In [None]:
data_path = '/content/drive/MyDrive/colabs/aliae-workspace/datasets/'

import os.path
import time
print(len(df_xml[df_xml['link_id'] == df_xml['parent_id']].values))

all_dialogs = []
starttime = time.time()
last_index_in_file = 0
if os.path.exists(f"{data_path}french_reddit_all_dialog_turns_ids.txt"):
  with open(f"{data_path}french_reddit_all_dialog_turns_ids.txt") as f:
    last_line = f.readlines()[-1]

  if (last_line != "" or last_line != None):
    last_index_in_file = int(last_line.split("\t")[0]) + 1
    print("last_index_in_file:" , last_index_in_file)

for i, row in enumerate(df_xml[df_xml['link_id'] == df_xml['parent_id']].values[last_index_in_file:]):
    get_children(last_index_in_file + i, row[5])

    if (last_index_in_file + i) % 20 == 0 : 
        with open(f"{data_path}french_reddit_all_dialog_turns_ids.txt", 'a') as f:
          f.writelines(all_dialogs)
          print(f"i:{i}", f"last index:{(last_index_in_file + i)}", f"dialogs: {len(all_dialogs)}", f"time: {round(time.time() - starttime, 2)}")
          all_dialogs = []
        # starttime = time.time()
    
    if (last_index_in_file + i) % 50000 == 0: break

print(len(all_dialogs))

# 2. Reddit Ids > Dialog > ParlAI format > Train/Valid/Test

## General Functions

In [5]:
def transfer_dialog(d):
    if len(d)%2 !=0: d = d[:-1]
    t = ""
    for i in range(0,len(d),2):
        u1 = d[i]
        u2 = d[i+1]

        if (i+2) != len(d):
            t += "text:"+u1+"\t"+"labels:"+u2+"\n"
        else:
            t += "text:"+u1+"\t"+"labels:"+u2+"\t"+"episode_done:True"+"\n"
    return t
    
t = ['hello','how are you','good','bye', 'to be removed']
print(transfer_dialog(t))

def colored(r, g, b, text):
    return "\033[38;2;{};{};{}m{} \033[38;2;255;255;255m".format(r, g, b, text)
  
text = 'Hello, World'
colored_text = colored(0, 255, 0, text)
print(colored_text)

def convert_parlai_format_to_list_of_turns(lines):
    result = []
    for line in lines:
        text_label = line.split("\t")
        if len(text_label) >=2:
            result.append(text_label[0].replace("text:", ""))
            result.append(text_label[1].replace("labels:", "").replace("\n",""))

            print(len(text_label[0].split()), colored(200,200,200, text_label[0]))
            print("    ",len(text_label[1].split()) , colored(0,150,0, text_label[1]))
    return result

text:hello	labels:how are you
text:good	labels:bye	episode_done:True

[38;2;0;255;0mHello, World [38;2;255;255;255m


In [10]:
def convert_line_to_list_of_ids(line):
    return line.split("\t")[1].replace("[","").replace("]","").replace("\n","").replace(" ","").replace("'","").split(",")

In [7]:
import re
def remove_urls (vTEXT):
    vTEXT = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', vTEXT, flags=re.MULTILINE)
    return(vTEXT)
    
print( remove_urls("this is a test https://sdfs.sdfsdf.com/sdfsdf/sdfsdf/sd/sdfsdfs?bob=%20tree&jef=man lets see this too https://sdfsdf.fdf.com/sdf/f end"))

this is a test  lets see this too  end


In [8]:
data_path = '/content/drive/MyDrive/colabs/aliae-workspace/datasets/'
with open(f"{data_path}french_reddit_all_dialog_turns_ids.txt") as f:
    lines = f.readlines()

In [12]:
pd.set_option('display.max_colwidth', None)
# print(len(lines))
# print(sum([1 for l in lines if len(convert_line_to_list_of_ids(l)) == 10]))
print(len(lines))

number_of_all_turns = 0

flag = False
dict_xml = df_xml.set_index('comment_id').to_dict()['text']
dialogs_in_parlai_format = []
for index, line in enumerate(lines):
    ids = convert_line_to_list_of_ids(line)

    if len(ids) >= 6: 
    #     print(index)
    #     break
        turns = []
        for ii, id in enumerate(ids):
            text = dict_xml[id].replace("\n", "").replace("\t", "")
            text = remove_urls(text)
            if len(text.split()) > 128:
                # flag = True
                turns.append(' '.join(text.split()[:128]))
                break
            else:
                turns.append(text)
        if len(turns) >= 6: 
            # print("\n",line, ii)
            # print("\n",text)
            # print("\n",turns)
            # break
            dialogs_in_parlai_format.append([transfer_dialog(turns),len(turns)])

    if index > 0 and index % 100000 == 0: print(f"index: {index}")

973124

 89	['c0taiz9', 'c0tapmq', 'c0tas0l', 'c0tatby', 'c0taw1u', 'c0tax6a', 'c0tayek']
 2

 "Quant à l'épithète "enculé", elle stigmatise la passivité de l'homme bestialisé" Pratiquer de l'intellectualisme aussi pédant, m'hérisse les poils. Ce genre de sophistication linguistique est pour moi de la poudre aux yeux quand on a rien à dire d'intéressant. Les insultes sont le pain quotidien de beaucoup de jeunes, ce ne sont plus des insultes mais des constructions élaborées à la sémantique nuancée et cachée. J'adore traiter mes amis de gros enculés pour un rien, c'est un acte d'amour. Ce mec n'a rien compris à la beauté de l'argot. Je dis des choses bien plus blessantes avec un registre soutenu qu'avec des insultes. Et puis, il n'y a pas de débat à avoir. Subitement l'opinion francaise s'offusque de gros mots alors que les politiques nous chient dans les bottes à longueur de journées, mais bien sûr avec verve. Encore un cache misère du désert intellectuel Français. 

 ["Cet article pue 

In [22]:
print(len(dialogs_in_parlai_format))
import statistics
# [a for a in dialogs_in_parlai_format][:5]
print(statistics.mean([a[1] for a in dialogs_in_parlai_format]))
print(statistics.mean([len(a[0].split()) for a in dialogs_in_parlai_format]))

121082
7.282841380221668
233.61439355147752


In [20]:
# dialogs_in_parlai_format[1]
statistics.mean([len(a.split()) for a in convert_parlai_format_to_list_of_turns(dialogs_in_parlai_format[1][0].split("\n"))])


100 [38;2;200;200;200mtext:Nice, ou Sophia-Antipolis? C'est pas la même chose! Je pose la question parce qu'on est sur reddit, donc il y a des chances que ce soit la première option :)Pour la banque je ne sais pas trop, pour le mobile ça dépend de ta consomation (il faut nous en dire un peu plus). Enfin pour les assurances, la sécu couvre la plupart des frais, tu peux y rajouter une bonne mutuelle (vérifie que ton employeur ne t'en fournit pas une gratuitement).Ton français n'est pas si mauvais! Si tu ne comprends pas quelque chose, n'hésite pas à demander, je peux traduire. [38;2;255;255;255m
     49 [38;2;0;150;0mlabels:Nice, je suis etudiant a edhec, pour le mobile je besoin quel que chose qui avais le data, qu'est ce que c'est la secu?, (Je dois trouver les accents sur se l'ordinateur)Je besoin becoup de practique avec mon francais, est-ce que il y a de bonne shopping a Nice? [38;2;255;255;255m
8 [38;2;200;200;200mtext:un conseil, n'habite pas à coté de l'Edhec. [38;2;255;255

37.166666666666664

In [23]:
data_path = '/content/drive/MyDrive/colabs/aliae-workspace/datasets/french_reddit/'
# with open(data_path + "train.txt", "w") as f:
#     f.writelines(dialogs_in_parlai_forma

df = pd.DataFrame ([a[0] for a in dialogs_in_parlai_format], columns = ['dialog'])
train, valid, test = np.split(df.sample(frac=1, random_state=42), 
                                 [int(.8*len(df)), 
                                  int(.9*len(df))])
print(f"train set: {len(train)}, validation set: {len(valid)},test set: {len(test)}")

with open(f"{data_path}/data_train.txt","w") as f:
    f.write('\n'.join(a[0] for a in train.values))

with open(f"{data_path}/data_valid.txt","w") as f:
    f.write('\n'.join(a[0] for a in valid.values))

with open(f"{data_path}/data_test.txt","w") as f:
    f.write('\n'.join(a[0] for a in test.values))
print('done!')

train set: 96865, validation set: 12108,test set: 12109
done!
