In [None]:
import torch
import sys
import os

current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.append(current_dir)
print(f"Project directory added to PATH: {current_dir}")

# 캐시초기화
if 'processing_paligemma' in sys.modules:
    del sys.modules['processing_paligemma']
if 'modeling_gemma' in sys.modules:
    del sys.modules['modeling_gemma']
if 'utils' in sys.modules:
    del sys.modules['utils']
print("Module cache cleared for forced reload.")


from processing_paligemma import PaliGemmaProcessor
from modeling_gemma import KVCache, PaliGemmaForConditionalGeneration
from utils import load_hf_model_mps_optimized 

print("All custom modules loaded successfully.")

MODEL_PATH = "paligemma-weights/paligemma-3b-pt-224" 

Project directory added to PATH: /Users/daehyeon/Documents/Project/Paligemma
Module cache cleared for forced reload.
All custom modules loaded successfully.


In [None]:
MODEL_PATH = "paligemma-weights/paligemma-3b-pt-224" 
PROMPT = "What objects are present? Where are they located? What are they doing? What is their appearance? What is the setting or background?"
IMAGE_FILE_PATH = "test_images/sample.jpg" #이미지 경로
MAX_TOKENS_TO_GENERATE = 100
TEMPERATURE = 0.8
TOP_P = 0.9
DO_SAMPLE = True
ONLY_CPU = False

In [3]:
device = "cpu"
if not ONLY_CPU:
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"

print(f"Device set to: {device}")


Device set to: mps


In [None]:
# 모델 로딩
print("Loading model...")
model, tokenizer = load_hf_model_mps_optimized(MODEL_PATH, device)

Loading model...
Loading model from paligemma-weights/paligemma-3b-pt-224
Target device: mps
Creating model instance and loading state dict manually...
Model state dict loaded successfully.
Error loading tokenizer or setting device: MPS backend out of memory (MPS allocated: 9.06 GiB, other allocations: 384.00 KiB, max allowed: 9.07 GiB). Tried to allocate 8.00 MiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


RuntimeError: MPS backend out of memory (MPS allocated: 9.06 GiB, other allocations: 384.00 KiB, max allowed: 9.07 GiB). Tried to allocate 8.00 MiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
# 프로세서 초기화
num_image_tokens = model.config.vision_config.num_image_tokens
image_size = model.config.vision_config.image_size
processor = PaliGemmaProcessor(tokenizer, num_image_tokens, image_size)

In [None]:
model.eval()
print("Model loaded and ready for inference.")

In [None]:
#추론 실행 함수

def get_model_inputs(
    processor: PaliGemmaProcessor, prompt: str, image_file_path: str, device: str
):
    """이미지 및 프롬프트 전처리 및 디바이스 이동"""
    image = Image.open(image_file_path).convert("RGB")
    images = [image]
    prompts = [prompt]
    model_inputs = processor(text=prompts, images=images)
    
    # 디바이스 이동
    model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
    return model_inputs



def run_inference():
    """메인 추론 실행"""
    
    model_inputs = get_model_inputs(processor, PROMPT, IMAGE_FILE_PATH, device)
    input_ids = model_inputs["input_ids"]
    attention_mask = model_inputs["attention_mask"]
    pixel_values = model_inputs["pixel_values"]
    kv_cache = KVCache()

    stop_token = processor.tokenizer.eos_token_id
    generated_tokens = []

    print("\n--- Running Inference ---")
    with torch.no_grad():
        for i in range(MAX_TOKENS_TO_GENERATE):
            # 모델 순전파
            outputs = model(
                input_ids=input_ids,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                kv_cache=kv_cache,
            )
            kv_cache = outputs["kv_cache"]
            next_token_logits = outputs["logits"][:, -1, :]

            # 샘플링/결정
            if DO_SAMPLE:
                next_token_logits = torch.softmax(next_token_logits / TEMPERATURE, dim=-1)
                next_token = _sample_top_p(next_token_logits, TOP_P) # _sample_top_p 정의 필요
            else:
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
                
            next_token = next_token.squeeze(0)
            generated_tokens.append(next_token)
            
            if next_token.item() == stop_token:
                break

            # 다음 턴 입력 업데이트
            input_ids = next_token.unsqueeze(-1)
            attention_mask = torch.cat(
                [attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1
            )
            
            # 진행 상황 출력 (옵션)
            sys.stdout.write(tokenizer.decode(next_token.cpu().numpy()))
            sys.stdout.flush()

        generated_tokens = torch.cat(generated_tokens, dim=-1)
        decoded = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        
        print("\n\n--- Final Result ---")
        print(PROMPT + decoded)


# run_inference() # 실행