In [1]:
import numpy as np
from tqdm.notebook import trange, tqdm
from transformers import BertTokenizer
from torch.nn import CrossEntropyLoss

import torch
from torch.utils.tensorboard import SummaryWriter

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

print(torch.cuda.current_device())

# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

There are 1 GPU(s) available.
We will use the GPU: GeForce GTX 1080
0
Loading BERT tokenizer...


In [2]:
import random
import string

def char2finger(c):
    c2f = {
        'q':1, 'a':1, 'z':1,
        'w':2, 's':2, 'x':2,
        'e':3, 'd':3, 'c':3,
        'r':4, 'f':4, 'v':4,
        't':4, 'g':4, 'b':4,
        'y':5, 'h':5, 'n':5,
        'u':5, 'j':5, 'm':5,
        'i':6, 'k':6,
        'o':7, 'l':7,
        'p':8,
    }
    if c == ' ':
        return 1 #unused 0
    if c in c2f:
        return 1+c2f[c] #unused 1 - 8
    return 10 #unused 9

def char2label(ch):
  c2l = {c: i for i, c in enumerate(string.ascii_lowercase+' ')}
  if ch in c2l:
    return c2l[ch]
  else:
    return len(c2l)

def label2char(ii):
  l2c = {i: c for i, c in enumerate(string.ascii_lowercase+' ')}
  if ii in l2c:
    return l2c[ii]
  else:
    return '*'

special_tokens_count = tokenizer.num_added_tokens()+1
max_seq_length = 256
ignore_label_id = -100
pad_token = 0

def strToSample(content, debug = False):
    tokens = content.split()
    #we randomly select the start index of typing
    #and give 0 more chance
    typing_start = random.choice(
        list(range(len(tokens)))+[0]*2)
    typing_start = 0
    #the pre context of a sample
    pre_tokens = tokens[:typing_start]
    pre_tokens = tokenizer.tokenize(' '.join(pre_tokens))

    typing_text = ' '.join(tokens[typing_start:])
    typing_seq = [char2finger(c) for c in typing_text]

    #if typing seq is longer than max seq
    if len(typing_seq) > max_seq_length - special_tokens_count:
        typing_text = typing_text[:(max_seq_length - special_tokens_count)]
        typing_text = ' '.join(typing_text.split()[:-1])
        typing_seq = [char2finger(c) for c in typing_text]
        pre_tokens = []

    #else if typing+token is longer than max seq
    extra = len(pre_tokens)+len(typing_seq)-\
            (max_seq_length - special_tokens_count)
    if extra > 0:
        pre_tokens = pre_tokens[extra:]

    # The sample format:
    # [precontext] what is your [typing] k e y
    # [CLS] token_id token_id token_id [SEP] finger_id finger_id finger_id [SEP]

    pre_ids = tokenizer.convert_tokens_to_ids(['[CLS]']+pre_tokens+['[SEP]'])
    input_ids = pre_ids+typing_seq+tokenizer.convert_tokens_to_ids(['[SEP]'])

    label_ids = len(pre_ids)*[ignore_label_id]+\
            [char2label(c) for c in typing_text]+\
            [ignore_label_id]

    segment_ids = len(pre_ids)*[0]+(len(typing_text)+1)*[1]
    input_mask = len(label_ids)*[1]

    padding_len = max_seq_length - len(input_ids)
    input_ids += [pad_token]*padding_len
    input_mask += [pad_token]*padding_len
    segment_ids += [pad_token]*padding_len
    label_ids += [ignore_label_id]*padding_len

    if debug:
        print('typing text: %s' % typing_text)
        print("tokens: ", " ".join([str(x) for x in pre_tokens]))
        print("pre_ids: ", " ".join([str(x) for x in pre_ids]))
        print("input_ids: ", " ".join([str(x) for x in input_ids]))
        print("input_mask: ", " ".join([str(x) for x in input_mask]))
        print("segment_ids: ", " ".join([str(x) for x in segment_ids]))
        print("label_ids: ", " ".join([str(x) for x in label_ids]))

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length
    assert len(label_ids) == max_seq_length

    return input_ids, input_mask, segment_ids, label_ids

