In [1]:
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel
from nltk.corpus import wordnet
import nltk
nltk.download('wordnet')
import torch.nn.functional as F
import torch
import argparse
import re
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package wordnet to /home/amv458/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
print(f'[device {device} is ready]')

[device cuda:0 is ready]


In [3]:
def build_dictionaries(train_file="words_250000_train.txt"):
    text_file = open(train_file, "r")
    train_dictionary = text_file.read().splitlines()
    text_file.close()
    train_set = set(train_dictionary)
    test_dictionary = []
    for synset in wordnet.all_synsets():
        for lemma in synset.lemmas():
            word = lemma.name().lower()
            if word not in train_set and word.isalpha() :
                test_dictionary.append(word)
    return train_dictionary, test_dictionary


def mean_pooling(token_embeddings, attention_mask):
    """ Average pooling layer for the Transformer.
    """
    token_embeddings = token_embeddings.to(attention_mask.device)
    input_mask_expanded=(attention_mask.unsqueeze(-1).float())
    sum_embeddings=torch.sum(token_embeddings * input_mask_expanded, -2)
    sum_mask=torch.clamp(input_mask_expanded.sum(-2), min=1e-9)
    return sum_embeddings / sum_mask

class HangmanTransformer(torch.nn.Module):
    def __init__(self,
            max_seq_len,
            model_dim,
            latent_dim,
            tokenizer,
            device=torch.device('cpu'),
            num_heads=1,
            num_layers=1,
            vocab_size=28):
        super().__init__()
        bert = BertModel.from_pretrained('bert-base-uncased')
        bert.config.hidden_size=model_dim
        bert.config.vocab_size=vocab_size
        bert.config.num_attention_heads=num_heads
        bert.config.num_hidden_layers=num_layers
        bert.config.pad_token_id=27
        bert.config.intermediate_size=latent_dim
        bert.config.max_position_embeddings=max_seq_len
        self.bert_model = BertModel(bert.config)
        self.classifier = torch.nn.Linear(model_dim, 26)
        self.config = bert.config
        self.guessed_letters = []
        self.alphabet = np.array([chr(97+i) for i in range(26)])
        self.tokenizer = tokenizer
        self.to(device)
        self.device = device
        
    def flush(self):
        """ Resets the Hangman game.
        """
        self.guessed_letters = []
        
    def guess(self, pattern):
        if len(pattern)>self.tokenizer.max_seq_len:
            """ Less than 0.1% of all training words have
                length larger than 20, so we "ignore" those.
            """
            letters_argsort = range(26)
        else:
            inp = self.tokenizer(pattern)
            inp['input_ids'] = inp['input_ids'].unsqueeze(0)
            inp['attention_mask'] = inp['attention_mask'].unsqueeze(0)
            out = F.softmax(self(inp), dim=-1).squeeze()
            letters_argsort = sorted(range(26), key=lambda i: out[i], reverse=True)
        letters = self.alphabet[letters_argsort]
        for letter in letters:
            if letter not in self.guessed_letters:
                self.guessed_letters.append(letter)
                return letter
    
    def forward(self, inp):
        attention_mask = inp['attention_mask'][:, None, None, :].to(self.device)
        input_ids = inp['input_ids'].to(self.device)
        x = self.bert_model.embeddings(input_ids)
        x = self.bert_model.encoder(x, attention_mask=attention_mask)
        x = mean_pooling(x['last_hidden_state'], inp['attention_mask'])
        x = self.classifier(x)
        return x

