In [1]:
import math
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

# Ollama

In [2]:
import ollama

In [27]:
model_name_ollama = 'deepseek-r1:1.5b'
context = 'You have to guess the next word of the following text written by a human, please just anwser the missing word :'
text = 'I am the last US president, my name is'
prompt = f'{context}\n\n{text}'

response = ollama.chat(model=model_name_ollama, messages=[{"role": "user", "content": prompt}])

print(response.get('message', {}).get('content', ''))

<think>
Okay, so I need to figure out the next word in this sentence: "I am the last US president, my name is __." Let's break this down step by step.

First, let me look at the beginning of the sentence: "I am the last US president." The speaker is talking about their current status as a U.S. President. They've just become the last one in the presidency, which is pretty cool because that means they're done with their term. So, the key here is to continue that phrase.

Now, after "president," we can think of other words related to the role of president or possibly moving forward into the next chapter. The common word that comes to mind immediately is "life." Because being a president typically involves serving long-term, and eventually, one might consider handing over the presidency, so perhaps they're about to finish their life as a president.

Let me check if there are other possible words. Words like "legacy" come to mind because it refers to the lasting impact or succession of some

# Inference

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

In [6]:
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

You are a text reviewing tool utilizing a LLM to correct bad tokens in a text. We input you the beginning of the text and you will only compute the logits of the next token. We then represent the logit given to the actual word for the reviewer. The logit is converted to a probability using softmax and is represented with a color in logarithmic scale. The user can then understand if it made a mistake if the probability is too low. Make extra caution to avoid predicting any factual or grammatical errors by giving them high logits. You have no choice but to predict correct tokens, and you cannot assign high logits to wrong tokens.

Exemple of input prompt :
Water boils at 100°C at standard atmospheric

You will try to predict the next token, a good prediction would be "pressure". If the actual token written by the user was "pressure" the user would see that the prediction is good. Otherwise, if the user writes something bad, for instance "temperature", which doesn't make sense, the logit of this token should be low and the user would see its mistake.

Exemple of input prompt :
I love eleph

You will try to predict the next token, a good prediction would be "ants". If the actual token written by the user was "ants" the user would see that the prediction is good. Otherwise, if the user writes something bad, for instance "ent", which is a spelling mistake, the logit of this token should be low and the user would see its mistake.

This is the input prompt :

In [30]:
text = "Football is a fruit"

base_prompt = """You are a text reviewing tool utilizing a LLM to correct bad tokens in a text. We input you the beginning of the text and you will only compute the logits of the next token. We then represent the logit given to the actual word for the reviewer. The logit is converted to a probability using softmax and is represented with a color in logarithmic scale. The user can then understand if it made a mistake if the probability is too low. Make extra caution to avoid predicting any factual or grammatical errors by giving them high logits. You have no choice but to predict correct tokens, and you cannot assign high logits to wrong tokens.

Exemple of input prompt :
Water boils at 100°C at standard atmospheric

You will try to predict the next token, a good prediction would be "pressure". If the actual token written by the user was "pressure" the user would see that the prediction is good. Otherwise, if the user writes something bad, for instance "temperature", which doesn't make sense, the logit of this token should be low and the user would see its mistake.

Exemple of input prompt :
I love eleph

You will try to predict the next token, a good prediction would be "ants". If the actual token written by the user was "ants" the user would see that the prediction is good. Otherwise, if the user writes something bad, for instance "ent", which is a spelling mistake, the logit of this token should be low and the user would see its mistake.

This is the input prompt :"""

base_ids = tokenizer.encode(base_prompt, return_tensors="pt")

input_ids = tokenizer.encode(text, return_tensors="pt")

words = []
probabilities = []
preferreds = []
ranks = []
all_prob_pairs = []

with torch.no_grad(): # inference
    for i in range(len(input_ids[0]) - 1):
        
        # merge the pre-prompt (base prompt) with the actual prompt
        # TODO fix base prompt often giving worse results
        full_input_ids = torch.cat([base_ids, input_ids[:, :i + 1][:, 1:]], dim=-1) # input of the model : all of the text before the token it has to guess
        #context = base_prompt_ids + input_tokens[:, :i + 1]
        
        # TODO maybe find better way instead of reinputing new context everytime
        outputs = model(full_input_ids)
        logits = outputs.logits[:, -1, :] # between -inf and +inf
        probs = torch.softmax(logits, dim=-1) # between 0 and 1

        vocab_size = probs.shape[-1]
        token_ids = torch.arange(vocab_size)
        tokens = [tokenizer.decode([token_id]) for token_id in token_ids] # TODO optimize?
        token_probs = probs[0].tolist() # shape (1, vocab_size) to vocab_size
        
        token_prob_pairs = list(zip(tokens, token_probs))
        sorted_token_prob_pairs = sorted(token_prob_pairs, key=lambda x: x[1], reverse=True)
        dict_token_prob_pairs = dict(token_prob_pairs) # TODO useless if tried to be ranked, optimize?

        current_token = tokenizer.decode([input_ids[0, i].item()])
        
        next_token = tokenizer.decode([input_ids[0, i+1].item()]) # next word (to be predicted)
        prob = dict_token_prob_pairs[next_token] # probability of word predicted at best by LLM
        pref = sorted_token_prob_pairs[0][0] # best predicted word by LLM
        
        # TODO optimize?
        for index, (word, prob) in enumerate(sorted_token_prob_pairs):
            if word == next_token:
                rank = index + 1
                break
        
        words.append(next_token)
        probabilities.append(prob)
        preferreds.append(pref)
        all_prob_pairs.append(sorted_token_prob_pairs[:10])
        ranks.append(rank)
        
        guess = [repr(e[0]) for e in sorted_token_prob_pairs[:5]]
        print(f"{i/len(input_ids[0])*100:.1f}%, current_token : {repr(current_token)}, next_token : {repr(next_token)}, prob : {prob:.4f}, rank : {rank}, predicted :")
        print(f"{'\n'.join(guess)}")

