In [132]:
from collections import Counter
from typing import List, Dict, Union, Tuple
from functools import reduce
import torch


In [2]:
### Setup small tests - borrowed from CS224N 2018-19 Homework 5
sentences_small = [["a", "b", "c?"], ["~d~", "c", "b", "a"]]
sentences_words = [['Human:', 'What', 'do', 'we', 'want?'], ['Computer:', 'Natural', 'language', 'processing!'], [
        'Human:', 'When', 'do', 'we', 'want', 'it?'], ['Computer:', 'When', 'do', 'we', 'want', 'what?']]

*Create class for source or target vocabulary and another that essentially contains variables to store both*

In [3]:
## Borrowed from d2l.ai http://d2l.ai/chapter_recurrent-neural-networks/text-preprocessing.html#vocabulary

def count_corpus(tokens: Union[List[str], List[List[str]]],
                 sort: bool = True) -> List[tuple]:
    if len(tokens) == 0 or isinstance(tokens[0], list):
        tokens = [token for line in tokens for token in line] ## Flatten 2D List into 1D list
    sorted_counts = sorted(Counter(tokens).items(), key=lambda x: x[0]) ## By ascii
    sorted_counts = sorted(sorted_counts, key=lambda x: x[1], reverse=True) ## By count descending
    return sorted_counts

In [4]:
print(count_corpus(sentences_words))

[('do', 3), ('we', 3), ('Computer:', 2), ('Human:', 2), ('When', 2), ('want', 2), ('Natural', 1), ('What', 1), ('it?', 1), ('language', 1), ('processing!', 1), ('want?', 1), ('what?', 1)]


In [38]:
def pad_sents(sents: List[List[int]], pad_token: int) -> List[List[int]]:
    sents_padded = []

    max_sent_length = reduce(max, map(len, sents))
    for sent in sents:
        sents_padded.append(sent + [pad_token]*(max_sent_length - len(sent)))
    
    return sents_padded


In [117]:
def pad_sents_char(sents: List[List[List[int]]], pad_token: int, max_word_length: int = 21) -> List[List[List[int]]]:
    max_sent_length = reduce(max, map(len, sents))
    empty_word = [pad_token] * max_word_length
    sents_padded = []
    for sent in sents:
        sent_padded = [(word + [pad_token]*(max_word_length - len(word)))[:max_word_length] for word in sent]
        sents_padded.append(sent_padded + [empty_word] * (max_sent_length - len(sent)))
    return sents_padded


In [167]:
class VocabStore(object):
    """
    Will store source or target vocabulary
    """
    def __init__(self, tokens: List[List[str]] = None, token2id: Dict[str, int] = None, 
                 min_freq: int = 0, reserved_tokens: Dict[str, str] = None) -> None:

        # For handling tokens

        if token2id:
            self.token2id = token2id

        else:
            self.token2id = {}
            if not reserved_tokens:
                reserved_tokens = {}
            
            reserved_tokens["unk"] = reserved_tokens.get("unk", "<unk>")
            reserved_tokens["pad"] = reserved_tokens.get("pad", "<pad>")
            reserved_tokens["start"] = reserved_tokens.get("start", "<s>")
            reserved_tokens["end"] = reserved_tokens.get("end", "</s>")

            self.start_token, self.token2id[reserved_tokens['start']] = reserved_tokens["start"], 1
            self.end_token, self.token2id[reserved_tokens['end']] = reserved_tokens["end"], 2
            self.unk, self.token2id[reserved_tokens['unk']] = reserved_tokens["unk"], 3
            self.pad, self.token2id[reserved_tokens['pad']] = reserved_tokens["pad"], 0

        if not tokens:
            tokens = []

        self.id2word = {}
        uniq_tokens = list(self.token2id.keys())
        token_freqs = count_corpus(tokens)
        uniq_tokens += [token for token, freq in token_freqs
                        if freq >= min_freq and token not in uniq_tokens]

        for token in uniq_tokens:
            self.token2id[token] = self.token2id.get(token, len(self.token2id))
            self.id2word[self.token2id[token]] = token

        # For handling chars

        self.char_list = list(
            """ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]"""
        )
        self.char2id = {}
        self.char2id[self.pad] = 0
        self.start_char, self.char2id["{"] = "{", 1
        self.end_char, self.char2id["}"] = "}", 2
        self.char2id[self.unk] = 3

        for c in self.char_list:
            self.char2id[c] = self.char2id.get(c, len(self.char2id))
        
        self.id2char = {v:k for k,v in self.char2id.items()}
        

    def __len__(self) -> int:
        return len(self.token2id)

    def __getitem__(self, tokens: Union[List[str], Tuple[str], str]) -> Union[List[int], int]:
        if not isinstance(tokens, (list, tuple)):
            return self.token2id.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def __contains__(self, token) -> bool:
        return token in self.token2id

    def __setitem__(self, key, value):
        raise ValueError("Vocabulary store is read only")

    def __repr__(self) -> str:
        return f"Vocab Store: Tokens [size={len(self)}], Characters [size={len(self.char2id)}]"

    def to_tokens(self, indices: Union[List[int], Tuple[int], int]) -> Union[List[int], int]:
        if not isinstance(indices, (list, tuple)):
            return self.id2word.get(indices, None)
        return [self.to_tokens(index) for index in indices]

    def len(self, tokens: bool = True) -> int:
        return len(self.token2id) if tokens else len(self.char2id)

    def sent2id(self, sents: List[List[str]]) -> List[List[int]]:
        return [self[sent] for sent in sents]

    def to_charid(self, char: Union[List[str], str]) -> int:
        if not isinstance(char, (list, tuple)):
            return self.char2id.get(char, self.unk)
        return [self.to_charid(c) for c in char]

    def word2char(self, tokens: [Union[List[str], str]]) -> Union[List[List[int]], List[int]]:
        if not isinstance(tokens, (list, tuple)):
            return [self.char2id.get(char, self.unk) for char in self.start_char + tokens + self.end_char]
        return [self.word2char(token) for token in tokens]

    def to_char(self, indices: Union[int, List[int]]) -> Union[str, List[str]]:
        if not isinstance(indices, (list, tuple)):
            return self.id2char.get(indices, None)
        return [self.to_char(index) for index in indices]

    def sent2charid(self, sents: List[List[str]]) -> List[List[List[int]]]:
        return [self.word2char(sent) for sent in sents]

    def to_tensor(self, sents: List[List[str]], tokens: bool, device: torch.device) -> torch.Tensor:
        ids = self.sent2id(sents) if tokens else self.sent2charid(sents)
        pad_ids = pad_sents(ids, self[self.pad]) if tokens else pad_sents_char(ids, self.to_charid(self.pad))
        tensor_sents = torch.tensor(pad_ids, dtype=torch.long, device=device)
        return torch.t(tensor_sents) if tokens else tensor_sents.permute([1, 0, 2])

    