In [4]:
class HangmanTokenizer():
    def __init__(self, max_seq_len, device):
        self.max_seq_len = max_seq_len
        self.special_idxs = {
            'pad_idx': 27,
            'dot_idx': 26}
        self.device=device
        
    def __call__(self, pattern, word=None):
        """ Tokenizes a given pattern (word).
        """
        inp = torch.zeros(self.max_seq_len)
        attn_mask = torch.zeros(self.max_seq_len)
        fill_mask = torch.zeros(26)
        label = torch.zeros(26)
        for j in range(len(pattern)):
            if pattern[j].isalpha():
                inp[j] = ord(pattern[j])-97
                fill_mask[ord(pattern[j])-97] = -1e9
            else:
                inp[j] = self.special_idxs['dot_idx']
                if word is not None:
                    label[ord(word[j])-97] = 1
                    y = torch.LongTensor([ord(word[j])-97])
                else:
                    y = torch.LongTensor([0])
            attn_mask[j] = 1
        for j in range(len(pattern), self.max_seq_len):
            inp[j] = self.special_idxs['pad_idx']
        out = {'input_ids': inp.long().to(self.device),
               'attention_mask': attn_mask.long().to(self.device),
               'fill_mask': fill_mask.float().to(self.device),
               'label': label.float().to(self.device),
               'y': y.to(self.device)}
        return out
        

class HangmanDatasetStage1(Dataset):
    """ This dataset realizes masking of a
        single letter in a word (for Stage 1).
    """
    def __init__(self, dictionary, tokenizer):
        self.dictionary = []
        for word in dictionary:
            if len(word) <= tokenizer.max_seq_len:
                self.dictionary.append(word)
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.dictionary)
        
    def __getitem__(self, i):
        word = self.dictionary[i]
        idx = np.random.choice(range(len(word)))
        pattern = ""
        for i,letter in enumerate(word):
            if i!=idx:
                pattern += letter
            else:
                pattern += '.'
        return self.tokenizer(pattern, word=word)


class HangmanDatasetStage2(Dataset):
    """ This dataset realizes masking of all instances of
        letters in a random subcollection of
        letters in a word (for Stage 2).
    """
    def __init__(self, dictionary, tokenizer):
        self.dictionary = []
        for word in dictionary:
            if len(word) <= tokenizer.max_seq_len:
                self.dictionary.append(word)
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.dictionary)
        
    def __getitem__(self, i):
        word = self.dictionary[i]
        word_letters = set(list(word))
        if len(word_letters)>1:
            num_open = np.random.choice(range(1,len(word_letters)))
            open_letters = np.random.choice(
                list(word_letters),
                size=num_open,
                replace=False)
            open_letters = set(open_letters.tolist())
        else:
            open_letters = set([])
        pattern = ""
        for letter in word:
            if letter in open_letters:
                pattern += letter
            else:
                pattern += '.'
        return self.tokenizer(pattern, word=word)
        

def collate_fn(batch):
    return {'input_ids': torch.vstack([inp['input_ids'] for inp in batch]),
        'attention_mask': torch.vstack([inp['attention_mask'] for inp in batch]),
        'fill_mask': torch.vstack([inp['fill_mask'] for inp in batch]),
        'label': torch.vstack([inp['label'] for inp in batch]),
        'y': torch.cat([inp['y'] for inp in batch])}

In [5]:
train_dictionary, test_dictionary = build_dictionaries()
tokenizer = HangmanTokenizer(max_seq_len=20, device=device)
dataset = HangmanDatasetStage1(train_dictionary, tokenizer)
dataloader = DataLoader(
    dataset,
    batch_size=512,
    shuffle=True,
    collate_fn=collate_fn)
model = HangmanTransformer(
    max_seq_len=20,
    model_dim=512,
    latent_dim=1024,
    tokenizer=tokenizer,
    device=device,
    num_heads=8,
    num_layers=12,
    vocab_size=26+len(tokenizer.special_idxs))

### Stage 1: Pre-training
In this stage, we pre-train ```HangmanTransformer``` using "masked word modeling". We use a learning rate linear warmup for 10 epochs followed by exponential decay. Gamma is computed such that the final learning rate is 50 times smaller than the maximum one (reached at the end of warmup).

In [6]:
gamma = (1./50)**(1./75)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=2e-4,
    weight_decay=1e-5)
scheduler_warmup = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1.0,
    total_iters=10)
scheduler_decay = torch.optim.lr_scheduler.ExponentialLR(
    optimizer,
    gamma=gamma)
