In [1]:
import os
import glob
import logging
import base64
import time
import random
import json
import math
import threading
import argparse
import re
import fcntl
from datetime import datetime
from tqdm import tqdm
from openai import OpenAI
from PIL import Image
from concurrent.futures import ThreadPoolExecutor, as_completed, Future

In [None]:
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"

In [None]:
# Model Configurations
MODEL_CONFIGS = [
    # 高并发
    {"name": "qwen-vl-plus-latest", "token_limit": 20000000, "max_workers": 13, "request_interval": 0.09, "input_rate_per_k": 0.0015, "output_rate_per_k": 0.0045},
    # Add more models as needed, including request_interval, input_rate_per_k, output_rate_per_k

    # 低并发
    # Add more models as needed, including request_interval, input_rate_per_k, output_rate_per_k

    # 已用尽
    # {'name': "qwen2.5-vl-3b-instruct", "token_limit": 900000, "max_workers": 10, "request_interval": 0.2, "input_rate_per_k": 0.0012, "output_rate_per_k": 0.0036},
    # {'name': "qwen2.5-vl-7b-instruct", "token_limit": 270000, "max_workers": 10, "request_interval": 0.2, "input_rate_per_k": 0.002, "output_rate_per_k": 0.005},
    # {"name": "qwen-vl-max", "token_limit": 600000, "max_workers": 20, "request_interval": 0.2, "input_rate_per_k": 0.003, "output_rate_per_k": 0.009},
    # {'name': "qwen2.5-vl-32b-instruct", "token_limit": 110000, "max_workers": 4, "request_interval": 0.8, "input_rate_per_k": 0.008, "output_rate_per_k": 0.024},
    # {'name': "qwen2.5-vl-72b-instruct", "token_limit": 650000, "max_workers": 4, "request_interval": 0.8, "input_rate_per_k": 0.016, "output_rate_per_k": 0.048},
    # {'name': "qwen-vl-plus-latest", "token_limit": 800000, "max_workers": 20, "request_interval": 0.2, "input_rate_per_k": 0.0015, "output_rate_per_k": 0.0045},
    # {'name': "qwen-vl-max-latest", "token_limit": 450000, "max_workers": 20, "request_interval": 0.2, "input_rate_per_k": 0.003, "output_rate_per_k": 0.009},
    # {"name": "qwen-vl-plus-2025-01-25", "token_limit": 70000, "max_workers": 3, "request_interval": 0.8, "input_rate_per_k": 0.0015, "output_rate_per_k": 0.0045},
]

In [None]:
# 并发控制设置 (These are now default/fallback, primary control is per-model)
# MAX_WORKERS = 10 # Dynamically set
MAX_RETRIES = 3  # 最大重试次数
RETRY_DELAY = 2  # 重试延迟（秒）
# REQUEST_INTERVAL = 0.2  # Now defined per model
MIN_IMAGES_REQUIRED = 4  # 子文件夹中需要的最少图像数量

# Token/Model State Management
current_model_index = 0
model_token_usage = {config["name"]: 0 for config in MODEL_CONFIGS}
all_models_exhausted = False
current_max_workers = MODEL_CONFIGS[current_model_index]["max_workers"] if MODEL_CONFIGS else 10 # Default if no configs

# 全局请求控制
request_lock = threading.Lock()
last_request_time = time.time()
error_count = 0
consecutive_errors_threshold = 5

# Token统计
total_tokens_used = 0
total_requests = 0
token_lock = threading.Lock()

# 费用统计（按照提供的费率） - Now defined per model
# INPUT_RATE_PER_K = 0.0015
# OUTPUT_RATE_PER_K = 0.0045

# JSON文件锁
json_file_lock = threading.Lock()

# 数据集收集 - 用于内存中缓存
dataset = []
processed_subfolders = set()  # 用于跟踪已处理的子文件夹
dataset_lock = threading.Lock()

# 设置日志
os.makedirs("./logs", exist_ok=True)
log_file = f"./logs/scene_change_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
    filename=log_file,
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# 添加控制台日志处理器，但限制输出
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.WARNING)  # 只显示警告和错误
logging.getLogger().addHandler(console_handler)

# 预设类别列表
# CATEGORIES = ["lake_or_pond", "waste_disposal", "lighthouse", "road_bridge", "tower", "swimming_pool", "smokestack", "educational_institution", "amusement_park", "park", "stadium", "construction_site", "tunnel_opening", "space_facility", "railway_bridge"]

