In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

In [16]:
import math

In [2]:
def load_model(model_id="Qwen/Qwen3-0.6B", dtype=torch.float16):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=dtype
    )
    model.eval()
    return tokenizer, model

In [3]:
tokenizer, model = load_model()
model.tokenizer = tokenizer  # decode에 사용하기 위해 tokenizer 주입

In [4]:
def rgb_ansi(r, g, b):
    return f"\033[38;2;{r};{g};{b}m"

def bg_rgb_ansi(r, g, b):
    return f"\033[48;2;{r};{g};{b}m"  # 배경색

def color_by_prob(text, prob):
    """
    확률(prob) 값에 따라 색상을 입힌 문자열 반환
    높은 확률일수록 초록색, 낮을수록 빨간색 계열
    """
    if prob >= 0.5:
        color = "\033[92m"  # 밝은 초록 (high prob)
    elif prob >= 0.2:
        color = "\033[93m"  # 밝은 노랑 (medium prob)
    else:
        color = "\033[91m"  # 밝은 빨강 (low prob)

    reset = "\033[0m"
    return f"{color}{text}{reset}"

def get_yellow_to_green_ansi(prob):
    """
    확률에 따라 진한 노랑(255,255,0) → 초록(0,255,0) 색상 매핑
    """
    steps = 10
    index = min(int(prob * steps), steps - 1)

    r_start, g, b = 255, 255, 0  # 시작 색 (노랑)
    r_end = 0                   # 끝 색 (초록)
    
    # R을 선형적으로 감소
    r = int(r_start - (r_start - r_end) * (index / (steps - 1)))
    
    return rgb_ansi(r, g, b)

def get_yellow_to_green_bg_ansi(prob):
    steps = 10
    index = min(int(prob * steps), steps - 1)
    r_start, g, b = 255, 255, 0
    r = int(r_start - (r_start * index / (steps - 1)))
    return bg_rgb_ansi(r, g, b)

def get_yellow_to_green_bg_ansi(prob):
    steps = 10
    index = min(int(prob * steps), steps - 1)
    r_start, g, b = 255, 255, 0
    r = int(r_start - (r_start * index / (steps - 1)))
    return bg_rgb_ansi(r, g, b)

def get_yellow_to_brightblue_bg_ansi(prob):
    """
    prob: 0.0(노랑) → 1.0(밝은 파랑) 배경색 10단계 반환
    밝은 파랑: #6699FF (R=102,G=153,B=255)
    노랑: #FFFF00 (R=255,G=255,B=0)
    """
    steps = 10
    index = min(int(prob * steps), steps - 1)
    
    r_start, g_start, b_start = 255, 255, 0       # 노랑
    r_end, g_end, b_end = 102, 153, 255           # 밝은 파랑
    
    r = int(r_start + (r_end - r_start) * index / (steps - 1))
    g = int(g_start + (g_end - g_start) * index / (steps - 1))
    b = int(b_start + (b_end - b_start) * index / (steps - 1))
    
    return bg_rgb_ansi(r, g, b)

def get_yellow_to_lightblue_bg_ansi(prob):
    """
    prob: 0.0 (#E5E200 노랑) → 1.0 (#008FE6 파랑) 배경색 10단계 선형 보간
    """
    steps = 10
    index = min(int(prob * steps), steps - 1)

    r_start, g_start, b_start = 229, 226, 0     # #E5E200 노랑
    r_end, g_end, b_end = 0, 143, 230           # #008FE6 파랑

    r = int(r_start + (r_end - r_start) * index / (steps - 1))
    g = int(g_start + (g_end - g_start) * index / (steps - 1))
    b = int(b_start + (b_end - b_start) * index / (steps - 1))

    return bg_rgb_ansi(r, g, b)


def bg_rgb_ansi(r, g, b):
    return f"\033[48;2;{r};{g};{b}m"