schedulers = [scheduler_warmup, scheduler_decay]
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers,
    milestones=[10])
criterion = torch.nn.CrossEntropyLoss(reduction='mean')

for epoch in range(75):
    losses = []
    acc = 0.
    for batch in dataloader:
        optimizer.zero_grad()
        y = batch['y']
        out = model(batch)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        acc += sum(torch.argmax(out, dim=-1) == y)
    acc = acc/len(dataset)
    lr = optimizer.param_groups[0]['lr']
    print(f'[epoch: {epoch}][loss/acc: {np.mean(losses):.4f}/{acc:.4f}][lr: {lr:.5f}]')
    scheduler.step()
    
state_dict = model.state_dict()
#torch.save(state_dict, 'model-checkpoint-stage-1.pt')

[epoch: 0][loss/acc: 2.7703/0.1515][lr: 0.00002]
[epoch: 1][loss/acc: 2.5372/0.2112][lr: 0.00004]
[epoch: 2][loss/acc: 2.1394/0.3147][lr: 0.00006]
[epoch: 3][loss/acc: 1.9439/0.3678][lr: 0.00007]
[epoch: 4][loss/acc: 1.8411/0.3988][lr: 0.00009]
[epoch: 5][loss/acc: 1.7687/0.4224][lr: 0.00011]
[epoch: 6][loss/acc: 1.7089/0.4420][lr: 0.00013]
[epoch: 7][loss/acc: 1.6662/0.4560][lr: 0.00015]
[epoch: 8][loss/acc: 1.6250/0.4692][lr: 0.00016]
[epoch: 9][loss/acc: 1.5897/0.4804][lr: 0.00018]




[epoch: 10][loss/acc: 1.5631/0.4893][lr: 0.00020]
[epoch: 11][loss/acc: 1.5200/0.5028][lr: 0.00019]
[epoch: 12][loss/acc: 1.4760/0.5167][lr: 0.00018]
[epoch: 13][loss/acc: 1.4392/0.5287][lr: 0.00017]
[epoch: 14][loss/acc: 1.4043/0.5396][lr: 0.00016]
[epoch: 15][loss/acc: 1.3779/0.5496][lr: 0.00015]
[epoch: 16][loss/acc: 1.3494/0.5572][lr: 0.00015]
[epoch: 17][loss/acc: 1.3220/0.5665][lr: 0.00014]
[epoch: 18][loss/acc: 1.2979/0.5751][lr: 0.00013]
[epoch: 19][loss/acc: 1.2754/0.5818][lr: 0.00013]
[epoch: 20][loss/acc: 1.2505/0.5903][lr: 0.00012]
[epoch: 21][loss/acc: 1.2280/0.5970][lr: 0.00011]
[epoch: 22][loss/acc: 1.2083/0.6036][lr: 0.00011]
[epoch: 23][loss/acc: 1.1871/0.6107][lr: 0.00010]
[epoch: 24][loss/acc: 1.1705/0.6169][lr: 0.00010]
[epoch: 25][loss/acc: 1.1582/0.6208][lr: 0.00009]
[epoch: 26][loss/acc: 1.1375/0.6271][lr: 0.00009]
[epoch: 27][loss/acc: 1.1179/0.6323][lr: 0.00008]
[epoch: 28][loss/acc: 1.1092/0.6359][lr: 0.00008]
[epoch: 29][loss/acc: 1.0948/0.6408][lr: 0.00007]


### Stage 2: Finetuning
Now that we pre-trained a model to recognize word structures and grammar, we finetune it on a more realistic data distribution (with more than one blanks in words) as generated by ```HangmanDatasetStage2```. Now, the objective function is to guess all of the letters that are masked out in the input, so we use a multi-label BCE loss. In addition, we utilize the fact that none of the revealed letters can be masked and effectively "mask" them out in the output by using ```fill_mask``` returned by the tokenizer. To pick the best model during training, we perform validation using a hold-out dictionary.

