In [1]:
import pandas as pd
import os
import re
import numpy as np
import random
from collections import Counter

In [5]:
class data_loader:
    '''
    data memebers:
        self.tags
        self.messages
        self.word_voca
        self.char_voca
        self.embeddings
        self.word_dic
        self.embedded_message
    '''

    def __init__(self,path,embedding_path = os.getcwd() + "/data/glove.840B.300d.txt",lower = False):

        self.tags = []
        self.messages = []
        with open(path) as f:
            contents = f.readlines()
            for line in contents:
                #line = self.__clean_line(line)
                tag, message = line.strip().split("\t")
                self.tags.append(tag)
                self.messages.append(self.__clean_line(message,lower))
        self.__build_voc()
        self.__build_vector_dict(embedding_path)
        self.__convert_message()
        self.__convert_tag()
        print("init finished")

    def __build_voc(self):
        '''
        _build_voc will extract the vocabulary from the messages data
        '''
        counter_words = Counter()
        for message in self.messages:
            counter_words.update(message.strip().split())

        self.word_voca = list(counter_words)

        counter_chars = Counter()
        for word in self.word_voca:
            counter_chars.update(list(word))
        self.char_voca =list(counter_chars)

    def __clean_line(self, line,lower):
        '''
        remove punctuations
        '''
        line = line.encode('ascii',errors='ignore').decode()
        line = re.sub(r"([?.!,¿])", r" \1 ", line)
        line = re.sub(r'[" "]+', " ", line)
        line = re.sub(r"[^a-zA-Z?.!,¿]+", " ", line)
        if lower:
            line = line.lower()
        return line

    def __build_vector_dict(self, embedding_path):
        self.word_dic = {}
        for index, word in enumerate(self.word_voca):
            self.word_dic[word] = index


        print("Start to build word embedding vectors, this may take a while...")
        with open(embedding_path) as f:
            vectors = f.readlines()
            print("Embedding file loaded")
            dimension = len(vectors[0].strip().split())-1
            embeddings = np.zeros((len(self.word_voca),dimension))
            found = 0
            for line in vectors:
                line = line.strip().split()
                if len(line)!= dimension +1:
                    continue
                word = line[0]
                embedding = line[1:]
                if word in self.word_dic:
                    found += 1
                    embeddings[self.word_dic[word]] = embedding
            self.embeddings = embeddings

    def data_split(self, ration = 0.9, random = False, embedding = True):
        len_tags = len(self.tags)
        len_messages = len(self.messages)
        if embedding:
            tags = self.embedded_tags
            messages = self.embedded_messages
        else:
            #print("why else??!!??!!")
            tags = self.tags
            messages = self.messages
            
        if len_tags != len_messages:
            print("The number of tags doesn't equal to the number of messages, please check the file")
            return
        if random:
            tags,messages = self.__shuffle_data(tags,messages)
        train_message = np.array(messages[:int(len_messages*0.9)])
        train_tag = np.array(tags[:int(len_tags*0.9)])
        test_message = np.array(messages[int(len_messages*0.9):])
        test_tag = np.array(tags[int(len_tags*0.9):])

        return train_message, train_tag, test_message, test_tag

    #TODO shuffle the data when the dataset is different
    def __shuffle_data(self,tags,messages):
        data_type = type(tags)
        combined_list = list(zip(tags, messages))
        random.shuffle(combined_list)
        tags, messages = zip(*combined_list)
        tags = data_type(list(tags))
        messages = data_type(list(messages))                  
        return tags,messages
    
    def __convert_message(self):
        converted_message = np.zeros((len(self.messages),128,len(self.embeddings[0])))
        #print(converted_message.shape)
        for mess_index,message in enumerate(self.messages):
            #print(mess_index,message)
            message = message.strip().split()
            message_convert = np.zeros((128,len(self.embeddings[0])))
            for index,word in enumerate(message):
                if index >= 128:
                    break
                message_convert[index] = self.embeddings[self.word_dic[word]]
            #print(converted_message.shape)
            #print(message_convert.shape)
            converted_message[mess_index] = message_convert
            #np.append(converted_message,message_convert,axis = 0)
            #converted_message.append(np.array(message_convert),axis = 0)   
        #print(converted_message.shape)
        self.embedded_messages = converted_message
        
    def __convert_tag(self):
        self.embedded_tags = list(map(lambda x: int(x=="spam"),self.tags))

In [3]:
#test = data_loader(os.getcwd() + "/data/SMSSpamCollection",lower=True)
#print(test.tags)
#print(test.char_voca)
#test.__build_voc()