In [168]:
vocab_store = VocabStore(sentences_words)
print(vocab_store)
print(list(vocab_store.token2id.items())[:10])

Vocab Store: Tokens [size=17], Characters [size=96]
[('<s>', 1), ('</s>', 2), ('<unk>', 3), ('<pad>', 0), ('do', 4), ('we', 5), ('Computer:', 6), ('Human:', 7), ('When', 8), ('want', 9)]


In [169]:
print("we" in vocab_store)
print("NOTWE" in vocab_store)

True
False


In [170]:
print(vocab_store[sentences_words[0]])

[7, 11, 4, 5, 15]


In [171]:
print(vocab_store.to_tokens([7, 11, 4, 5, 15]))

['Human:', 'What', 'do', 'we', 'want?']


In [172]:
vocab_store.sent2id(sentences_words)

[[7, 11, 4, 5, 15], [6, 10, 13, 14], [7, 8, 4, 5, 9, 12], [6, 8, 4, 5, 9, 16]]

In [173]:
pad_sents(vocab_store.sent2id(sentences_words), vocab_store[vocab_store.pad])

[[7, 11, 4, 5, 15, 0],
 [6, 10, 13, 14, 0, 0],
 [7, 8, 4, 5, 9, 12],
 [6, 8, 4, 5, 9, 16]]

In [174]:
vocab_store.to_tokens([7, 11, 4, 5, 15])

['Human:', 'What', 'do', 'we', 'want?']

In [175]:
vocab_store.word2char(sentences_words[0])

[[1, 11, 50, 42, 30, 43, 71, 2],
 [1, 26, 37, 30, 49, 2],
 [1, 33, 44, 2],
 [1, 52, 34, 2],
 [1, 52, 30, 43, 49, 70, 2]]

In [176]:
vocab_store.to_char([1, 10, 49, 41, 29, 42, 70, 2])

['{', 'G', 't', 'l', 'Z', 'm', '?', '}']

In [177]:
vocab_store.sent2charid(sentences_words)

[[[1, 11, 50, 42, 30, 43, 71, 2],
  [1, 26, 37, 30, 49, 2],
  [1, 33, 44, 2],
  [1, 52, 34, 2],
  [1, 52, 30, 43, 49, 70, 2]],
 [[1, 6, 44, 42, 45, 50, 49, 34, 47, 71, 2],
  [1, 17, 30, 49, 50, 47, 30, 41, 2],
  [1, 41, 30, 43, 36, 50, 30, 36, 34, 2],
  [1, 45, 47, 44, 32, 34, 48, 48, 38, 43, 36, 69, 2]],
 [[1, 11, 50, 42, 30, 43, 71, 2],
  [1, 26, 37, 34, 43, 2],
  [1, 33, 44, 2],
  [1, 52, 34, 2],
  [1, 52, 30, 43, 49, 2],
  [1, 38, 49, 70, 2]],
 [[1, 6, 44, 42, 45, 50, 49, 34, 47, 71, 2],
  [1, 26, 37, 34, 43, 2],
  [1, 33, 44, 2],
  [1, 52, 34, 2],
  [1, 52, 30, 43, 49, 2],
  [1, 52, 37, 30, 49, 70, 2]]]

In [178]:
pad_sents_char(vocab_store.sent2charid(sentences_words), vocab_store.to_charid(vocab_store.pad))

[[[1, 11, 50, 42, 30, 43, 71, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 26, 37, 30, 49, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 33, 44, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 52, 34, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 52, 30, 43, 49, 70, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
 [[1, 6, 44, 42, 45, 50, 49, 34, 47, 71, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 17, 30, 49, 50, 47, 30, 41, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 41, 30, 43, 36, 50, 30, 36, 34, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 45, 47, 44, 32, 34, 48, 48, 38, 43, 36, 69, 2, 0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
 [[1, 11, 50, 42, 30, 43, 71, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 26, 37, 34, 43, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [183]:
vocab_store.to_tensor(sentences_words, tokens=True, device="cpu").size()

torch.Size([6, 4])

In [184]:
vocab_store.to_tensor(sentences_words, tokens=False, device="cpu").size()

torch.Size([6, 4, 21])