In [None]:
%pip install "numpy<2"
%pip install transformers
%pip install bitsandbytes
%pip install accelerate
%pip install tqdm

In [None]:
import cv2
import os
from pathlib import Path
from tqdm import tqdm
import json
import gc
from PIL import Image
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig

from huggingface_hub import login
login()

In [None]:
# --- 設定 ---
ROOT_DIR = "/workspace/wm/datasets"
TASK = "pusht_noise"
TRAIN_VAL = "val" # "train" or "val"
OVERWRITE_JSONL = False

MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
PROMPT = (
    "Given this image as a goal state for a PushT robot task, "
    "provide a concise caption between 3 to 8 words "
    "that implies the achieved state."
)

VIDEO_DIR = f"{ROOT_DIR}/{TASK}/{TRAIN_VAL}/obses"
LAST_FRAME_DIR = f"{ROOT_DIR}/{TASK}/{TRAIN_VAL}/last_frames"
OUTPUT_JSONL = f"{ROOT_DIR}/{TASK}/{TRAIN_VAL}/captions.jsonl"

In [None]:
# --- データセット(mp4)の最後のフレームを取り出す ---

def save_last_frames(video_dir, output_dir):
    """
    動画ディレクトリから各動画の最終フレームを抽出し保存する。
    """
    os.makedirs(output_dir, exist_ok=True)
    
    video_extensions = ['.mp4']
    video_files = [f for f in os.listdir(video_dir) if Path(f).suffix.lower() in video_extensions]
    video_files.sort()
    
    print(f"Found {len(video_files)} videos in {video_dir}")
    
    for video_name in tqdm(video_files, desc="Extracting"):
        video_path = os.path.join(video_dir, video_name)
        output_path = os.path.join(output_dir, f"{Path(video_name).stem}_last.jpg")
        
        if os.path.exists(output_path):
            # すでに画像がある場合はスキップ
            continue
            
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if total_frames > 0:
            # 最終フレームへシーク
            cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
            success, frame = cap.read()
            if success:
                cv2.imwrite(output_path, frame, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
            else:
                print(f"Warning: Could not read the last frame of {video_name}")
        
        cap.release()

if __name__ == "__main__":
    save_last_frames(VIDEO_DIR, LAST_FRAME_DIR)
    print(f"\n✓ Process completed. Frames saved to: {LAST_FRAME_DIR}")

In [None]:
# --- 最後のフレームとプロンプトをVLMに渡して、キャプションをつける ---

def load_model():
    """
    モデルとプロセッサをロードし、GPUに配置する
    """
    # 4bit量子化の設定（VRAM節約）
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True
    )

    print(f"Loading model: {MODEL_ID}")
    model = MllamaForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        # quantization_config=bnb_config, # 量子化
        device_map={"": 0}, # "auto"だと、自動的に最適なデバイス（GPU/CPU）へ割り当て。ただし、"auto"の方が推論遅いかも?
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    return model, processor

def generate_captions():
    if os.path.exists(OUTPUT_JSONL) and not OVERWRITE_JSONL:
        print(f"Skipping: {OUTPUT_JSONL} already exists.")
        return

    # ディレクトリの準備
    os.makedirs(os.path.dirname(OUTPUT_JSONL), exist_ok=True)
    
    # モデルの準備
    model, processor = load_model()

    # 画像リストの準備
    image_files = [f for f in os.listdir(LAST_FRAME_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    image_files.sort()

    print(f"Total images found: {len(image_files)}")

    # 推論ループ
    with torch.no_grad():
        with open(OUTPUT_JSONL, "w", encoding="utf-8") as f:
            for image_name in tqdm(image_files, desc="Generating captions"):
                image_path = os.path.join(LAST_FRAME_DIR, image_name)
                
                try:
                    # 画像の読み込み
                    image = Image.open(image_path).convert("RGB")

                    # メッセージの構築
                    messages = [
                        {"role": "user", "content": [
                            {"type": "image"},
                            {"type": "text", "text": PROMPT}
                        ]}
                    ]
                    
                    # 入力の準備
                    input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
                    inputs = processor(
                        image,
                        input_text,
                        add_special_tokens=False,
                        return_tensors="pt"
                    ).to(model.device)

                    # 推論実行
                    output = model.generate(**inputs, max_new_tokens=30)
                    
                    # テキストのデコードと抽出
                    full_text = processor.decode(output[0], skip_special_tokens=True)
                    caption = full_text.split("assistant")[-1].strip()

                    # 結果を保存
                    result = {
                        "file_name": image_name,
                        "caption": caption
                    }
                    f.write(json.dumps(result, ensure_ascii=False) + "\n")
                    f.flush()

                except Exception as e:
                    print(f"\nError processing {image_name}: {e}")
                
                finally:
                    # メモリ解放
                    if 'inputs' in locals(): del inputs
                    if 'output' in locals(): del output
                    if 'image' in locals(): del image
                    torch.cuda.empty_cache()
                    gc.collect()

if __name__ == "__main__":
    generate_captions()
    print(f"\n✓ 完了！ 結果は {OUTPUT_JSONL} に保存されました。")