# CATEGORIES = ["barn", "race_track", "nuclear_powerplant", "place_of_worship", "shopping_mall", "runway", "ground_transportation_station", "fire_station", "solar_farm", "oil_or_gas_facility", "police_station", "surface_mine", "archaeological_site", "office_building", "recreational_facility", "border_checkpoint", "interchange", "hospital", "multi-unit_residential", "debris_or_rubble", "factory_or_powerplant", "parking_lot_or_garage"]

CATEGORIES = ["fountain", "electric_substation", "water_treatment_facility", "car_dealership", "crop_field", "wind_farm", "helipad", "gas_station", "impoverished_settlement", "airport_terminal", "golf_course", "single-unit_residential", "dam", "burial_site", "airport", "flooded_road", "shipyard", "zoo", "airport_hangar", "military_facility", "toll_booth", "aquaculture", "port", "prison", "storage_tank"]

In [None]:
def get_current_model_config():
    """Gets the configuration for the currently active model."""
    if all_models_exhausted or current_model_index >= len(MODEL_CONFIGS):
        return None
    return MODEL_CONFIGS[current_model_index]

def switch_to_next_model():
    """Switches to the next available model in the list."""
    global current_model_index, all_models_exhausted, current_max_workers
    with token_lock: # Use token_lock for thread safety when modifying shared state
        current_model_index += 1
        if current_model_index >= len(MODEL_CONFIGS):
            all_models_exhausted = True
            logging.warning("All models have reached their token limits. Stopping processing.")
            print("All models have reached their token limits. Stopping processing.")
            current_max_workers = 1 # Reduce workers if all exhausted
        else:
            new_model_config = MODEL_CONFIGS[current_model_index]
            current_max_workers = new_model_config["max_workers"]
            logging.info(f"Switching to model: {new_model_config['name']} (Limit: {new_model_config['token_limit']}, Workers: {current_max_workers})")
            print(f"Switching to model: {new_model_config['name']} (Limit: {new_model_config['token_limit']}, Workers: {current_max_workers})")
    return not all_models_exhausted

def calculate_image_tokens(image_path):
    """计算图像的token数量"""
    try:
        # 打开图片文件
        image = Image.open(image_path)
        
        # 获取图片的原始尺寸
        height = image.height
        width = image.width
        
        # 将高度调整为28的整数倍
        h_bar = round(height / 28) * 28
        # 将宽度调整为28的整数倍
        w_bar = round(width / 28) * 28
        
        # 图像的Token下限：4个Token
        min_pixels = 28 * 28 * 4
        # 图像的Token上限：1280个Token
        max_pixels = 1280 * 28 * 28
            
        # 对图像进行缩放处理，调整像素的总数在范围[min_pixels,max_pixels]内
        if h_bar * w_bar > max_pixels:
            # 计算缩放因子beta，使得缩放后的图像总像素数不超过max_pixels
            beta = math.sqrt((height * width) / max_pixels)
            # 重新计算调整后的高度，确保为28的整数倍
            h_bar = math.floor(height / beta / 28) * 28
            # 重新计算调整后的宽度，确保为28的整数倍
            w_bar = math.floor(width / beta / 28) * 28
        elif h_bar * w_bar < min_pixels:
            # 计算缩放因子beta，使得缩放后的图像总像素数不低于min_pixels
            beta = math.sqrt(min_pixels / (height * width))
            # 重新计算调整后的高度，确保为28的整数倍
            h_bar = math.ceil(height * beta / 28) * 28
            # 重新计算调整后的宽度，确保为28的整数倍
            w_bar = math.ceil(width * beta / 28) * 28
        
        # 计算图像的Token数：总像素除以28 * 28，加上标记的2个Token
        token_count = int((h_bar * w_bar) / (28 * 28)) + 2
        return token_count
    except Exception as e:
        logging.error(f"计算图片token失败: {str(e)}")
        return 1282  # 返回可能的最大值作为保守估计

def encode_image_to_base64(image_path):
    """将图片编码为base64字符串，用于API请求"""
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except Exception as e:
        logging.error(f"编码图片 {image_path} 失败: {str(e)}")
        return None

def wait_for_request_interval():
    """确保请求之间的最小间隔 (uses current model's interval)"""
    global last_request_time
    model_config = get_current_model_config()
    # Use a default interval if somehow no model is configured (should not happen)
    current_interval = model_config["request_interval"] if model_config else 0.2

    with request_lock:
        current_time = time.time()
        elapsed = current_time - last_request_time
        if elapsed < current_interval:
            sleep_time = current_interval - elapsed
            time.sleep(sleep_time)
        last_request_time = time.time()