In [3]:
pretrained_path = 'Models/local/Bestcheckpoint-1370000/'
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    BertForTokenClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

config = AutoConfig.from_pretrained(
        pretrained_path,
        num_labels=len(string.ascii_lowercase+' ')+1,
        cache_dir=None,
    )

model = BertForTokenClassification.from_pretrained(
    pretrained_path,
    config=config,)

model.to(device)
model.eval()

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [4]:
testcase = strToSample("elephants are afr")
testcase = [torch.LongTensor(t) for t in testcase]
testcase = [t.unsqueeze(0) for t in testcase]
testcase = tuple(t.to(device) for t in testcase)
inputs = {"input_ids": testcase[0], "attention_mask": testcase[1],
                      "token_type_ids": testcase[2], "labels": testcase[3]}

In [6]:
tn = np.array(testcases).transpose((1,0,2))
tn = torch.LongTensor(tn)
tn = tn.to(device)
inputs = {"input_ids": tn[0], "attention_mask": tn[1],
                      "token_type_ids": tn[2], "labels": tn[3]}

In [57]:
# beam search for transposition/omission error correction
import torch.nn.functional as F

fingerlist = 'asdfjklp'

def searchInsertionNTranspositions(searchlists):
    testcases = np.array([strToSample(seq) for seq in searchlists])
    testcases = testcases.transpose((1,0,2)) # (batch, item, input) to (item, batch, input)
    testcases = torch.LongTensor(testcases).to(device)
    inputs = {"input_ids": testcases[0], "attention_mask": testcases[1],
                              "token_type_ids": testcases[2], "labels": testcases[3]}
    outputs = model(**inputs)
    tmp_eval_loss, logits = outputs[:2]
    probs = F.softmax(logits, dim=-1).detach().cpu().numpy()
    out_label_ids = inputs["labels"].detach().cpu().numpy()
    return probs, out_label_ids

def searchAlternatives(inputseqs):
    problists = []
    labellists = []
    
    searchlists = []
    for seq in inputseqs:
        searchlists += [seq]
        if len(seq) == len(inputseqs[0]):
            searchlists += [seq[:-1]+c+seq[-1] for c in fingerlist]
        if len(seq) > 1 and ' ' not in seq[-3:]:
            searchlists.append(seq[:-2]+seq[-1]+seq[-2])
        searchlists = list(dict.fromkeys(searchlists))
    for i in range(0, len(searchlists), 20):
        probs, out_label_ids = searchInsertionNTranspositions(searchlists[i:i+20])
        if len(problists) == 0:
            problists = probs
            labellists = out_label_ids
        else:
            problists = np.concatenate((problists, probs), axis=0)
            labellists = np.concatenate((labellists, out_label_ids), axis=0)
    probsmax = np.max(problists, axis=-1)
    preds = np.argmax(problists, axis=-1)
    valid_probs = probsmax * (labellists != ignore_label_id)
    avg_probs = []
    for i in range(len(valid_probs)):
        probsum = valid_probs[i].sum()
        avg_probs.append(probsum / (valid_probs[i] > 0).sum())
    avg_probs = np.array(avg_probs)
    avg_probs[1:] -= 0.03
    top_indices = (-avg_probs).argsort()[:10]
    #always the original sequence (the 0-index)
    top_indices = [0] + list(top_indices[top_indices!=0])
    print(valid_probs.shape, avg_probs.shape, searchlists[0], top_indices, (valid_probs[0] > 0).sum())
    return preds[top_indices], labellists[top_indices]


In [15]:
preds, label_ids = searchAlternatives(['informaion', 'informalialn', 'informailn', 'informacialn', 'information'])

In [16]:
res = []
for i in range(len(label_ids)):
    resid = preds[i][ label_ids[i] != ignore_label_id ]
    res.append(''.join([label2char(c) for c in resid]))
res

['informaion',
 'information',
 'informaion',
 'informacialy',
 'informailly',
 'informailly',
 'informailty',
 'informailty',
 'informaliano',
 'informailin']

