In [1]:
# -------- C·∫§U H√åNH CH·∫†Y L·∫†I CHO C√ÅC PART B·ªä THI·∫æU --------

# S·ªë part ban ƒë·∫ßu khi chia full dataset
ORIGINAL_NUM_PARTS = 53

# Nh·ªØng part ƒë√£ b·ªã ng·∫Øt ·ªü l·∫ßn ch·∫°y ƒë·∫ßu
MISSING_PARTS = [8, 15, 17, 39, 45, 46]

# Chia l·∫°i to√†n b·ªô t·∫≠p "thi·∫øu" th√†nh bao nhi√™u part m·ªõi
NUM_PARTS = 12              # ·ªü ƒë√¢y l√† 12

# M·ªói l·∫ßn submit notebook tr√™n Kaggle, ƒë·ªïi NUM_ID = 1..12
NUM_ID = 8                  # <-- s·ª≠a s·ªë n√†y khi ch·∫°y t·ª´ng part

BATCH_SIZE = 32             # s·ªë ·∫£nh x·ª≠ l√Ω ƒë·ªìng th·ªùi tr√™n m·ªói GPU

# ƒê∆∞·ªùng d·∫´n ƒë·∫øn file index (gi·ªØ nguy√™n)
INDEX_FILE = "/kaggle/input/irkeyframeindex/keyframe_index_kaggle.json"

In [2]:
import os
import json
import gc
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock

OUTPUT_DIR = "/kaggle/working/captions_per_part"
os.makedirs(OUTPUT_DIR, exist_ok=True)

MODEL_NAME = "vikhyatk/moondream2"
REVISION = "2025-06-21"

# Lock ƒë·ªÉ ƒë·ªìng b·ªô khi ghi k·∫øt qu·∫£
caption_lock = Lock()


def load_model(device):
    print(f"üì• Loading Moondream2 model on {device}...")
    device_id = device.index if hasattr(device, 'index') else int(str(device).split(':')[-1])
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        revision=REVISION,
        trust_remote_code=True,
        device_map={"": device_id}
    )
    print(f"‚úÖ Model loaded on {device}")
    return model


def caption_image(model, image_path):
    try:
        image = Image.open(image_path).convert("RGB")
        result = model.caption(image, length="normal")
        return result["caption"]
    except Exception as e:
        print(f"‚ö†Ô∏è Error captioning {image_path}: {e}")
        return ""


def load_index_file(index_path):
    """Load all images from index file as a flat list"""
    with open(index_path, 'r') as f:
        index_data = json.load(f)
    
    # Convert to list of (idx, img_path) tuples
    all_images = [(int(idx), img_path) for idx, img_path in index_data.items()]
    all_images.sort(key=lambda x: x[0])  # Sort by index
    
    return all_images


def split_dataset(all_images, num_parts, part_id):
    """Split dataset into parts and return the selected part"""
    total_images = len(all_images)
    
    # Calculate partition boundaries
    images_per_part = total_images // num_parts
    remainder = total_images % num_parts
    
    # Calculate start and end indices for this part
    if part_id <= remainder:
        start_idx = (part_id - 1) * (images_per_part + 1)
        end_idx = start_idx + images_per_part + 1
    else:
        start_idx = remainder * (images_per_part + 1) + (part_id - remainder - 1) * images_per_part
        end_idx = start_idx + images_per_part
    
    selected_images = all_images[start_idx:end_idx]
    
    print(f"üìä Total images: {total_images}")
    print(f"üì¶ Part {part_id}/{num_parts}: Processing images {start_idx+1} to {end_idx}")
    print(f"üñºÔ∏è  Number of images in this part: {len(selected_images)}")
    
    return selected_images


def process_batch_on_gpu(model, image_batch, gpu_id, all_captions):
    """Process a batch of images on a specific GPU"""
    batch_captions = {}
    
    for idx, img_path in image_batch:
        caption = caption_image(model, img_path)
        if caption:
            batch_captions[str(idx)] = caption
    
    # Thread-safe update of shared results
    with caption_lock:
        all_captions.update(batch_captions)
    
    return len(batch_captions), gpu_id


