# Word AutoCorrect using Twitter posts
Word Autocorrect based on words from Sentiment140 twitter posts.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from utils import train, top_n, top_n_accuracy, count_parameters
import pandas as pd
from pathlib import Path
import csv
import re
from tqdm.notebook import tqdm
from collections import defaultdict
import random
from prettytable import PrettyTable

In [3]:
FILE = Path("data/Sentiment140/train.csv")
with open(FILE, "r", encoding='latin-1') as f:
    texts = []
    reader = csv.reader(f)
    for row in reader:
        texts.append(row[5])

In [4]:
texts

["@switchfoot http://twitpic.com/2y1zl - Awww, that's a bummer.  You shoulda got David Carr of Third Day to do it. ;D",
 "is upset that he can't update his Facebook by texting it... and might cry as a result  School today also. Blah!",
 '@Kenichan I dived many times for the ball. Managed to save 50%  The rest go out of bounds',
 'my whole body feels itchy and like its on fire ',
 "@nationwideclass no, it's not behaving at all. i'm mad. why am i here? because I can't see you all over there. ",
 '@Kwesidei not the whole crew ',
 'Need a hug ',
 "@LOLTrish hey  long time no see! Yes.. Rains a bit ,only a bit  LOL , I'm fine thanks , how's you ?",
 "@Tatiana_K nope they didn't have it ",
 '@twittera que me muera ? ',
 "spring break in plain city... it's snowing ",
 'I just re-pierced my ears ',
 "@caregiving I couldn't bear to watch it.  And I thought the UA loss was embarrassing . . . . .",
 '@octolinz16 It it counts, idk why I did either. you never talk to me anymore ',
 "@smarrison i wo

In [5]:
def process_tweet(tweet):
    tweet = tweet.lower()
    tweet = re.sub("\@\w*", "", tweet)
    tweet = re.sub("http\S*", "", tweet)
    tweet = re.sub("www.\S*", "", tweet)
    tweet = re.sub('[^A-Za-z ]+', '', tweet)
    return tweet.split()

In [6]:
def get_word_counts(tweets):
    all_words = defaultdict(int)
    for tweet in tqdm(tweets):
        for word in process_tweet(tweet):
            all_words[word] = all_words[word] + 1
    return all_words

In [7]:
word_freq = get_word_counts(texts)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1600000.0), HTML(value='')))




In [8]:
len(word_freq)

417929

In [9]:
def sort_words(word_dict):
    occurance = np.array([v for v in word_dict.values()])
    sorted_idx = np.argsort(occurance)[::-1]
    words = [k for k in word_dict.keys()]
    sorted_words = [(words[idx], occurance[idx]) for idx in sorted_idx]
    return sorted_words

In [10]:
sw = sort_words(word_freq)

In [11]:
VOCAB_SIZE = 10000
top_words = [word for word, count in sw[:VOCAB_SIZE]]
vocab = {word: i for i, word in enumerate(top_words)}
vocab["a"], vocab["the"]

(3, 2)

In [12]:
class WordTransform:
    """Returns the first n randomly chossen letters of the word. 20 % of the time also randomly jumbles words
    by switching letters from nearby keys on the keyboard."""
    
    nearest_keys = {
        "a": list("qwsxz"),
        "b": list("vfghn"),
        "c": list("xsdfv"),
        "d": list("swerfvcx"),
        "e": list("wsdfr"),
        "f": list("dertgbvc"),
        "g": list("frtzhnbv"),
        "h": list("gtzujmnb"),
        "i": list("ujklo"),
        "j": list("hyuikmn"),
        "k": list("juiolmn"),
        "l": list("iopkm"),
        "m": list("njkl"),
        "n": list("bghjm"),
        "o": list("iklp"),
        "p": list("ol"),
        "q": list("aws"),
        "r": list("edfgt"),
        "s": list("aqwedcxz"),
        "t": list("rfghz"),
        "u": list("yhjki"),
        "v": list("cdfgb"),
        "w": list("qasde"),
        "x": list("zasdc"),
        "y": list("tghju"),
        "z": list("asx")
    }
        
    def __call__(self, word):
        size = len(word)
        jumbled_word = word
        if size > 1:
            start = 1 + int(size*0.2)
            size = np.random.choice(np.arange(start, size))
            if np.random.choice([True, False], p=[0.2, 0.8]):
                switch_idxs = np.random.choice(size - 1, int(0.25 * size), replace=False) + 1
                letters = list(word)
                for idx in switch_idxs:
                    letters[idx] = np.random.choice(self.nearest_keys[letters[idx]] + ["", "", "", ""])
                jumbled_word = "".join(letters)
        return jumbled_word[:size] if size else ""

