In [20]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, Gemma3ForCausalLM

In [21]:
import math
import json

In [22]:
def load_model(
    model_id="google/gemma-3-1b-it",
    dtype=torch.float16,
    load_in_8bit=True,
):
    quantization_config = BitsAndBytesConfig(load_in_8bit=load_in_8bit)

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = Gemma3ForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=dtype,
        quantization_config=quantization_config if load_in_8bit else None,
    )

    # <end_of_turn> 토큰을 eos_token으로 등록 (토크나이저에 없으면 추가)
    if '<end_of_turn>' not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({'eos_token': '<end_of_turn>'})
        model.resize_token_embeddings(len(tokenizer))
    else:
        # 이미 있으면 eos_token으로 지정만 해줌
        tokenizer.eos_token = '<end_of_turn>'

    model.eval()
    return tokenizer, model

In [23]:
tokenizer, model = load_model()
model.tokenizer = tokenizer  # decode에 사용하기 위해 tokenizer 주입

In [24]:
def rgb_ansi(r, g, b):
    return f"\033[38;2;{r};{g};{b}m"

def bg_rgb_ansi(r, g, b):
    return f"\033[48;2;{r};{g};{b}m"  # 배경색

def color_by_prob(text, prob):
    """
    확률(prob) 값에 따라 색상을 입힌 문자열 반환
    높은 확률일수록 초록색, 낮을수록 빨간색 계열
    """
    if prob >= 0.5:
        color = "\033[92m"  # 밝은 초록 (high prob)
    elif prob >= 0.2:
        color = "\033[93m"  # 밝은 노랑 (medium prob)
    else:
        color = "\033[91m"  # 밝은 빨강 (low prob)

    reset = "\033[0m"
    return f"{color}{text}{reset}"

def get_yellow_to_green_ansi(prob):
    """
    확률에 따라 진한 노랑(255,255,0) → 초록(0,255,0) 색상 매핑
    """
    steps = 10
    index = min(int(prob * steps), steps - 1)

    r_start, g, b = 255, 255, 0  # 시작 색 (노랑)
    r_end = 0                   # 끝 색 (초록)
    
    # R을 선형적으로 감소
    r = int(r_start - (r_start - r_end) * (index / (steps - 1)))
    
    return rgb_ansi(r, g, b)

def get_yellow_to_green_bg_ansi(prob):
    steps = 10
    index = min(int(prob * steps), steps - 1)
    r_start, g, b = 255, 255, 0
    r = int(r_start - (r_start * index / (steps - 1)))
    return bg_rgb_ansi(r, g, b)

def get_yellow_to_green_bg_ansi(prob):
    steps = 10
    index = min(int(prob * steps), steps - 1)
    r_start, g, b = 255, 255, 0
    r = int(r_start - (r_start * index / (steps - 1)))
    return bg_rgb_ansi(r, g, b)

def get_yellow_to_brightblue_bg_ansi(prob):
    """
    prob: 0.0(노랑) → 1.0(밝은 파랑) 배경색 10단계 반환
    밝은 파랑: #6699FF (R=102,G=153,B=255)
    노랑: #FFFF00 (R=255,G=255,B=0)
    """
    steps = 10
    index = min(int(prob * steps), steps - 1)
    
    r_start, g_start, b_start = 255, 255, 0       # 노랑
    r_end, g_end, b_end = 102, 153, 255           # 밝은 파랑
    
    r = int(r_start + (r_end - r_start) * index / (steps - 1))
    g = int(g_start + (g_end - g_start) * index / (steps - 1))
    b = int(b_start + (b_end - b_start) * index / (steps - 1))
    
    return bg_rgb_ansi(r, g, b)

def get_yellow_to_lightblue_bg_ansi(prob):
    """
    prob: 0.0 (#E5E200 노랑) → 1.0 (#008FE6 파랑) 배경색 10단계 선형 보간
    """
    steps = 10
    index = min(int(prob * steps), steps - 1)

    r_start, g_start, b_start = 229, 226, 0     # #E5E200 노랑
    r_end, g_end, b_end = 0, 143, 230           # #008FE6 파랑

    r = int(r_start + (r_end - r_start) * index / (steps - 1))
    g = int(g_start + (g_end - g_start) * index / (steps - 1))
    b = int(b_start + (b_end - b_start) * index / (steps - 1))

    return bg_rgb_ansi(r, g, b)


