In [28]:
import numpy as np

In [29]:
class WordSequence(object):
    PAD_TAG = '<pad>'
    UNK_TAG = '<unk>'
    START_TAG = '<s>'
    END_TAG = '</s>'
    
    PAD = 0
    UNK = 1
    START = 2
    END = 3
    
    
    def __init__(self):
        self.dict = {
            WordSequence.PAD_TAG: WordSequence.PAD,
            WordSequence.UNK_TAG: WordSequence.UNK,
            WordSequence.START_TAG: WordSequence.START,
            WordSequence.END_TAG: WordSequence.END,
        }
        self.fited = False
        
    def to_index(self,word):
        assert self.fited, "Please fit the WordSequence instance first."
        if word in self.dict:
            return self.dict[word]
        return WordSequence.UNK
    
    def to_word(self, index):
        assert self.fited, "Please fit the WordSequence instance first."
        for k, v in self.dict.items():
            if v == index:
                return k
        return WordSequence.UNK_TAG
    
    def size(self):
        assert self.fited, "Please fit the WordSequence instance first."
        return len(self.dict) + 1
    
    def __len__(self):
        return self.size()
    
    # TODO maybe use tf-idf to determine which word is not necessary is better?
    # the sentences should be a list of sentences
    def fit(self,sentences, min_count = 5, max_count = None, max_feature=None):
        assert not self.fited, "You can only fit the instance once."
        
        count = {}
        
        for sentence in sentences:
            words = sentence.strip().split(" ")
            for word in words:
                if word not in count:
                    count[word] = 0
                count[word] += 1
        if min_count is not None:
            count = {k:v for k,v in count.items() if v>=min_count}
            
        if max_count is not None:
            count = {k:v for k,v in count.items() if v<=max_count}
            
        self.dict = {
            WordSequence.PAD_TAG: WordSequence.PAD,
            WordSequence.UNK_TAG: WordSequence.UNK,
            WordSequence.START_TAG: WordSequence.START,
            WordSequence.END_TAG: WordSequence.END,           
        }
        
        if isinstance(max_feature, int):
            count = sorted(list(count.items(),key = lambda x:x[1]))
            if mx_features is not None and len(count) > max_features:
                count = count[-int(max_features):]
            for w, _ in count:
                self.dict[w] = len(self.dict)
        else: 
            for w in sorted(count.keys()):
                self.dict[w] = len(self.dict)
        self.fited = True
    #input should be a string
    def transform(self, sentence, max_len = None):
        assert self.fited, "Please fit the WordSequence instance first."
        if max_len is not None:
            r = [self.PAD] * max_len
        else:
            r = [self.PAD] * len(sentence.strip().split(" "))
            
        for index, word in enumerate(sentence.strip().split(" ")):
            if max_len is not None and index >=len(r):
                break
            r[index] = self.to_index(word)
        
        return np.array(r)
    
    
    def inverse_transform(self, indices,
                          ignore_pad=False, ignore_unk=False,
                          ignore_start=False, igonre_end=False):
        ret = []
        for i in indices:
            word = self.to_word(i)
            if word == WordSequence.PAD_TAG and ignore_pad:
                continue
            if word == WordSequence.UNK_TAG and ignore_unk:
                continue
            if word == WordSequence.START_TAG and ignore_start:
                continue
            if word == WordSequence.END_TAG and igonre_end:
                continue
            ret.append(word)

        return ret
    
    

In [12]:
string = "Hello World"
a = string.strip().split(" ")
a

['Hello', 'World']

In [36]:
def test():
    ws = WordSequence()
    ws.fit([
    "Hello World! This is a sample sentence",
    "This is a test sentence, yeah!"],min_count = 0)
    
    indice = ws.transform("Hello GWU, the result is not defined")
    print(indice)
    
    back = ws.inverse_transform(indice)
    print(back)

In [37]:
if __name__ == '__main__':
    test()

[4 1 1 1 8 1 1]
['Hello', '<unk>', '<unk>', '<unk>', 'is', '<unk>', '<unk>']