In [13]:
word_transform = WordTransform()

In [14]:
class Letters(Dataset):
    
    letter_encode = {letter: num for num, letter in enumerate("abcdefghijklmnopqrstuvwxyz")}
    
    def __init__(self, words, vocab, transform, size=15):
        super().__init__()
        self.words = words
        self.vocab = vocab
        self.transform = transform
        self.size = size
        self.pad_char = len(self.letter_encode)
        self.chars = list(self.letter_encode.keys())
        
    def __len__(self):
        return len(self.words)
        
    def __getitem__(self, idx):
        word = self.words[idx]
        jumbled_word = self.transform(word)
        return self.encode(jumbled_word), self.vocab[word]
        
    def encode(self, word):
        X = torch.ones(self.size, dtype=torch.long) * self.pad_char
        end = min(self.size, len(word))
        for i, letter in enumerate(word[:end]):
            X[i] = self.letter_encode[letter]
        return X
    
    def decode(self, vector):
        letters = []
        for v in vector:
            if int(v) < len(self.chars):
                letters.append(self.chars[int(v)])
        return "".join(letters)
    
    def lookup(self, idx):
        return self.words[idx]

In [15]:
letters_top = Letters(top_words[:2000], vocab, word_transform)
letters_all = Letters(top_words, vocab, word_transform)

In [16]:
j, w = letters_top[1000]

In [17]:
letters_top.decode(j), letters_top.lookup(w)

('wha', 'whatever')

In [18]:
train_small = DataLoader(letters_top, batch_size=32, shuffle=True)
train_all = DataLoader(letters_all, batch_size=32, shuffle=True)

In [19]:
class CharModel(nn.Module):
    def __init__(self, vocab_size, embedding_size, seq_len):
        super().__init__()
        self.em_size = embedding_size * seq_len
        self.em = nn.Embedding(vocab_size, embedding_size)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(self.em_size, vocab_size)
        
    def forward(self, x):
        x = self.relu(self.em(x))
        return self.fc(x.view(-1, self.em_size))

In [20]:
model = CharModel(VOCAB_SIZE + 1, 8, 15)
crit = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0)
count_parameters(model, True)

+-----------+------------+
|  Modules  | Parameters |
+-----------+------------+
| em.weight |   80008    |
| fc.weight |  1200120   |
|  fc.bias  |   10001    |
+-----------+------------+
Total Trainable Params: 1290129


1290129

In [21]:
train_loss, train_acc, test_loss, test_acc = train(model, crit, opt, train_small, metric=top_n_accuracy, n_epochs=40)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:1, T Loss:9.778, T Met:0.000


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:2, T Loss:8.028, T Met:0.005


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:3, T Loss:7.450, T Met:0.027


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:4, T Loss:7.029, T Met:0.068


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:5, T Loss:6.638, T Met:0.123


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:6, T Loss:6.287, T Met:0.165


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:7, T Loss:5.948, T Met:0.194


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:8, T Loss:5.569, T Met:0.245


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:9, T Loss:5.284, T Met:0.260


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:10, T Loss:4.994, T Met:0.267


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:11, T Loss:4.743, T Met:0.288


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:12, T Loss:4.505, T Met:0.315


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:13, T Loss:4.310, T Met:0.312


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:14, T Loss:4.088, T Met:0.347


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:15, T Loss:3.976, T Met:0.350


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:16, T Loss:3.819, T Met:0.359


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:17, T Loss:3.701, T Met:0.376


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:18, T Loss:3.576, T Met:0.387


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:19, T Loss:3.514, T Met:0.384


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:20, T Loss:3.412, T Met:0.397


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:21, T Loss:3.328, T Met:0.397


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:22, T Loss:3.256, T Met:0.412


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:23, T Loss:3.143, T Met:0.427


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:24, T Loss:3.149, T Met:0.415


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:25, T Loss:3.085, T Met:0.420


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:26, T Loss:3.025, T Met:0.437


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:27, T Loss:2.990, T Met:0.435


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:28, T Loss:2.965, T Met:0.435


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:29, T Loss:2.969, T Met:0.438


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:30, T Loss:2.920, T Met:0.433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:31, T Loss:2.855, T Met:0.455


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:32, T Loss:2.806, T Met:0.468


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:33, T Loss:2.831, T Met:0.441


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:34, T Loss:2.724, T Met:0.465


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:35, T Loss:2.773, T Met:0.444


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:36, T Loss:2.727, T Met:0.449


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:37, T Loss:2.655, T Met:0.465


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:38, T Loss:2.731, T Met:0.465


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:39, T Loss:2.689, T Met:0.464


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