def main():
    num_gpus = torch.cuda.device_count()
    print(f"üöÄ Starting captioning on {num_gpus} GPU(s)")
    print(f"üìã Original dataset was split into {ORIGINAL_NUM_PARTS} parts")
    print(f"üìã Missing original parts: {MISSING_PARTS}")
    print(f"üìã New subset partition: Part {NUM_ID}/{NUM_PARTS}")
    print(f"üì¶ Batch size: {BATCH_SIZE} images per GPU")
    
    # 1. Load full index
    print(f"\nüìÇ Loading index file: {INDEX_FILE}")
    all_images = load_index_file(INDEX_FILE)
    
    # 2. Gom ·∫£nh t·ª´ 6 part c≈© b·ªã thi·∫øu th√†nh m·ªôt t·∫≠p m·ªõi
    print(f"\nüß© Collecting images from missing parts...")
    missing_images = []
    for missing_id in MISSING_PARTS:
        print(f"\nüîç Collecting images from original part {missing_id}/{ORIGINAL_NUM_PARTS}...")
        part_images = split_dataset(all_images, ORIGINAL_NUM_PARTS, missing_id)
        print(f"   ‚ûï Added {len(part_images)} images")
        missing_images.extend(part_images)
    
    print(f"\nüß© Total images in missing subset: {len(missing_images)}")
    print(f"üìã Re-splitting missing subset into {NUM_PARTS} new parts...")
    
    # 3. Chia l·∫°i t·∫≠p "thi·∫øu" th√†nh 12 ph·∫ßn m·ªõi v√† l·∫•y part ƒëang c·∫ßn ch·∫°y
    selected_images = split_dataset(missing_images, NUM_PARTS, NUM_ID)
    
    if not selected_images:
        print("‚ùå No images found in this partition!")
        return
    
    print(f"\nüñºÔ∏è  Processing {len(selected_images)} images in this partition")
    
    # 4. Load model(s)
    print(f"\n{'='*60}")
    print("üì• Loading model(s)...")
    print(f"{'='*60}")
    
    models = []
    for gpu_id in range(num_gpus):
        device = torch.device(f"cuda:{gpu_id}")
        model = load_model(device)
        models.append(model)
    
    all_captions = {}
    
    print(f"\n{'='*60}")
    print(f"üîÑ Processing images across {num_gpus} GPUs in parallel...")
    print(f"{'='*60}")
    
    # 5. Process images in batches across GPUs in parallel
    batch_idx = 0
    batch_num = 0
    
    with ThreadPoolExecutor(max_workers=num_gpus) as executor:
        while batch_idx < len(selected_images):
            batch_num += 1
            futures = []
            
            print(f"\nüîÑ Batch {batch_num}")
            
            # Submit tasks for each GPU
            for gpu_id in range(num_gpus):
                start_idx = batch_idx + gpu_id * BATCH_SIZE
                end_idx = min(start_idx + BATCH_SIZE, len(selected_images))
                
                if start_idx >= len(selected_images):
                    break
                
                image_batch = selected_images[start_idx:end_idx]
                
                print(
                    f"   GPU {gpu_id}: Submitting {len(image_batch)} images "
                    f"(indices {start_idx+1}-{end_idx})"
                )
                
                future = executor.submit(
                    process_batch_on_gpu,
                    models[gpu_id],
                    image_batch,
                    gpu_id,
                    all_captions
                )
                futures.append(future)
            
            # Wait for all GPUs to complete this batch
            for future in as_completed(futures):
                num_captions, gpu_id = future.result()
                print(f"   ‚úÖ GPU {gpu_id}: Generated {num_captions} captions")
            
            # Move to next batch (all GPUs combined)
            batch_idx += num_gpus * BATCH_SIZE
    
    # 6. Save results
    output_path = f"{OUTPUT_DIR}/captions_missing_part_{NUM_ID}_of_{NUM_PARTS}.json"
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(all_captions, f, indent=2, ensure_ascii=False)
    
    print(f"\n{'='*60}")
    print("üìä FINAL SUMMARY (MISSING SUBSET)")
    print(f"New part {NUM_ID}/{NUM_PARTS} (from original parts {MISSING_PARTS})")
    print("="*60)
    print(f"‚úÖ Saved captions: {output_path}")
    print(f"üìä Total images processed: {len(selected_images)}")
    print(f"üìä Total captions generated: {len(all_captions)}")
    print(f"üìä Success rate: {len(all_captions)/len(selected_images)*100:.1f}%")
    print("‚úÖ All done!")
    
    # 7. Cleanup
    print(f"\nüßπ Cleaning up...")
    for gpu_id, model in enumerate(models):
        del model
        torch.cuda.empty_cache()
        if torch.cuda.is_available():
            torch.cuda.synchronize(gpu_id)
    gc.collect()
    print("‚úÖ Cleanup complete")