In [None]:
%pip install "numpy<2"
%pip install transformers
%pip install bitsandbytes
%pip install accelerate
%pip install tqdm

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m]

In [None]:
from huggingface_hub import login
# login()

In [None]:
import os
import json
import gc
from PIL import Image
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from tqdm import tqdm

# --- 設定（パスやプロンプトを一箇所に集約） ---
ROOT_DIR = "/workspace/wm/datasets"
TASK = "pusht_noise"
TRAIN_VAL = "val" # "train" or "val"

LAST_FRAME_DIR = f"{ROOT_DIR}/{TASK}/{TRAIN_VAL}/last_frames"
OUTPUT_JSONL = f"{ROOT_DIR}/{TASK}/{TRAIN_VAL}/captions.jsonl"
OVERWRITE_JSONL = False

MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
PROMPT = (
    "Given this image as a goal state for a PushT robot task, "
    "provide a concise caption between 3 to 8 words "
    "that implies the achieved state."
)

def load_model():
    """
    モデルとプロセッサをロードし、GPUに配置する
    """
    # 4bit量子化の設定（VRAM節約）
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True
    )

    print(f"Loading model: {MODEL_ID}")
    model = MllamaForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        # quantization_config=bnb_config, # 量子化
        device_map={"": 0} , # "auto"だと自動的に最適なデバイス（GPU/CPU）へ割り当て。ただし"auto"の方が推論遅いかも?
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    return model, processor

def generate_captions():
    if os.path.exists(OUTPUT_JSONL) and not OVERWRITE_JSONL:
        print(f"Skipping: {OUTPUT_JSONL} already exists.")
        return

    # ディレクトリの準備
    os.makedirs(os.path.dirname(OUTPUT_JSONL), exist_ok=True)
    
    # モデルの準備
    model, processor = load_model()

    # 画像リストの準備
    image_files = [f for f in os.listdir(LAST_FRAME_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    image_files.sort()

    print(f"Total images found: {len(image_files)}")

    # 推論ループ
    with torch.no_grad():
        with open(OUTPUT_JSONL, "w", encoding="utf-8") as f:
            for image_name in tqdm(image_files, desc="Generating captions"):
                image_path = os.path.join(LAST_FRAME_DIR, image_name)
                
                try:
                    # 画像の読み込み
                    image = Image.open(image_path).convert("RGB")

                    # メッセージの構築
                    messages = [
                        {"role": "user", "content": [
                            {"type": "image"},
                            {"type": "text", "text": PROMPT}
                        ]}
                    ]
                    
                    # 入力の準備
                    input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
                    inputs = processor(
                        image,
                        input_text,
                        add_special_tokens=False,
                        return_tensors="pt"
                    ).to(model.device)

                    # 推論実行
                    output = model.generate(**inputs, max_new_tokens=30)
                    
                    # テキストのデコードと抽出
                    full_text = processor.decode(output[0], skip_special_tokens=True)
                    caption = full_text.split("assistant")[-1].strip()

                    # 結果を保存
                    result = {
                        "file_name": image_name,
                        "caption": caption
                    }
                    f.write(json.dumps(result, ensure_ascii=False) + "\n")
                    f.flush()

                except Exception as e:
                    print(f"\nError processing {image_name}: {e}")
                
                finally:
                    # メモリ解放
                    if 'inputs' in locals(): del inputs
                    if 'output' in locals(): del output
                    if 'image' in locals(): del image
                    torch.cuda.empty_cache()
                    gc.collect()

if __name__ == "__main__":
    generate_captions()
    print(f"\n✓ 完了！ 結果は {OUTPUT_JSONL} に保存されました。")

In [2]:
!nvidia-smi

Sun Jan 18 04:05:07 2026       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.274.02             Driver Version: 535.274.02   CUDA Version: 13.0     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   66C    P8              18W /  72W |      0MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    