In [6]:
def validate(model, dictionary, verbose=False, n=100):
    """ Hangman game simulator
    """
    successes = 0.
    words = np.random.choice(dictionary, size=n, replace=False)
    for word in tqdm(words):
        model.flush()
        pattern = ['.']*len(word)
        if verbose:
            print(f'[GAME]: {"".join(pattern)}')
        missed = 0
        while True:
            letter = model.guess("".join(pattern))
            guess_correct = False
            for i in range(len(word)):
                if word[i] == letter:
                    pattern[i] = letter
                    guess_correct = True
            if not guess_correct:
                missed += 1
            if missed == 6:
                if verbose:
                    print(f'[FAIL][word: {word}]')
                break
            elif '.' not in pattern:
                successes += 1
                if verbose:
                    print(f'[OK!][word: {word}]')
                break
            else:
                if verbose:
                    print(f'-[{6-missed}]: {"".join(pattern)} ({letter})')
    return successes/n

In [10]:
state_dict = torch.load('model-checkpoint-stage-1.pt')
model.load_state_dict(state_dict)
gamma = (1./50)**(1./50)
optimizer = torch.optim.Adam(
    [{'params': model.bert_model.parameters(), 'lr': 1e-5},
    {'params': model.classifier.parameters(), 'lr': 1e-4}],
    lr=1e-4)
scheduler_warmup = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1.0,
    total_iters=10)
scheduler_decay = torch.optim.lr_scheduler.ExponentialLR(
    optimizer,
    gamma=gamma)
schedulers = [scheduler_warmup, scheduler_decay]
criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers,
    milestones=[10])
dataset = HangmanDatasetStage2(train_dictionary, tokenizer)
dataloader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    collate_fn=collate_fn)

best_model = {
    'accuracy': 0,
    'model': None}
for epoch in range(50):
    losses = []
    if epoch % 5 == 0:
        model.eval()
        acc = validate(model, test_dictionary, n=100)
        if best_model['accuracy'] < acc:
            best_model['model'] = model.state_dict()
        torch.save(best_model['model'], 'model-checkpoint-stage-2.pt')
        model.train()
    for batch in dataloader:
        optimizer.zero_grad()
        out = model(batch)
        out = out + batch['fill_mask']
        loss = criterion(out, batch['label'])
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    scheduler.step()
    print(f'[epoch: {epoch}][loss/acc: {np.mean(losses):.4f}/{acc:.4f}]')
    
model.eval()
acc = validate(model, test_dictionary, n=200)
if best_model['accuracy'] < acc:
    best_model['model'] = model.state_dict()
#torch.save(best_model['model'], 'model-checkpoint-stage-2.pt')

100%|██████████| 100/100 [00:12<00:00,  8.10it/s]


[epoch: 0][loss/acc: 0.3371/0.3100]
[epoch: 1][loss/acc: 0.2823/0.3100]
[epoch: 2][loss/acc: 0.2726/0.3100]
[epoch: 3][loss/acc: 0.2676/0.3100]
[epoch: 4][loss/acc: 0.2651/0.3100]


100%|██████████| 100/100 [00:13<00:00,  7.61it/s]


[epoch: 5][loss/acc: 0.2617/0.6000]
[epoch: 6][loss/acc: 0.2598/0.6000]
[epoch: 7][loss/acc: 0.2583/0.6000]
[epoch: 8][loss/acc: 0.2574/0.6000]
[epoch: 9][loss/acc: 0.2554/0.6000]


100%|██████████| 100/100 [00:12<00:00,  7.78it/s]


[epoch: 10][loss/acc: 0.2546/0.5500]
[epoch: 11][loss/acc: 0.2532/0.5500]
[epoch: 12][loss/acc: 0.2525/0.5500]
[epoch: 13][loss/acc: 0.2513/0.5500]
[epoch: 14][loss/acc: 0.2513/0.5500]


100%|██████████| 100/100 [00:12<00:00,  7.79it/s]


[epoch: 15][loss/acc: 0.2506/0.5200]
[epoch: 16][loss/acc: 0.2490/0.5200]
[epoch: 17][loss/acc: 0.2493/0.5200]
[epoch: 18][loss/acc: 0.2487/0.5200]
[epoch: 19][loss/acc: 0.2485/0.5200]