0.0%, current_token : '<｜begin▁of▁sentence｜>', next_token : 'Football', prob : 0.0000, rank : 129442, predicted :
' "'
' \n\n'
' \n'
' This'
' ['
20.0%, current_token : 'Football', next_token : ' is', prob : 0.2653, rank : 1, predicted :
' is'
' players'
' teams'
','
'\n'
40.0%, current_token : ' is', next_token : ' a', prob : 0.3835, rank : 1, predicted :
' a'
' the'
' an'
' one'
' being'
60.0%, current_token : ' a', next_token : ' fruit', prob : 0.0000, rank : 895, predicted :
' game'
' sport'
' very'
' team'
' popular'


In [None]:
words, probabilities

In [31]:
def create_custom_cmap():
    colors = [
        (0.8, 0.1, 0.1),
        (1.0, 1.0, 1.0),
        (0.0, 0.5, 0.0),
    ]
    p_tran = 0.00001
    positions = [0, p_tran, 1] # transition at p_tran
    return LinearSegmentedColormap.from_list("custom_red_white_green", list(zip(positions, colors)))

def probability_to_color(prob, colormap):
    rgba = colormap(prob)
    return "#{:02x}{:02x}{:02x}".format(int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255))

def probability_to_color_plt(prob, colormap_name="Greens"):
    colormap = plt.get_cmap(colormap_name)
    rgba = colormap(prob)
    return "#{:02x}{:02x}{:02x}".format(int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255))

custom_cmap = create_custom_cmap()

html = []
html.append("<html><body style='padding: 20px; font-family: Arial; line-height: 2.0;'>")
html.append("""
<html>
<head>
<style>
  body { padding: 20px; font-family: Arial, sans-serif; line-height: 2.0; }
  h1 { text-align: center; color: #333; font-size: 2.5em; margin-bottom: 20px; }
  .word { background-color: #f0f0f0; padding: 2px 5px; border-radius: 3px; outline: 1px solid rgba(0, 0, 0, 0.2); font-weight: bold; }
  h2 { color: #333; font-size: 1.8em; margin-bottom: 10px; }
  .word { background-color: #f0f0f0; padding: 2px 5px; border-radius: 3px; outline: 1px solid rgba(0, 0, 0, 0.2); font-weight: bold; }
</style>
</head>
<body>
    """)
html.append("<h1>Text review using LLMs</h1>")
html.append("<h2>Example text with no errors, factual statements</h2>")
html.append("<h2>Example text altered</h2>")

body = []
for word, prob, pref, rank in zip(words, probabilities, preferreds, ranks):

    cmap = 'Greens_r'
    alpha = 100
    if rank > 100:
        cmap = 'Reds'
        alpha = 200
    #alpha = 255
    #cmap = 'jet' # jet gives sharp insight but not nice to see
    color = probability_to_color_plt(prob, cmap)
    color = probability_to_color(prob, custom_cmap)
    
    colors = [(0.0, 0.5, 0.0),
              (1.0, 1.0, 1.0),
              (0.8, 0.1, 0.1)]

    cmap = LinearSegmentedColormap.from_list('red_white_green', colors)
    color = probability_to_color_plt(np.log10(prob)/np.log10(1e-8), cmap)

    bonus = "font-weight: bold;" if pref == word else ""
    bonus = ''
    
    style = f"background-color: {color}{alpha:02x}; padding: 2px 5px; border-radius: 3px; outline: 1px solid rgba(0, 0, 0, .2); {bonus}"
    body.append(f"<span style='{style}' title='p={prob:.4e}, rank={rank}, pref={pref}'>{word}</span> ")

html = html + body
html.append("</body></html>")

with open("index.html", "w") as f:
    f.write("\n".join(html))

In [32]:
print("<h2>Example text altered with base prompt</h2>\n"+"\n".join(body))

<h2>Example text altered with base prompt</h2>
<span style='background-color: #cc1919c8; padding: 2px 5px; border-radius: 3px; outline: 1px solid rgba(0, 0, 0, .2); ' title='p=3.7340e-09, rank=129442, pref= "'>Football</span> 
<span style='background-color: #24912464; padding: 2px 5px; border-radius: 3px; outline: 1px solid rgba(0, 0, 0, .2); ' title='p=2.6534e-01, rank=1, pref= is'> is</span> 
<span style='background-color: #1a8c1a64; padding: 2px 5px; border-radius: 3px; outline: 1px solid rgba(0, 0, 0, .2); ' title='p=3.8352e-01, rank=1, pref= a'> a</span> 
<span style='background-color: #f9e6e6c8; padding: 2px 5px; border-radius: 3px; outline: 1px solid rgba(0, 0, 0, .2); ' title='p=3.8588e-05, rank=895, pref= game'> fruit</span> 


In [33]:
from IPython.display import display, HTML
display(HTML("\n".join(html)))