In [24]:
import torch
from transformers import BertTokenizer, BertForMaskedLM

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define a function to predict next words
def predict_next_words(text, top_k=10):
    # Tokenize the input text
    tokenized_text = tokenizer.tokenize(text)
    
    # Add [CLS] and [SEP] tokens
    tokenized_text = ["[CLS]"] + tokenized_text + ["[SEP]"]
    
    # Convert tokens to IDs
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    
    # Mask the last token
    masked_index = len(indexed_tokens) - 1
    indexed_tokens[masked_index] = tokenizer.mask_token_id
    
    # Convert indexed tokens to tensor
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)

    # Generate predictions
    with torch.no_grad():
        outputs = model(tokens_tensor)

    # Get the predicted probabilities for the masked token
    predictions = outputs[0]
    predicted_probabilities = predictions[0, masked_index].cpu()

    # Get top-k predicted tokens
    top_k_probabilities, top_k_indices = predicted_probabilities.topk(top_k)

    # Convert token IDs back to words
    top_k_words = tokenizer.convert_ids_to_tokens(top_k_indices.tolist())

    return top_k_words



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [28]:
# Example usage
text = "machine learning and"
next_words = predict_next_words(text)
print("Next words:", next_words)


Next words: ['science', 'engineering', 'its', 'design', 'the', 'application', '-', 'and', 'analysis', 'development']
