In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

In [4]:
def load_model(model_id="Qwen/Qwen3-0.6B", dtype=torch.float16):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=dtype
    )
    model.eval()
    return tokenizer, model

In [5]:
tokenizer, model = load_model()
model.tokenizer = tokenizer  # decode에 사용하기 위해 tokenizer 주입

In [6]:
def rgb_ansi(r, g, b):
    return f"\033[38;2;{r};{g};{b}m"

def bg_rgb_ansi(r, g, b):
    return f"\033[48;2;{r};{g};{b}m"  # 배경색

def color_by_prob(text, prob):
    """
    확률(prob) 값에 따라 색상을 입힌 문자열 반환
    높은 확률일수록 초록색, 낮을수록 빨간색 계열
    """
    if prob >= 0.5:
        color = "\033[92m"  # 밝은 초록 (high prob)
    elif prob >= 0.2:
        color = "\033[93m"  # 밝은 노랑 (medium prob)
    else:
        color = "\033[91m"  # 밝은 빨강 (low prob)

    reset = "\033[0m"
    return f"{color}{text}{reset}"

def get_yellow_to_green_ansi(prob):
    """
    확률에 따라 진한 노랑(255,255,0) → 초록(0,255,0) 색상 매핑
    """
    steps = 10
    index = min(int(prob * steps), steps - 1)

    r_start, g, b = 255, 255, 0  # 시작 색 (노랑)
    r_end = 0                   # 끝 색 (초록)
    
    # R을 선형적으로 감소
    r = int(r_start - (r_start - r_end) * (index / (steps - 1)))
    
    return rgb_ansi(r, g, b)

def get_yellow_to_green_bg_ansi(prob):
    steps = 10
    index = min(int(prob * steps), steps - 1)
    r_start, g, b = 255, 255, 0
    r = int(r_start - (r_start * index / (steps - 1)))
    return bg_rgb_ansi(r, g, b)

def get_yellow_to_green_bg_ansi(prob):
    steps = 10
    index = min(int(prob * steps), steps - 1)
    r_start, g, b = 255, 255, 0
    r = int(r_start - (r_start * index / (steps - 1)))
    return bg_rgb_ansi(r, g, b)

def get_yellow_to_brightblue_bg_ansi(prob):
    """
    prob: 0.0(노랑) → 1.0(밝은 파랑) 배경색 10단계 반환
    밝은 파랑: #6699FF (R=102,G=153,B=255)
    노랑: #FFFF00 (R=255,G=255,B=0)
    """
    steps = 10
    index = min(int(prob * steps), steps - 1)
    
    r_start, g_start, b_start = 255, 255, 0       # 노랑
    r_end, g_end, b_end = 102, 153, 255           # 밝은 파랑
    
    r = int(r_start + (r_end - r_start) * index / (steps - 1))
    g = int(g_start + (g_end - g_start) * index / (steps - 1))
    b = int(b_start + (b_end - b_start) * index / (steps - 1))
    
    return bg_rgb_ansi(r, g, b)

def get_yellow_to_lightblue_bg_ansi(prob):
    """
    prob: 0.0 (#E5E200 노랑) → 1.0 (#008FE6 파랑) 배경색 10단계 선형 보간
    """
    steps = 10
    index = min(int(prob * steps), steps - 1)

    r_start, g_start, b_start = 229, 226, 0     # #E5E200 노랑
    r_end, g_end, b_end = 0, 143, 230           # #008FE6 파랑

    r = int(r_start + (r_end - r_start) * index / (steps - 1))
    g = int(g_start + (g_end - g_start) * index / (steps - 1))
    b = int(b_start + (b_end - b_start) * index / (steps - 1))

    return bg_rgb_ansi(r, g, b)


def bg_rgb_ansi(r, g, b):
    return f"\033[48;2;{r};{g};{b}m"

In [7]:
def print_top_tokens(top_tokens, field_width=20):
    """Top-5 토큰 정보를 가운데 정렬 + 배경색 노랑→초록, 글자색 기본 유지"""

    formatted = []
    for _, token_str, prob in top_tokens:
        token_display = f"{repr(token_str)} ({prob:.4f})"
        padded = f"{token_display:^{field_width}}"   # 먼저 순수 텍스트로 정렬
        bg_color = get_yellow_to_lightblue_bg_ansi(prob) # 외부 정의 함수 호출
        colored = f"{bg_color}{padded}\033[0m"       # 색상 코드 감싸기
        formatted.append(colored)

    print("== Top-5 Tokens ==")
    print(" | ".join(formatted))
    print()

In [8]:
def generate_next_token(model, input_ids):
    """다음 토큰과 top-5 후보 반환"""
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits

    next_token_logits = logits[0, -1, :]
    probs = F.softmax(next_token_logits, dim=-1)

    # Top-5 토큰 추출
    top_probs, top_ids = torch.topk(probs, k=5)
    top_tokens = [(token_id.item(), token_str, top_probs[i].item())
                  for i, token_id in enumerate(top_ids)
                  for token_str in [model.tokenizer.decode(token_id)]]

    return top_tokens

def run_inference_loop(tokenizer, model, input_text, max_new_tokens=20):
    """전체 생성 루프 실행"""
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
    generated = input_ids

    print(f"Input: {input_text}\n")

    for step in range(max_new_tokens):
        top_tokens = generate_next_token(model, generated)
        print_top_tokens(top_tokens)

        # 다음 토큰 선택 (top-1 기반, greedy)
        next_token_id = torch.tensor([[top_tokens[0][0]]], device=model.device)
        generated = torch.cat([generated, next_token_id], dim=1)

        if next_token_id.item() == tokenizer.eos_token_id:
            break

    output_text = tokenizer.decode(generated[0], skip_special_tokens=True)
    print("== Final Output ==")
    print(output_text)


In [9]:
run_inference_loop(tokenizer, model, "Is the Moon Earth's satellite?")

Input: Is the Moon Earth's satellite?

== Top-5 Tokens ==
[48;2;152;198;76m  ' Yes' (0.3269)   [0m | [48;2;203;216;25m  ' What' (0.1384)  [0m | [48;2;229;226;0m   ' If' (0.0664)   [0m | [48;2;229;226;0m  ' And' (0.0449)   [0m | [48;2;229;226;0m   ' Is' (0.0432)   [0m

== Top-5 Tokens ==
[48;2;101;179;127m    ',' (0.5190)    [0m | [48;2;203;216;25m    '.' (0.1880)    [0m | [48;2;229;226;0m  '.\n\n' (0.0725)  [0m | [48;2;229;226;0m  '\n\n' (0.0630)   [0m | [48;2;229;226;0m   '\n' (0.0530)    [0m

== Top-5 Tokens ==
[48;2;76;170;153m  ' the' (0.6934)   [0m | [48;2;229;226;0m   ' it' (0.0828)   [0m | [48;2;229;226;0m  ' and' (0.0472)   [0m | [48;2;229;226;0m  ' but' (0.0345)   [0m | [48;2;229;226;0m' because' (0.0234) [0m

== Top-5 Tokens ==
[48;2;0;143;230m  ' Moon' (0.9229)  [0m | [48;2;229;226;0m ' Earth' (0.0505)  [0m | [48;2;229;226;0m  ' moon' (0.0164)  [0m | [48;2;229;226;0m  ' Sun' (0.0016)   [0m | [48;2;229;226;0m  'Moon' (0.0003)   [0m

== 