In [5]:
def print_top_tokens(top_tokens, field_width=20):
    """Top-5 토큰 정보를 가운데 정렬 + 배경색 노랑→초록, 글자색 기본 유지"""

    formatted = []
    for _, token_str, prob in top_tokens:
        token_display = f"{repr(token_str)} ({prob:.4f})"
        padded = f"{token_display:^{field_width}}"   # 먼저 순수 텍스트로 정렬
        bg_color = get_yellow_to_lightblue_bg_ansi(prob) # 외부 정의 함수 호출
        colored = f"{bg_color}{padded}\033[0m"       # 색상 코드 감싸기
        formatted.append(colored)

    print("== Top-5 Tokens ==")
    print(" | ".join(formatted))
    print()

In [6]:
def generate_next_token(model, input_ids):
    """다음 토큰과 top-5 후보 반환"""
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits

    next_token_logits = logits[0, -1, :]
    probs = F.softmax(next_token_logits, dim=-1)

    # Top-5 토큰 추출
    top_probs, top_ids = torch.topk(probs, k=5)
    top_tokens = [(token_id.item(), token_str, top_probs[i].item())
                  for i, token_id in enumerate(top_ids)
                  for token_str in [model.tokenizer.decode(token_id)]]

    return top_tokens

def run_inference_loop(tokenizer, model, input_text, max_new_tokens=20):
    """전체 생성 루프 실행"""
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
    generated = input_ids

    print(f"Input: {input_text}\n")

    for step in range(max_new_tokens):
        top_tokens = generate_next_token(model, generated)
        print_top_tokens(top_tokens)

        # 다음 토큰 선택 (top-1 기반, greedy)
        next_token_id = torch.tensor([[top_tokens[0][0]]], device=model.device)
        generated = torch.cat([generated, next_token_id], dim=1)

        if next_token_id.item() == tokenizer.eos_token_id:
            break

    output_text = tokenizer.decode(generated[0], skip_special_tokens=True)
    print("== Final Output ==")
    print(output_text)


In [8]:
def compute_perplexity_from_logits(tokenizer, model, input_text):
    """logits를 사용하여 perplexity 직접 계산"""
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(input_text, return_tensors="pt")
        input_ids = inputs.input_ids.to(model.device)

        # 모델 예측 logits
        outputs = model(input_ids)
        logits = outputs.logits  # [batch_size, seq_len, vocab_size]

        # 정답 label은 input_ids를 한 칸 오른쪽으로 이동
        shift_logits = logits[:, :-1, :]         # [batch, seq-1, vocab]
        shift_labels = input_ids[:, 1:]          # [batch, seq-1]

        # CrossEntropy 계산: log_softmax 후 정답 토큰의 log prob 추출
        log_probs = F.log_softmax(shift_logits, dim=-1)
        shift_labels = shift_labels.unsqueeze(-1)  # [batch, seq-1, 1]

        # gather로 정답 토큰의 log prob 추출
        token_log_probs = log_probs.gather(dim=-1, index=shift_labels).squeeze(-1)  # [batch, seq-1]

        # 평균 negative log likelihood
        nll = -token_log_probs.mean()
        perplexity = torch.exp(nll)

        print(f"NLL: {nll.item():.4f}")
        print(f"Perplexity: {perplexity.item():.4f}")
        return perplexity.item()

In [44]:
# run_inference_loop(tokenizer, model, "What is 2 + 2? Answer briefly.")

In [41]:
def generate_with_self_confidence(tokenizer, model, input_text, max_new_tokens=20, field_width=20):
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
    generated = input_ids
    log_probs = []

    print(f"Input: {input_text}\n")

    for step in range(max_new_tokens):
        outputs = model(generated)
        logits = outputs.logits[:, -1, :]  # 마지막 토큰의 logits

        log_softmax = F.log_softmax(logits, dim=-1)
        next_token_id = torch.argmax(log_softmax, dim=-1)
        next_log_prob = log_softmax[0, next_token_id]
        prob = next_log_prob.exp().item()

        log_probs.append(next_log_prob.item())
        avg_nll = -sum(log_probs) / len(log_probs)
        avg_prob = math.exp(-avg_nll)

        # 토큰 문자열 (repr로 이스케이프 처리)
        token_str = repr(tokenizer.decode([next_token_id.item()]))

        # 배경색 코드 생성 (PPL_Prob 즉 avg_prob 기준)
        bg_color = get_yellow_to_lightblue_bg_ansi(avg_prob)

        # 토큰 문자열에 색상과 정렬 적용
        padded = f"{token_str:^{field_width}}"
        colored_token = f"{bg_color}{padded}\033[0m"

        print(
            f"[{step:02d}] Token: {colored_token} | "
            f"LogProb: {next_log_prob.item():>8.4f}  Prob: {prob:>7.4f}  | "
            f"AvgNLL: {avg_nll:.4f}  PPL_Prob: {avg_prob:.4f}"
        )

        generated = torch.cat([generated, next_token_id.unsqueeze(0)], dim=1)
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    print("\n== Self Perplexity ==")
    print(f"Avg NLL: {avg_nll:.4f}")
    print(f"Perplexity (self): {avg_prob:.4f}")

    return avg_prob

