In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def load_model_and_tokenizer(model_name, device):
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

def predict_next_word(text, model, tokenizer, top_k, device):
    input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model(input_ids)
        predictions = outputs.logits
    predicted_index = torch.argmax(predictions[0, -1, :]).item()
    predicted_token = tokenizer.decode(predicted_index)
    predicted_indices = torch.topk(predictions[0, -1, :], top_k).indices
    predicted_tokens = [tokenizer.decode(idx.item()) for idx in predicted_indices]
    predicted_scores = torch.softmax(predictions[0, -1, :], dim=-1)
    predicted_probabilities = [predicted_scores[idx].item() for idx in predicted_indices]
    return {
        'input_text': text,
        'input_ids': input_ids.tolist()[0],
        'predicted_next_word': predicted_token,
        'top_predictions': list(zip(predicted_tokens, predicted_probabilities))
    }

def display_predictions(predictions):
    print("\nInput Text: ", predictions['input_text'])
    print("\nToken Predictions:")
    for token, score in predictions['top_predictions']:
        print(f"{token} : {score:.2f}")
    print("\nPredicted Next Word:", predictions['predicted_next_word'])

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model_name = "microsoft/phi-1"
    model, tokenizer = load_model_and_tokenizer(model_name, device)

    while True:
        text = input("Enter a sentence (or type 'exit' to quit): ").strip()
        if text.lower() == "exit":
            break
        predictions = predict_next_word(text, model, tokenizer, top_k=5, device=device)
        display_predictions(predictions)

Using device: cuda


Enter a sentence (or type 'exit' to quit):  the quick brown


  attn_output = torch.nn.functional.scaled_dot_product_attention(
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)



Input Text:  the quick brown

Token Predictions:
 fox : 0.95
" : 0.01
 jumps : 0.01
", : 0.01
', : 0.00

Predicted Next Word:  fox


Enter a sentence (or type 'exit' to quit):  hello



Input Text:  hello

Token Predictions:
 world : 0.34
") : 0.14
", : 0.13
" : 0.10
_ : 0.08

Predicted Next Word:  world


Enter a sentence (or type 'exit' to quit):  exit