def update_token_stats(image_tokens, completion_tokens):
    """更新token统计信息，并检查当前模型是否超过限制 (uses current model's rates)"""
    global total_tokens_used, total_requests, model_token_usage

    model_config = get_current_model_config()
    if not model_config:
        return True # Stop if no model is available

    current_model_name = model_config["name"]
    current_model_limit = model_config["token_limit"]
    # Get rates from the current model config, provide defaults if missing
    input_rate = model_config.get("input_rate_per_k", 0.0015)
    output_rate = model_config.get("output_rate_per_k", 0.0045)

    request_tokens = image_tokens + completion_tokens

    with token_lock:
        # Update overall stats
        total_tokens_used += request_tokens
        total_requests += 1
        avg_tokens = total_tokens_used / total_requests if total_requests > 0 else 0

        # Update current model's usage
        previous_model_usage = model_token_usage[current_model_name]
        model_token_usage[current_model_name] += request_tokens

        # Calculate costs using model-specific rates
        input_cost = (image_tokens / 1000) * input_rate
        output_cost = (completion_tokens / 1000) * output_rate
        request_cost = input_cost + output_cost
        # Total cost estimate is tricky without knowing token split per model, use average of first model?
        # Or maybe log per-model cost instead? Let's just log the request cost for now.
        # TODO: Improve total cost estimation if needed.
        # total_cost_estimate = (total_tokens_used / 1000) * (INPUT_RATE_PER_K + OUTPUT_RATE_PER_K) / 2

        logging.info(f"Model: {current_model_name} - Token统计: Img {image_tokens}, Comp {completion_tokens}, "
                     f"Model Total {model_token_usage[current_model_name]}/{current_model_limit}, "
                     f"Overall Total {total_tokens_used}, Avg Req {avg_tokens:.1f}, "
                     f"Cost: Req ¥{request_cost:.4f} (Rates: In {input_rate}, Out {output_rate})")

        # Check if the current model's limit is exceeded
        if model_token_usage[current_model_name] >= current_model_limit and previous_model_usage < current_model_limit:
            logging.warning(f"Model {current_model_name} reached token limit ({current_model_limit}). Usage: {model_token_usage[current_model_name]}.")
            print(f"Model {current_model_name} reached token limit ({current_model_limit}).")
            if not switch_to_next_model():
                return True # Signal that all models are exhausted

        return all_models_exhausted # Return whether processing should stop overall

def should_stop_processing():
    """Checks if all models have reached their token limits."""
    return all_models_exhausted

def get_date_from_filename(filename):
    """从文件名中提取日期"""
    match = re.search(r'(\d{4})-(\d{1,2})-(\d{1,2})', filename)
    if match:
        year, month, day = map(int, match.groups())
        return datetime(year, month, day)
    return None

def sort_images_by_date(image_paths):
    """按照文件名中的日期对图像进行排序"""
    return sorted(image_paths, key=lambda x: get_date_from_filename(os.path.basename(x)) or datetime.min)

def read_scene_description(subfolder):
    """读取场景描述文件的内容"""
    desc_path = os.path.join(subfolder, "scene_description.txt")
    if os.path.exists(desc_path):
        try:
            with open(desc_path, "r", encoding="utf-8") as f:
                return f.read().strip()
        except Exception as e:
            logging.warning(f"无法读取场景描述文件 {desc_path}: {str(e)}")
    return ""

def load_existing_dataset(json_file_path):
    """加载现有的数据集文件，用于续传"""
    global dataset, processed_subfolders
    
    if os.path.exists(json_file_path) and os.path.getsize(json_file_path) > 0:
        try:
            with open(json_file_path, 'r', encoding='utf-8') as f:
                loaded_data = json.load(f)
                if isinstance(loaded_data, list):
                    with dataset_lock:
                        dataset = loaded_data
                        # 从现有数据集构建已处理子文件夹集合
                        for item in dataset:
                            if 'output_image' in item:
                                subfolder = os.path.dirname(item['output_image'])
                                processed_subfolders.add(subfolder)
                    
                    print(f"已加载现有数据集，包含 {len(dataset)} 条记录")
                    logging.info(f"已加载现有数据集，包含 {len(dataset)} 条记录")
                    return True
                else:
                    print(f"警告: 数据集文件 {json_file_path} 格式不正确")
                    logging.warning(f"数据集文件 {json_file_path} 格式不正确")
        except json.JSONDecodeError as e:
            print(f"警告: 无法解析数据集文件 {json_file_path}: {str(e)}")
            logging.warning(f"无法解析数据集文件 {json_file_path}: {str(e)}")
        except Exception as e:
            print(f"警告: 读取数据集文件 {json_file_path} 时出错: {str(e)}")
            logging.warning(f"读取数据集文件 {json_file_path} 时出错: {str(e)}")
    
    # 如果文件不存在或读取失败，初始化为空列表
    with dataset_lock:
        dataset = []
        processed_subfolders = set()
    return False