100%|██████████| 100/100 [00:12<00:00,  7.87it/s]


[epoch: 20][loss/acc: 0.2482/0.6200]
[epoch: 21][loss/acc: 0.2480/0.6200]
[epoch: 22][loss/acc: 0.2478/0.6200]
[epoch: 23][loss/acc: 0.2473/0.6200]
[epoch: 24][loss/acc: 0.2467/0.6200]


100%|██████████| 100/100 [00:12<00:00,  7.97it/s]


[epoch: 25][loss/acc: 0.2474/0.6200]
[epoch: 26][loss/acc: 0.2470/0.6200]
[epoch: 27][loss/acc: 0.2471/0.6200]
[epoch: 28][loss/acc: 0.2465/0.6200]
[epoch: 29][loss/acc: 0.2465/0.6200]


100%|██████████| 100/100 [00:12<00:00,  8.10it/s]


[epoch: 30][loss/acc: 0.2458/0.6100]
[epoch: 31][loss/acc: 0.2462/0.6100]
[epoch: 32][loss/acc: 0.2457/0.6100]
[epoch: 33][loss/acc: 0.2460/0.6100]
[epoch: 34][loss/acc: 0.2458/0.6100]


100%|██████████| 100/100 [00:12<00:00,  8.09it/s]


[epoch: 35][loss/acc: 0.2459/0.5800]
[epoch: 36][loss/acc: 0.2459/0.5800]
[epoch: 37][loss/acc: 0.2463/0.5800]
[epoch: 38][loss/acc: 0.2461/0.5800]
[epoch: 39][loss/acc: 0.2455/0.5800]


100%|██████████| 100/100 [00:12<00:00,  7.83it/s]


[epoch: 40][loss/acc: 0.2445/0.6900]
[epoch: 41][loss/acc: 0.2455/0.6900]
[epoch: 42][loss/acc: 0.2455/0.6900]
[epoch: 43][loss/acc: 0.2452/0.6900]
[epoch: 44][loss/acc: 0.2451/0.6900]


100%|██████████| 100/100 [00:12<00:00,  7.99it/s]


[epoch: 45][loss/acc: 0.2454/0.5600]
[epoch: 46][loss/acc: 0.2456/0.5600]
[epoch: 47][loss/acc: 0.2451/0.5600]
[epoch: 48][loss/acc: 0.2449/0.5600]
[epoch: 49][loss/acc: 0.2453/0.5600]


100%|██████████| 200/200 [00:25<00:00,  7.83it/s]


## Verbose Evaluation

In [12]:
tokenizer = HangmanTokenizer(max_seq_len=20, device=device)
final_model = HangmanTransformer(
    max_seq_len=20,
    model_dim=512,
    latent_dim=1024,
    tokenizer=tokenizer,
    device=device,
    num_heads=8,
    num_layers=12,
    vocab_size=26+len(tokenizer.special_idxs))
state_dict = torch.load('model-checkpoint-stage-2.pt')
final_model.load_state_dict(state_dict)
final_model.eval()
print(f'[Model ready]')

[Model ready]


In [13]:
# Multiple eval.

N = 500
word = np.random.choice(test_dictionary, size=N, replace=False)
acc = validate(final_model, word, verbose=False, n=N)
print(f'[accuracy: {acc:.3f}]')

100%|██████████| 500/500 [01:03<00:00,  7.84it/s]

[accuracy: 0.624]





In [14]:
# Single eval.

word = np.random.choice(test_dictionary)
acc = validate(final_model, [word], verbose=True, n=1)

100%|██████████| 1/1 [00:00<00:00,  9.09it/s]

[GAME]: .......
-[5]: ....... (e)
-[4]: ....... (a)
-[4]: .....r. (r)
-[4]: ...s.r. (s)
-[4]: .o.sor. (o)
-[4]: .onsor. (n)
-[4]: consor. (c)
[OK!][word: consort]