def bg_rgb_ansi(r, g, b):
    return f"\033[48;2;{r};{g};{b}m"

In [25]:
def print_top_tokens(top_tokens, field_width=22):
    """Top-5 토큰 정보를 가운데 정렬 + 배경색 노랑→초록, 글자색 기본 유지"""

    formatted = []
    for _, token_str, prob in top_tokens:
        token_display = f"{repr(token_str)} ({prob:.4f})"
        padded = f"{token_display:^{field_width}}"   # 먼저 순수 텍스트로 정렬
        bg_color = get_yellow_to_lightblue_bg_ansi(prob) # 외부 정의 함수 호출
        colored = f"{bg_color}{padded}\033[0m"       # 색상 코드 감싸기
        formatted.append(colored)

    print("== Top-5 Tokens ==")
    print(" | ".join(formatted))
    print()

In [26]:
def generate_next_token(model, input_ids):
    """다음 토큰과 top-5 후보 반환"""
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits

    next_token_logits = logits[0, -1, :]
    probs = F.softmax(next_token_logits, dim=-1)

    # Top-5 토큰 추출
    top_probs, top_ids = torch.topk(probs, k=5)
    top_tokens = [(token_id.item(), token_str, top_probs[i].item())
                  for i, token_id in enumerate(top_ids)
                  for token_str in [model.tokenizer.decode(token_id)]]

    return top_tokens

def run_inference_loop(tokenizer, model, input_text, max_new_tokens=20):
    """전체 생성 루프 실행"""
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
    generated = input_ids

    print(f"Input: {input_text}\n")

    for step in range(max_new_tokens):
        top_tokens = generate_next_token(model, generated)
        print_top_tokens(top_tokens)

        # 다음 토큰 선택 (top-1 기반, greedy)
        next_token_id = torch.tensor([[top_tokens[0][0]]], device=model.device)
        generated = torch.cat([generated, next_token_id], dim=1)

        if next_token_id.item() == tokenizer.eos_token_id:
            break

    output_text = tokenizer.decode(generated[0], skip_special_tokens=True)
    print("== Final Output ==")
    print(output_text)


In [27]:
def compute_perplexity_from_logits(tokenizer, model, input_text):
    """logits를 사용하여 perplexity 직접 계산"""
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(input_text, return_tensors="pt")
        input_ids = inputs.input_ids.to(model.device)

        # 모델 예측 logits
        outputs = model(input_ids)
        logits = outputs.logits  # [batch_size, seq_len, vocab_size]

        # 정답 label은 input_ids를 한 칸 오른쪽으로 이동
        shift_logits = logits[:, :-1, :]         # [batch, seq-1, vocab]
        shift_labels = input_ids[:, 1:]          # [batch, seq-1]

        # CrossEntropy 계산: log_softmax 후 정답 토큰의 log prob 추출
        log_probs = F.log_softmax(shift_logits, dim=-1)
        shift_labels = shift_labels.unsqueeze(-1)  # [batch, seq-1, 1]

        # gather로 정답 토큰의 log prob 추출
        token_log_probs = log_probs.gather(dim=-1, index=shift_labels).squeeze(-1)  # [batch, seq-1]

        # 평균 negative log likelihood
        nll = -token_log_probs.mean()
        perplexity = torch.exp(nll)

        print(f"NLL: {nll.item():.4f}")
        print(f"Perplexity: {perplexity.item():.4f}")
        return perplexity.item()

In [28]:
# run_inference_loop(tokenizer, model, "What is 2 + 2? Answer briefly.")