def is_subfolder_processed(subfolder):
    """检查子文件夹是否已经处理过"""
    with dataset_lock:
        return subfolder in processed_subfolders

def add_to_dataset(input_images, output_image, input_prompt, output_description, json_file_path):
    """添加条目到数据集并立即写入JSON文件"""
    global dataset, processed_subfolders
    
    # 创建新条目
    new_entry = {
        "input_images": input_images,
        "output_image": output_image,
        "input_prompt": input_prompt,
        "output_description": output_description
    }
    
    # 添加到内存中的数据集
    with dataset_lock:
        dataset.append(new_entry)
        processed_subfolders.add(os.path.dirname(output_image))
    
    # 立即写入JSON文件
    with json_file_lock:
        try:
            temp_file = f"{json_file_path}.temp"
            with open(temp_file, 'w', encoding='utf-8') as f:
                # 获取文件锁，确保写入操作的原子性
                fcntl.flock(f, fcntl.LOCK_EX)
                try:
                    # 使用内存中的完整数据集写入临时文件
                    with dataset_lock:
                        json.dump(dataset, f, ensure_ascii=False, indent=2)
                    # 确保数据写入磁盘
                    f.flush()
                    os.fsync(f.fileno())
                finally:
                    # 释放文件锁
                    fcntl.flock(f, fcntl.LOCK_UN)
            
            # 重命名临时文件，替换原文件
            os.replace(temp_file, json_file_path)
            logging.info(f"已将新条目保存到数据集文件，当前共 {len(dataset)} 条记录")
            return True
        except Exception as e:
            logging.error(f"保存数据集时发生错误: {str(e)}")
            if os.path.exists(temp_file):
                try:
                    os.remove(temp_file)
                except:
                    pass
            return False

