In [1]:
import pickle
import torch
from transformers import BertTokenizer, BertModel
import torch.nn as nn

In [2]:
qw_to_idx= {'how long': 0, 'how': 1, 'whom': 2, 'how many': 3, 'when': 4, 'whose': 5, 'what': 6, 'who': 7, 'where': 8, 'why': 9, 'how much': 10, 'which': 11}
idx_to_qw = inverted_dict = {value: key for key, value in qw_to_idx.items()}
qw_to_idx


{'how long': 0,
 'how': 1,
 'whom': 2,
 'how many': 3,
 'when': 4,
 'whose': 5,
 'what': 6,
 'who': 7,
 'where': 8,
 'why': 9,
 'how much': 10,
 'which': 11}

In [3]:
# Model
class QWPModel(nn.Module):
    def __init__(self, num_classes: int, tokenizer):
        super(QWPModel, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.bert.resize_token_embeddings(len(tokenizer))  # Resize for <qw>
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.pooler_output  # [CLS] token
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits


In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    tokenizer.add_tokens(["<qw>"])
    model = QWPModel(num_classes=len(qw_to_idx), tokenizer=tokenizer)
    model.load_state_dict(torch.load("best_qwp_model.pth"))
    model.to(device)
    model.eval()

    while True:
        sample_question = input("Masked question: ")
        if sample_question == "exit":
            break
        sample_answer = input("Answer: ")
        input_text = f"{sample_question} [SEP] {sample_answer}"
        encoding = tokenizer(
            input_text,
            max_length=128,
            padding="max_length",
            truncation=True,
            return_tensors="pt")

        input_ids = encoding["input_ids"].squeeze().unsqueeze(0).to(device)
        attention_mask = encoding["attention_mask"].squeeze().unsqueeze(0).to(device)


        logits = model(input_ids, attention_mask)
        preds = torch.argmax(logits, dim=1)

        print(f"Input: {input_text}\n QW: {idx_to_qw[int(preds[0].cpu())]}")

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Input: <qw> is the capital of Sweden? [SEP] Stockholm
 QW: what
Input: <qw> are you? [SEP] fine
 QW: what