In [56]:
typing = 'i enjoy playng infomation games in my sare tiem'
alter = ['']
for i in range(len(typing)):        
    alter = [typing[:i]]+alter
    alter = [phr+typing[i] for phr in alter]
    alter = list(dict.fromkeys(alter))
    if typing[i] == ' ':
        continue
    preds, label_ids = searchAlternatives(alter)
    alter = []
    for j in range(len(label_ids)):
        resid = preds[j][ label_ids[j] != ignore_label_id ]
        decoded = ''.join([label2char(c) for c in resid])
        if len(decoded) <= i+2:
            alter.append(decoded)
            if len(alter) >= 5:
                break
    print(f'step {i}: {typing[:i+1]} res {alter}')

(9, 256) (9,) i [0, 7, 1, 8, 6, 2, 5, 3, 4] 1
step 0: i res ['i', 'ok', 'ai', 'pi', 'ii']
(13, 256) (13,) i e [0, 4, 1, 2, 6, 11, 8, 7, 10, 5] 3
step 2: i e res ['i d', 'i ve', 'i ad', 'i we', 'i id']
(26, 256) (26,) i en [0, 1, 10, 16, 7, 15, 6, 4, 13, 19, 22] 4
step 3: i en res ['i en', 'i can', 'i can', 'i don', 'i don']
(14, 256) (14,) i enj [0, 6, 4, 10, 11, 1, 8, 9, 3, 2] 5
step 4: i enj res ['i dun', 'i chin', 'i envy', 'i damn', 'i damn']
(26, 256) (26,) i enjo [0, 24, 5, 15, 21, 4, 14, 22, 20, 25, 8] 6
step 5: i enjo res ['i duno', 'i canno', 'i dunno', 'i dunno', 'i emily']
(25, 256) (25,) i enjoy [0, 10, 16, 6, 19, 9, 5, 15, 20, 22] 7
step 6: i enjoy res ['i enjoy', 'i enjoy', 'i enjoin', 'i enjoin', 'i dunno']
(19, 256) (19,) i enjoy p [0, 8, 1, 11, 7, 15, 18, 10, 9, 5] 9
step 8: i enjoy p res ['i enjoy p', 'i enjoy pp', 'i enjoy ap', 'i dunno ap', 'i enjoy lp']
(16, 256) (16,) i enjoy pl [0, 9, 4, 5, 8, 11, 1, 13, 7, 15] 10
step 9: i enjoy pl res ['i enjoy po', 'i enjoy po

(21, 256) (21,) i enjoy playng infomation games in my sare t [0, 18, 10, 1, 19, 7, 16, 20, 6, 15] 44
step 43: i enjoy playng infomation games in my sare t res ['i enjoy playug intonation games in my save b', 'i enjoy playug infonation games in my paste b', 'i enjoy playug intonation games in my safe at', 'i enjoy playug intonation games in my safe at', 'i enjoy playug intonation games in my waite b']
(22, 256) (22,) i enjoy playng infomation games in my sare ti [0, 20, 18, 21, 8, 17, 19, 5, 14, 6] 45
step 44: i enjoy playng infomation games in my sare ti res ['i enjoy playug infonation games in my save ti', 'i enjoy playug intonation games in my safe air', 'i enjoy playug infonation games in my paste ti', 'i enjoy playug infonation games in my waite ti', 'i enjoy playug infonation games in my safe gpi']
(28, 256) (28,) i enjoy playng infomation games in my sare tie [0, 16, 6, 17, 7, 13, 3, 5, 15, 8] 46
step 45: i enjoy playng infomation games in my sare tie res ['i enjoy playug infonat

In [72]:
import json
# Reading data back
with open('dict.json', 'r') as f:
    seqdict = json.load(f)
    
def beam_search(data, k):
    sequences = [[[], 0]] #sequence, prob
    for item in data:
        sump = np.sum(np.exp(item))
        eprob = np.log([np.exp(p)/sump for p in item])
        new_seq = []
        for sequence in sequences:
            seq, prob = sequence
            new_seq += [[seq+[idx], prob-p] for idx, p in enumerate(eprob)]
        new_seq = sorted(new_seq, key=lambda tup:tup[1])
        sequences = new_seq[:k]
    return sequences
H
preds = logits.detach().cpu().numpy().squeeze()
st, ed = words[1][0], words[1][1]
seqs = beam_search(preds[st:ed], 20)
tapseq = ''.join([str(n) for n in input_ids[st:ed]])
print(seqdict[tapseq])

for s in seqs:
    print (''.join([label2char(idx) for idx in s[0]]))

[['are', 106744], ['age', 5186], ['abc', 685], ['ate', 408], ['arc', 352], ['ave', 75], ['abe', 60], ['atc', 58], ['ard', 55], ['qtc', 45], ['afc', 24], ['agc', 21], ['abd', 20]]
are
ate
age
ard
arc
abe
afe
ave
ane
ale
aee
qre
ame
aoe
aue
awe
ahe
ade
ace
axe


In [59]:
with torch.no_grad():     
  outputs = model(**inputs)
  print(outputs[1].shape)
  tmp_eval_loss, logits = outputs[:2]

  preds = logits.detach().cpu().numpy().squeeze()[1]
  out_label_ids = inputs["labels"].detach().cpu().numpy().squeeze()
  segment_ids = inputs["token_type_ids"].detach().cpu().numpy().squeeze()
  input_ids = inputs["input_ids"].detach().cpu().numpy().squeeze()

# words = []
# wordstart = -1
# for i in range(segment_ids.shape[0]):
#     if segment_ids[i] != 0:
#         if input_ids[i] in [1, 102]:
#             words.append([wordstart, i])
#             wordstart = -1
#         elif wordstart == -1:
#             wordstart = i

# #beam search for each word
# beam_size = 50
# print(preds[words[0][0]])

preds = np.argmax(preds, axis=1)
preds_list = []
for j in range(out_label_ids.shape[0]):
  if out_label_ids[j] != ignore_label_id:
    preds_list.append(label2char(preds[j]))

print(''.join(preds_list))

torch.Size([2, 256, 28])
(256, 28)
best


In [74]:
testcases = []
wrong = []

with open('phrases2.txt') as f:
    lines = f.readlines()
    for line in lines:
        testcases.append(line.strip().lower())

for case in testcases:
    testcase = strToSample(case)
    testcase = [torch.LongTensor(t) for t in testcase]
    testcase = [t.unsqueeze(0) for t in testcase]
    testcase = tuple(t.to(device) for t in testcase)
    inputs = {"input_ids": testcase[0], "attention_mask": testcase[1],
                          "token_type_ids": testcase[2], "labels": testcase[3]}
    with torch.no_grad():     
      outputs = model(**inputs)
      tmp_eval_loss, logits = outputs[:2]

      preds = logits.detach().cpu().numpy().squeeze()
      out_label_ids = inputs["labels"].detach().cpu().numpy().squeeze()

    preds = np.argmax(preds, axis=1)
    preds_list = []
    for j in range(out_label_ids.shape[0]):
      if out_label_ids[j] != ignore_label_id:
        preds_list.append(label2char(preds[j]))

    res = ''.join(preds_list)
    if res != ' '.join(case.split()):
        wrong.append([case, res])
        

In [68]:
print(wrong)

[['i can see the rings on saturn', 'i can see the rings on satutn'], ['elephants are afraid of mice', 'elephants are afraid of nice'], ['if at first you do not succeed', 'it at first you do not succeed'], ['please provide your date of birth', 'please provide your care or birth'], ['we run the risk of failure', 'we buy the risk of failure'], ['beware the ides of march', 'beware the ices or march'], ['double double toil and trouble', 'double double foil and trouble'], ['play it again sam', 'play it again say'], ['you are not a jedi yet', 'you are not a heck yet'], ['starlight and dewdrop', 'starlight and desdtop'], ['drove my chevy to the levee', 'drove my ehevy to the legee'], ['but the levee was dry', 'but the leved was dry'], ['the quick brown fox jumped', 'the quick brown box jumped'], ['there will be some fog tonight', 'there will be some for tonight'], ['the dow jones index has risen', 'the cow jones index has risen'], ['we are subjects and must obey', 'we are subjects and just obe