def get_change_description(images, prompt, subfolder_path, folder_name):
    """从模型获取图像变化描述"""
    global error_count

    # Check if processing should stop overall
    if should_stop_processing():
        logging.info(f"Skipping request for {subfolder_path} (All models exhausted)")
        return None

    model_config = get_current_model_config()
    if not model_config: # Should not happen if should_stop_processing is checked first, but good practice
        logging.error(f"No available model for request {subfolder_path}")
        return None

    current_model_name = model_config["name"]
    current_model_limit = model_config["token_limit"]
    current_model_tokens_used = model_token_usage[current_model_name]

    # 验证所有图片
    valid_images = []
    for img_path in images:
        if os.path.exists(img_path) and os.path.getsize(img_path) > 0:
            base64_image = encode_image_to_base64(img_path)
            if base64_image:
                valid_images.append((img_path, base64_image))
        else:
            logging.warning(f"图片不存在或大小为0: {img_path}")
    
    if len(valid_images) < len(images):
        logging.error(f"子文件夹 {subfolder_path} 中有无效图片，跳过处理")
        return None
    
    # 计算图像token
    total_image_tokens = sum([calculate_image_tokens(img_path) for img_path, _ in valid_images])

    # 检查当前模型单个请求是否可能超过限制
    with token_lock:
        # Estimate completion tokens conservatively (e.g., 100)
        estimated_completion_tokens = 100
        expected_tokens_for_request = total_image_tokens + estimated_completion_tokens
        if current_model_tokens_used + expected_tokens_for_request > current_model_limit:
            logging.warning(f"Skipping request {subfolder_path} for model {current_model_name} (Estimated request would exceed token limit)")
            # Attempt to switch model immediately if this one is likely full
            if not switch_to_next_model():
                # If switching fails (all models exhausted), return None
                 return None
            else:
                 # If switching succeeds, let the next iteration/task try with the new model
                 # Returning None here prevents processing this subfolder *now*
                 # but allows the overall process to continue if a new model is available.
                 # Alternatively, could retry the request immediately with the new model,
                 # but that adds complexity. Let's stick to skipping for now.
                 return None

    # 准备消息内容
    content = []
    
    # 添加所有图片
    for img_path, base64_image in valid_images:
        content.append({
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image}"
            }
        })
    
    # 添加文本提示
    change_prompt = f"Scene: {folder_name}. {prompt} In under 77 words, describe specific changes between these images (earliest to latest) in English. Focus on precise location details (e.g., 'new building in bottom-left', 'runway extended northward'). Identify concrete changes in structures, landscape, or development. Be specific and concise. Must be within 50 words. Attention!! Focus on changes not description."
    content.append({
        "type": "text",
        "text": change_prompt
    })
    
    # 记录请求信息
    logging.info(f"Preparing change description request for {subfolder_path} using model {current_model_name} ({len(valid_images)} images)")

    # 添加重试机制
    for retry in range(MAX_RETRIES):
        # Check again before each attempt
        if should_stop_processing():
            logging.info(f"Canceling request for {subfolder_path} (All models exhausted)")
            return None

        # Check if the *specific model* changed or became exhausted during wait/retry
        model_config_before_wait = get_current_model_config()
        if not model_config_before_wait or model_config_before_wait["name"] != current_model_name:
             logging.info(f"Model changed during retry/wait for {subfolder_path}. Aborting current attempt.")
             # Let the caller handle retrying with the potentially new model if needed.
             return None # Indicate failure for this attempt with the old model

        try:
            # 等待请求间隔
            wait_for_request_interval()
            
            # 创建OpenAI客户端
            client = OpenAI(
                base_url=BASE_URL,
                api_key=OPENROUTER_API_KEY,
            )
            
            logging.info(f"Sending change description request: {subfolder_path} with model {current_model_name} (Try {retry+1}/{MAX_RETRIES})")
            
            # 发送请求
            response = client.chat.completions.create(
                model=current_model_name, # Use the current model name
                messages=[
                    {
                        "role": "user",
                        "content": content
                    }
                ],
                max_tokens=256,
                temperature=0.7
            )
            
            # 提取描述内容
            description = None
            completion_tokens = 0
            
            # 从响应中提取数据
            if hasattr(response, 'model_dump_json'):
                response_json = json.loads(response.model_dump_json())
                if 'choices' in response_json and response_json['choices'] and 'message' in response_json['choices'][0]:
                    message = response_json['choices'][0]['message']
                    if 'content' in message and message['content']:
                        description = message['content']
                if 'usage' in response_json and 'completion_tokens' in response_json['usage']:
                    completion_tokens = response_json['usage']['completion_tokens']
            
            # 备用方案：直接访问对象属性
            if description is None and hasattr(response, 'choices') and response.choices and len(response.choices) > 0:
                choice = response.choices[0]
                if hasattr(choice, 'message') and hasattr(choice.message, 'content'):
                    description = choice.message.content
            
            # 如果获取到描述
            if description:
                # 更新token统计，检查是否超过限制
                token_exceeded = update_token_stats(total_image_tokens, completion_tokens)
                
                # 重置错误计数
                with request_lock:
                    error_count = 0
                
                return description
            
            # 如果所有方法都无法提取描述
            logging.error(f"无法从响应中提取变化描述内容")
            
            # 增加错误计数
            with request_lock:
                error_count += 1
                current_errors = error_count
                
            # 检查是否达到连续错误阈值
            if current_errors >= consecutive_errors_threshold:
                logging.warning(f"检测到连续 {current_errors} 次错误，暂停5秒后继续...")
                time.sleep(5)
                with request_lock:
                    error_count = 0
            
            raise ValueError("无法从API响应中提取描述内容")
            
        except Exception as e:
            error_msg = f"Model {current_model_name} request failed for {subfolder_path} (Try {retry+1}/{MAX_RETRIES}): {str(e)}"
            logging.error(error_msg)

            if retry < MAX_RETRIES - 1:
                # 添加随机延迟再重试
                delay = RETRY_DELAY + random.uniform(0, 2)
                logging.info(f"等待 {delay:.2f} 秒后重试...")
                time.sleep(delay)
    
    logging.error(f"Failed to get change description for {subfolder_path} with model {current_model_name} after {MAX_RETRIES} tries.")
    return None