In [43]:
generate_with_self_confidence(tokenizer, model, "What is 2 + 2? Answer briefly.", 
                              max_new_tokens=20,
                              field_width=14)

Input: What is 2 + 2? Answer briefly.

[00] Token: [48;2;178;207;51m     ' '      [0m | LogProb:  -1.5977  Prob:  0.2024  | AvgNLL: 1.5977  PPL_Prob: 0.2024
[01] Token: [48;2;152;198;76m     '2'      [0m | LogProb:  -0.3132  Prob:  0.7310  | AvgNLL: 0.9554  PPL_Prob: 0.3846
[02] Token: [48;2;127;189;102m     ' +'     [0m | LogProb:  -0.2515  Prob:  0.7778  | AvgNLL: 0.7208  PPL_Prob: 0.4864
[03] Token: [48;2;101;179;127m     ' '      [0m | LogProb:  -0.0073  Prob:  0.9927  | AvgNLL: 0.5424  PPL_Prob: 0.5814
[04] Token: [48;2;76;170;153m     '2'      [0m | LogProb:  -0.0457  Prob:  0.9551  | AvgNLL: 0.4431  PPL_Prob: 0.6421
[05] Token: [48;2;101;179;127m    ' is'     [0m | LogProb:  -0.9570  Prob:  0.3840  | AvgNLL: 0.5287  PPL_Prob: 0.5894
[06] Token: [48;2;76;170;153m     ' '      [0m | LogProb:  -0.1348  Prob:  0.8740  | AvgNLL: 0.4724  PPL_Prob: 0.6235
[07] Token: [48;2;76;170;153m     '4'      [0m | LogProb:  -0.0252  Prob:  0.9751  | AvgNLL: 0.4165  PPL_Prob: 0.659

0.5937385863244323

In [45]:
generate_with_self_confidence(tokenizer, model, "What is 4 + 4? Answer briefly.", 
                              max_new_tokens=20,
                              field_width=14)

Input: What is 4 + 4? Answer briefly.

[00] Token: [48;2;178;207;51m     ' '      [0m | LogProb:  -1.3672  Prob:  0.2549  | AvgNLL: 1.3672  PPL_Prob: 0.2548
[01] Token: [48;2;127;189;102m     '4'      [0m | LogProb:  -0.3164  Prob:  0.7290  | AvgNLL: 0.8418  PPL_Prob: 0.4309
[02] Token: [48;2;101;179;127m     ' +'     [0m | LogProb:  -0.2198  Prob:  0.8027  | AvgNLL: 0.6345  PPL_Prob: 0.5302
[03] Token: [48;2;76;170;153m     ' '      [0m | LogProb:  -0.0069  Prob:  0.9932  | AvgNLL: 0.4776  PPL_Prob: 0.6203
[04] Token: [48;2;76;170;153m     '4'      [0m | LogProb:  -0.0273  Prob:  0.9731  | AvgNLL: 0.3875  PPL_Prob: 0.6787
[05] Token: [48;2;76;170;153m    ' is'     [0m | LogProb:  -0.6938  Prob:  0.4998  | AvgNLL: 0.4386  PPL_Prob: 0.6449
[06] Token: [48;2;76;170;153m     ' '      [0m | LogProb:  -0.1329  Prob:  0.8755  | AvgNLL: 0.3949  PPL_Prob: 0.6737
[07] Token: [48;2;50;161;178m     '8'      [0m | LogProb:  -0.0122  Prob:  0.9878  | AvgNLL: 0.3471  PPL_Prob: 0.7067

0.6022058247076545