Epoch:40, T Loss:2.664, T Met:0.456


In [22]:
opt = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0)
train_loss, train_acc, test_loss, test_acc = train(model, crit, opt, train_all, metric=top_n_accuracy, n_epochs=40)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:1, T Loss:14.927, T Met:0.080


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:2, T Loss:9.320, T Met:0.094


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:3, T Loss:6.812, T Met:0.106


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:4, T Loss:5.684, T Met:0.140


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:5, T Loss:5.100, T Met:0.181


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:6, T Loss:4.683, T Met:0.220


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:7, T Loss:4.429, T Met:0.259


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:8, T Loss:4.220, T Met:0.280


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:9, T Loss:4.088, T Met:0.301


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:10, T Loss:3.934, T Met:0.324


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:11, T Loss:3.826, T Met:0.331


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:12, T Loss:3.750, T Met:0.335


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:13, T Loss:3.705, T Met:0.342


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:14, T Loss:3.614, T Met:0.357


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:15, T Loss:3.569, T Met:0.353


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:16, T Loss:3.532, T Met:0.361


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:17, T Loss:3.490, T Met:0.363


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:18, T Loss:3.447, T Met:0.369


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:19, T Loss:3.427, T Met:0.373


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:20, T Loss:3.413, T Met:0.377


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:21, T Loss:3.406, T Met:0.374


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:22, T Loss:3.353, T Met:0.385


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:23, T Loss:3.354, T Met:0.379


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:24, T Loss:3.353, T Met:0.383


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:25, T Loss:3.325, T Met:0.381


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:26, T Loss:3.281, T Met:0.388


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:27, T Loss:3.279, T Met:0.387


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:28, T Loss:3.280, T Met:0.389


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:29, T Loss:3.268, T Met:0.394


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:30, T Loss:3.283, T Met:0.388


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:31, T Loss:3.191, T Met:0.394


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:32, T Loss:3.248, T Met:0.391


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:33, T Loss:3.199, T Met:0.397


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:34, T Loss:3.217, T Met:0.392


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:35, T Loss:3.163, T Met:0.399


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:36, T Loss:3.138, T Met:0.405


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:37, T Loss:3.206, T Met:0.394


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:38, T Loss:3.181, T Met:0.399


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:39, T Loss:3.136, T Met:0.408


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch:40, T Loss:3.184, T Met:0.402


In [28]:
def show_preds(model, letters, ds):
    x, y = next(iter(ds))
    preds = model(x)
    top_preds = top_n(preds)
    print(f"TOP 3 ACCURACY {top_n_accuracy(y, preds)}")
    table = PrettyTable(["WORD", "TYPED", "IN TOP 3", "TOP 3"])
    for sample, label, pred in zip(x, y, top_preds):
        word = letters.lookup(label)
        given = letters.decode(sample)
        top_words = [letters.lookup(p) for p in pred]
        is_correct = word in top_words
        top_words = " ".join(top_words) 
        table.add_row([word, given, is_correct, top_words])
    print(table)

In [29]:
show_preds(model, letters_all, train_small)

TOP 3 ACCURACY 0.34375
+-----------+----------+----------+-----------------------------+
|    WORD   |  TYPED   | IN TOP 3 |            TOP 3            |
+-----------+----------+----------+-----------------------------+
|    time   |   tim    |  False   |      times timing timee     |
|    itll   |   itl    |   True   |       itll itself itt       |
|    her    |    he    |  False   |       hella hee hello       |
|     ok    |    o     |  False   |           ow oy of          |
|   passed  |    pa    |  False   |       pale para palace      |
|   event   |    ev    |  False   |        eve evans ever       |
|  positive |    po    |  False   |     poland pollen polar     |
|   ready   |    re    |  False   |       relay rely relax      |
|  version  |  versi   |   True   |   version versions verizon  |
|    dad    |    da    |  False   |        dallas dah das       |
|    spam   |   spa    |  False   |      sparks span spare      |
|   listen  |   list   |   True   |     lists listed 

In [30]:
def predict(txt):
    txt = txt.lower()
    txt = re.sub('[^A-Za-z ]+', '', txt)
    x = letters_all.encode(txt)
    o = model(x)
    top3 = top_n(o)[0]
    print(" ".join([letters_all.lookup(top.item()) for top in top3]))

In [31]:
predict("buttr")

butter butterflies butterfly


In [27]:
torch.save(model.state_dict(), Path("autocorrect.pt"))