def process_subfolder(subfolder, category, prompt, output_json):
    """处理单个子文件夹，生成变化描述并添加到数据集"""
    # Check if processing should stop overall
    if should_stop_processing():
        return "TOKEN_EXCEEDED" # Use existing status for simplicity

    try:
        # 检查内存中是否标记为已处理
        if is_subfolder_processed(subfolder):
            logging.info(f"跳过已处于数据集中的子文件夹: {subfolder}")
            return "ALREADY_PROCESSED"
            
        # 提取文件夹名称
        folder_name = os.path.basename(subfolder)
        
        # 检查是否已经处理过（断点续传）
        output_file = os.path.join(subfolder, "scene_change.txt")
        if os.path.exists(output_file):
            logging.info(f"跳过已处理的子文件夹: {subfolder}")
            return "ALREADY_PROCESSED"  # 返回特殊状态表示已处理
        
        # 获取该子文件夹下的所有jpg图片
        image_paths = glob.glob(os.path.join(subfolder, "*.jpg"))
        
        # 检查是否有足够的图像
        if len(image_paths) < MIN_IMAGES_REQUIRED:
            logging.warning(f"子文件夹 {subfolder} 图像数量不足 ({len(image_paths)}/{MIN_IMAGES_REQUIRED})，跳过处理")
            return "INSUFFICIENT_IMAGES"
        
        # 按日期排序图像
        sorted_images = sort_images_by_date(image_paths)
        
        # 选择最晚的4张图像
        selected_images = sorted_images[-MIN_IMAGES_REQUIRED:]
        
        # 分割成输入图像和输出图像
        input_images = selected_images[:3]
        output_image = selected_images[3]
        
        logging.info(f"处理子文件夹: {subfolder}, 选择最晚的 {MIN_IMAGES_REQUIRED} 张图片进行变化描述")
        
        # 读取场景描述作为输入提示
        input_prompt = read_scene_description(subfolder)
        
        # 调用模型获取变化描述
        change_description = get_change_description(selected_images, prompt, subfolder, folder_name)

        # Check again after the potentially long API call
        if should_stop_processing():
             # If limits were hit *during* the API call, the result might be valid,
             # but we shouldn't continue submitting new tasks.
             # We still process the *result* of this completed task if successful.
             pass # Let the logic below handle the result, but the outer loop will stop

        if change_description:
            # 保存描述到txt文件
            with open(output_file, "w", encoding="utf-8") as f:
                f.write(change_description)
            logging.info(f"变化描述已保存到: {output_file}")
            
            # 添加到数据集并立即保存
            add_to_dataset(input_images, output_image, input_prompt, change_description, output_json)
            
            return subfolder
        else:
            # Log failure, but check if it was due to token limits being hit *globally*
            if should_stop_processing():
                 logging.warning(f"Failed to get description for {subfolder}, likely due to hitting token limits.")
                 return "TOKEN_EXCEEDED" # Signal to stop
            else:
                 # Failure occurred within get_change_description.
                 # It could be due to hitting token limits (which sets all_models_exhausted)
                 # or other API errors after retries.
                 # Log the failure for this specific subfolder.
                 logging.error(f"无法为 {subfolder} 获取变化描述 (model might be exhausted or other API error occurred)")
                 # Check if the global stop flag is set *now*. If so, signal to stop processing further tasks.
                 if should_stop_processing():
                     return "TOKEN_EXCEEDED"
                 else:
                     # Otherwise, it was a non-limit-related failure for this subfolder, return None to indicate failure for this item.
                     return None # Signal failure for this specific subfolder
    except Exception as e:
        logging.error(f"处理子文件夹 {subfolder} 时发生异常: {str(e)}")
        return None