In [29]:
def generate_with_self_confidence(tokenizer, model, input_text, max_new_tokens=20, field_width=20):
    import math
    import torch.nn.functional as F

    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
    generated = input_ids
    log_probs = []
    history = []

    print(f"Input: {input_text}\n")

    for step in range(max_new_tokens):
        outputs = model(generated)
        logits = outputs.logits[:, -1, :]  # 마지막 토큰 logits

        log_softmax = F.log_softmax(logits, dim=-1)
        next_token_id = torch.argmax(log_softmax, dim=-1)
        next_log_prob = log_softmax[0, next_token_id]
        prob = next_log_prob.exp().item()

        log_probs.append(next_log_prob.item())
        avg_nll = -sum(log_probs) / len(log_probs)
        avg_prob = math.exp(-avg_nll)

        token_str = repr(tokenizer.decode([next_token_id.item()]))

        # 배경색 적용
        bg_color = get_yellow_to_lightblue_bg_ansi(avg_prob)
        padded = f"{token_str:^{field_width}}"
        colored_token = f"{bg_color}{padded}\033[0m"

        print(
            f"[{step:02d}] Token: {colored_token} | "
            f"LogProb: {next_log_prob.item():>8.4f}  Prob: {prob:>7.4f}  | "
            f"AvgNLL: {avg_nll:.4f}  PPL_Prob: {avg_prob:.4f}"
        )

        history.append({
            "step": step,
            "token_id": next_token_id.item(),
            "token_str": token_str,
            "log_prob": next_log_prob.item(),
            "prob": prob,
            "avg_nll": avg_nll,
            "ppl_prob": avg_prob,
            "colored_token": colored_token,  # 기록도 색칠 토큰 포함
        })

        generated = torch.cat([generated, next_token_id.unsqueeze(0)], dim=1)
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    print("\n== Self Perplexity ==")
    print(f"Avg NLL: {avg_nll:.4f}")
    print(f"Perplexity (self): {avg_prob:.4f}")

    return history


def print_generation_history(history, field_width=20):
    print("\n=== Generation History ===")
    for entry in history:
        # 이미 색칠된 토큰을 기록에서 재사용 (있으면)
        colored_token = entry.get("colored_token")
        if not colored_token:
            # 없으면 새로 색칠 (fallback)
            bg_color = get_yellow_to_lightblue_bg_ansi(entry['ppl_prob'])
            padded = f"{entry['token_str'] :^{field_width}}"
            colored_token = f"{bg_color}{padded}\033[0m"

        print(
            f"[{entry['step']:02d}] Token: {colored_token} | "
            f"LogProb: {entry['log_prob']:8.4f}  Prob: {entry['prob']:7.4f}  | "
            f"AvgNLL: {entry['avg_nll']:.4f}  PPL_Prob: {entry['ppl_prob']:.4f}"
        )

In [30]:
# generate_with_self_confidence(tokenizer, model, 
#                               "What is 4 + 4? Answer briefly.", 
#                               max_new_tokens=50,
#                               field_width=15)

In [31]:
# history = generate_with_self_confidence(tokenizer, model, 
#                                         "What is 2 + 2?", 
#                                         max_new_tokens=20,
#                                         field_width=15)
# print_generation_history(history)

In [32]:
json_file = "../arc-prize-2025/arc-agi_training_challenges.json"

In [33]:
with open(json_file, "r") as f:
    json_data = json.load(f)

In [34]:
for key in json_data:
    print(key, json_data[key])
    break

00576224 {'train': [{'input': [[7, 9], [4, 3]], 'output': [[7, 9, 7, 9, 7, 9], [4, 3, 4, 3, 4, 3], [9, 7, 9, 7, 9, 7], [3, 4, 3, 4, 3, 4], [7, 9, 7, 9, 7, 9], [4, 3, 4, 3, 4, 3]]}, {'input': [[8, 6], [6, 4]], 'output': [[8, 6, 8, 6, 8, 6], [6, 4, 6, 4, 6, 4], [6, 8, 6, 8, 6, 8], [4, 6, 4, 6, 4, 6], [8, 6, 8, 6, 8, 6], [6, 4, 6, 4, 6, 4]]}], 'test': [{'input': [[3, 2], [7, 8]]}]}


In [35]:
def generate_prompt_from_example(key: str, example: dict) -> str:
    prompt = f"The following examples show 2D input arrays and their corresponding output arrays. Each output is generated by applying a specific pattern or rule to the input.\n\n"
    prompt += "- Please study the relationship between each `input` and its corresponding `output` in the `train` examples.\n"
    prompt += "- Then, based on these patterns, predict the `output` for the given `test` input.\n\n"

    for i, sample in enumerate(example.get("train", []), 1):
        prompt += f"### Example {i}\nInput:\n{sample['input']}\n\nOutput:\n{sample['output']}\n\n"

    for i, test_case in enumerate(example.get("test", []), 1):
        prompt += f"---\n\n### Test Input {i}:\n{test_case['input']}\n\n**Task:**\nBased on the `train` examples above, generate the expected `output` array for this test input.\n"

    return prompt