def process_dataset(dataset_path, categories=None, prompt="", output_json="dataset.json", use_parallel=True):
    """处理指定类别的数据集，生成变化描述和构建数据集文件"""
    global current_max_workers # Allow modification if model switches

    # 加载现有数据集（续传功能）
    load_existing_dataset(output_json)
    
    if not MODEL_CONFIGS:
        print("错误: MODEL_CONFIGS is empty. Please define models to use.")
        logging.error("MODEL_CONFIGS is empty.")
        return

    initial_model_config = get_current_model_config()
    print(f"Starting with model: {initial_model_config['name']} (Limit: {initial_model_config['token_limit']}, Workers: {initial_model_config['max_workers']})")
    logging.info(f"Starting processing with model: {initial_model_config['name']}")

    # 初始化计数器
    total_processed = 0
    total_success = 0
    total_skipped = 0
    total_insufficient = 0
    
    # 获取所有类别文件夹或使用指定的类别
    all_categories = os.listdir(dataset_path) if not categories else categories
    
    # 过滤出有效的类别文件夹
    valid_categories = []
    for category in all_categories:
        category_path = os.path.join(dataset_path, category)
        if os.path.isdir(category_path):
            valid_categories.append(category)
    
    if not valid_categories:
        print(f"错误: 在 {dataset_path} 中未找到有效的类别文件夹")
        return
    
    print(f"将处理以下类别: {', '.join(valid_categories)}")
    logging.info(f"将处理以下类别: {', '.join(valid_categories)}")
    
    for category in valid_categories:
        category_path = os.path.join(dataset_path, category)
        
        print(f"处理类别: {category}")
        logging.info(f"处理类别: {category}")
        
        # 获取该类别下的所有子文件夹
        subfolders = [os.path.join(category_path, d) for d in os.listdir(category_path) 
                      if os.path.isdir(os.path.join(category_path, d))]
        
        if not subfolders:
            continue
            
        # 创建进度条
        pbar = tqdm(total=len(subfolders), desc=f"Processing {category} (Model: {get_current_model_config()['name'] if get_current_model_config() else 'N/A'})")

        # Check if processing should stop before starting category
        if should_stop_processing():
            print(f"All models exhausted. Stopping before processing category {category}.")
            break

        if use_parallel:
            # 并行处理
            # Use current_max_workers which might change if models switch
            with ThreadPoolExecutor(max_workers=current_max_workers) as executor:
                futures = []
                future_to_subfolder = {}

                for subfolder in subfolders:
                    # Check before submitting each task
                    if should_stop_processing():
                        logging.warning("Stopping task submission, models exhausted.")
                        # Cancel already submitted, not yet running tasks
                        for f in futures:
                            if not f.running() and not f.done():
                                f.cancel()
                        break

                    model_name_at_submit = get_current_model_config()['name'] if get_current_model_config() else "N/A"
                    future = executor.submit(process_subfolder, subfolder, category, prompt, output_json)
                    future_to_subfolder[future] = (subfolder, model_name_at_submit)
                    futures.append(future)

                # Process completed tasks
                for future in as_completed(futures):
                    # Check if stop signal received while waiting for results
                    if should_stop_processing() and not future.done():
                         # Don't wait indefinitely if told to stop
                         continue

                    subfolder, model_used = future_to_subfolder[future]
                    try:
                        # Check for cancellation first
                        if future.cancelled():
                            logging.info(f"Task for {subfolder} was cancelled.")
                            pbar.update(1)
                            continue

                        result = future.result()

                        # Update progress bar description if model switched
                        current_model_name_desc = get_current_model_config()['name'] if get_current_model_config() else "STOPPED"
                        pbar.set_description(f"Processing {category} (Model: {current_model_name_desc})")

                        if result == "ALREADY_PROCESSED":
                            total_skipped += 1
                        elif result == "TOKEN_EXCEEDED":
                            # This status now means "stop processing", either current model or all models
                            logging.warning(f"Received stop signal (TOKEN_EXCEEDED) from {subfolder}.")
                            # No need to explicitly cancel here, the should_stop_processing check handles it.
                            # Ensure the outer loops break correctly
                            all_models_exhausted = True # Make sure flag is set
                        elif result == "INSUFFICIENT_IMAGES":
                            total_insufficient += 1
                        else:
                            total_processed += 1
                            if result:
                                total_success += 1
                    except Exception as e:
                        logging.error(f"处理子文件夹 {subfolder} 时发生错误: {str(e)}")
                        total_processed += 1
                    
                    # Update progress bar even if there was an error or skip
                    pbar.update(1)

                    # Break outer loop if all models are exhausted
                    if should_stop_processing():
                         logging.warning("Breaking processing loop as all models are exhausted.")
                         # Attempt to cancel remaining futures
                         for f in futures:
                              if not f.done():
                                   f.cancel()
                         break # Break from the as_completed loop

        else: # Sequential processing
            for subfolder in subfolders:
                # Check before each task
                if should_stop_processing():
                    logging.warning("Stopping sequential processing, models exhausted.")
                    break

                current_model_name_desc = get_current_model_config()['name'] if get_current_model_config() else "STOPPED"
                pbar.set_description(f"Processing {category} (Model: {current_model_name_desc})")

                try:
                    result = process_subfolder(subfolder, category, prompt, output_json)
                    if result == "ALREADY_PROCESSED":
                        total_skipped += 1
                    elif result == "TOKEN_EXCEEDED":
                        logging.warning(f"Received stop signal (TOKEN_EXCEEDED) from {subfolder}.")
                        all_models_exhausted = True # Ensure flag is set
                        break # Stop processing this category
                    elif result == "INSUFFICIENT_IMAGES":
                        total_insufficient += 1
                    else:
                        total_processed += 1
                        if result:
                            total_success += 1
                except Exception as e:
                    logging.error(f"处理子文件夹 {subfolder} 时发生错误: {str(e)}")
                    total_processed += 1
                
                # Update progress bar even if there was an error or skip
                pbar.update(1)

        # Close progress bar for the category
        pbar.close()

        # Break the category loop if all models exhausted
        if should_stop_processing():
            print(f"Stopping processing categories as all models are exhausted.")
            break

    # 打印统计信息
    stats_msg = (f"处理完成! 总共处理: {total_processed} 个子文件夹, 成功: {total_success}, "
                f"失败: {total_processed - total_success}, 跳过(已处理): {total_skipped}, "
                f"图像不足: {total_insufficient}")
    print(stats_msg)
    logging.info(stats_msg)
    
    # 打印最终数据集状态
    dataset_msg = f"最终数据集包含 {len(dataset)} 条记录，已保存到: {output_json}"
    print(dataset_msg)
    logging.info(dataset_msg)
    
    # 打印Token使用统计
    token_msg = f"Token使用统计: 总计 {total_tokens_used} tokens, 平均每请求 {total_tokens_used/total_requests if total_requests > 0 else 0:.1f} tokens"
    print(token_msg) 
    logging.info(token_msg)
    
    # 打印费用估算
    avg_input_output_rate = (initial_model_config.get("input_rate_per_k", 0.0015) + initial_model_config.get("output_rate_per_k", 0.0045)) / 2  # 假设输入输出token比例接近1:1
    cost_estimate = (total_tokens_used / 1000) * avg_input_output_rate
    cost_msg = f"费用估算: ¥{cost_estimate:.4f} (按平均费率计算)"
    print(cost_msg)
    logging.info(cost_msg)