In [36]:
prompt_list = []

for i, (key, example) in enumerate(json_data.items()):
    prompt = generate_prompt_from_example(key, example)
    prompt_list.append((key, prompt))

# 길이 기준 오름차순 정렬
prompt_list.sort(key=lambda x: len(x[1]))

# 출력
for key, prompt in prompt_list[:10]:
    print(f"\n=== Prompt for key {key} (length={len(prompt)}) ===\n")
    print(prompt)


=== Prompt for key 833966f4 (length=675) ===

The following examples show 2D input arrays and their corresponding output arrays. Each output is generated by applying a specific pattern or rule to the input.

- Please study the relationship between each `input` and its corresponding `output` in the `train` examples.
- Then, based on these patterns, predict the `output` for the given `test` input.

### Example 1
Input:
[[9], [0], [1], [6], [8]]

Output:
[[0], [9], [1], [8], [6]]

### Example 2
Input:
[[4], [3], [6], [2], [8]]

Output:
[[3], [4], [6], [8], [2]]

---

### Test Input 1:
[[4], [5], [6], [7], [2]]

**Task:**
Based on the `train` examples above, generate the expected `output` array for this test input.


=== Prompt for key 6150a2bd (length=715) ===

The following examples show 2D input arrays and their corresponding output arrays. Each output is generated by applying a specific pattern or rule to the input.

- Please study the relationship between each `input` and its corresp

In [37]:
results = {}

for i, (key, prompt) in enumerate(prompt_list[:10]):
    print(f"\n--- Prompt {i+1} (key: {key}, length: {len(prompt)}) ---\n")
    
    # 여기에 모델 호출 함수 넣기 (예시로 history 사용)
    history = generate_with_self_confidence(
        tokenizer=tokenizer,
        model=model,
        input_text=prompt,
        max_new_tokens=100,
        field_width=15,
    )
    
    # 결과 저장
    results[key] = history


--- Prompt 1 (key: 833966f4, length: 675) ---

Input: The following examples show 2D input arrays and their corresponding output arrays. Each output is generated by applying a specific pattern or rule to the input.

- Please study the relationship between each `input` and its corresponding `output` in the `train` examples.
- Then, based on these patterns, predict the `output` for the given `test` input.

### Example 1
Input:
[[9], [0], [1], [6], [8]]

Output:
[[0], [9], [1], [8], [6]]

### Example 2
Input:
[[4], [3], [6], [2], [8]]

Output:
[[3], [4], [6], [8], [2]]

---

### Test Input 1:
[[4], [5], [6], [7], [2]]

**Task:**
Based on the `train` examples above, generate the expected `output` array for this test input.


[00] Token: [48;2;25;152;204m   'Output'    [0m | LogProb:  -0.1804  Prob:  0.8350  | AvgNLL: 0.1804  PPL_Prob: 0.8349
[01] Token: [48;2;0;143;230m      ':'      [0m | LogProb:  -0.0014  Prob:  0.9985  | AvgNLL: 0.0909  PPL_Prob: 0.9131
[02] Token: [48;2;0;143;23

KeyboardInterrupt: 

In [None]:
questions = {
    "Easy": [
        "Is the Moon Earth's satellite?",
        "What is 3 + 4?",
        "Is water H2O?",
        "Is the Sun a star?",
        "Are cats mammals?"
    ],
    "Tricky": [
        "If you say “I am lying,” are you telling the truth?",
        # "How many months have 28 days?",
        "If you are your brother's sibling, who are you?",
        # "If a clock is one hour fast after one hour, what time is it now?",
        "In a dark room with a candle, a lantern, and a match, which do you light first?"
    ]
}

results = {}

for difficulty, qs in questions.items():
    print(f"\n--- {difficulty} Questions ---\n")
    results[difficulty] = []
    for q in qs:
        history = generate_with_self_confidence(tokenizer, model, 
                                                q, 
                                                max_new_tokens=100, 
                                                field_width=15)
        # results[difficulty].append((q, history))
        # print(f"Q: {q}\nA: {history}\n")

In [None]:
run_inference_loop(tokenizer, model, 
                   "Is the Moon Earth's satellite?", 
                   max_new_tokens=50
                  )