def main():
    # 使用命令行参数解析
    parser = argparse.ArgumentParser(description="生成场景变化描述并构建数据集")
    parser.add_argument("--dataset_path", type=str, default="/data/IceInPot/datasets/fmow-rgb/train-resized",
                        help="数据集根目录路径")
    parser.add_argument("--categories", type=str, nargs="+", default=None,
                        help="要处理的类别，不指定则使用预设的CATEGORIES列表")
    parser.add_argument("--output_dir", type=str, default="/data/IceInPot/gzb_check/output_json",
                        help="输出数据集JSON文件目录路径")
    parser.add_argument("--parallel", action="store_true", default=True, # Keep default as True, but user can override
                        help="是否使用并行处理 (Concurrency controlled by model config)")
    parser.add_argument("--no-parallel", dest='parallel', action='store_false',
                        help="禁用并行处理")

    args = parser.parse_args()

    if not MODEL_CONFIGS:
         print("Error: MODEL_CONFIGS list is empty. Cannot run without models defined.")
         return

    # 自定义变化描述提示词
    custom_prompt = "You are looking at a time series of remote sensing images of the same location. "
    
    print(f"开始处理数据集: {args.dataset_path}")
    
    # 确保输出目录存在
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 使用命令行参数指定的类别或预设类别列表
    categories_to_process = args.categories if args.categories else CATEGORIES
    print(f"将处理以下类别: {', '.join(categories_to_process)}")

    print("Model Configurations:")
    for config in MODEL_CONFIGS:
        print(f"  - Name: {config['name']}, Token Limit: {config['token_limit']}, Max Workers: {config['max_workers']}, "
              f"Interval: {config.get('request_interval', 'N/A')}s, "
              f"Rates (In/Out): ¥{config.get('input_rate_per_k', 'N/A')}/¥{config.get('output_rate_per_k', 'N/A')} per 1k")

    # Initial concurrency is set globally based on the first model
    print(f"并发设置: {'启用' if args.parallel else '禁用'}. Initial max workers: {current_max_workers}")
    logging.info(f"开始处理数据集: {args.dataset_path}, Models: {json.dumps(MODEL_CONFIGS)}")

    # 为每个类别处理数据集并保存到单独的JSON文件
    for category in categories_to_process:
        # 为当前类别生成输出JSON文件路径
        output_json = os.path.join(args.output_dir, f"{category}_changes.json")
        print(f"\n处理类别: {category}, 输出文件: {output_json}")
        logging.info(f"开始处理类别: {category}, 输出文件: {output_json}")
        
        # 清空全局token统计，为新类别重置模型状态
        global total_tokens_used, total_requests, current_model_index, all_models_exhausted, model_token_usage
        total_tokens_used = 0
        total_requests = 0
        current_model_index = 0
        all_models_exhausted = False
        model_token_usage = {config["name"]: 0 for config in MODEL_CONFIGS}
        
        # 处理单个类别
        process_dataset(
            dataset_path=args.dataset_path, 
            categories=[category], # 传递单一类别列表
            prompt=custom_prompt, 
            output_json=output_json,
            use_parallel=args.parallel
        )
        
        print(f"完成类别 {category} 的处理")
        logging.info(f"完成类别 {category} 的处理")
        
        # 如果所有模型都已用尽，退出循环
        if all_models_exhausted:
            print("所有模型已达到Token限制，停止处理剩余类别")
            logging.warning("所有模型已达到Token限制，停止处理剩余类别")
            break
    
    print("所有类别处理完成！")
    logging.info("所有类别处理完成！")

if __name__ == "__main__":
    main() # Wrap execution in a main function