In [None]:
!pip install bitsandbytes

In [None]:
!pip install huggingface_hub

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
import warnings
from transformers import utils as hf_utils

# Filter the specific warning from PaliGemmaProcessor
warnings.filterwarnings(
    "ignore",
    message="You are passing both `text` and `images` to `PaliGemmaProcessor`.*",
    category=UserWarning,
    module='transformers.models.paligemma.processing_paligemma'
)

In [None]:
# =============================================================================
# Combined Script for VLM-based Change Detection VQA
# Integrates PaliGemma into the existing pipeline, replacing custom
# encoder/decoder. Uses Hugging Face Transformers and Accelerate.
# Includes segmentation and VQA/captioning tasks.
#
# Addresses OutOfMemoryError by suggesting/implementing memory-saving techniques:
# - Gradient Accumulation (via args.gradient_accumulation_steps)
# - Quantization (via args.quantization)
# - Mixed Precision (via args.mixed_precision)
# - Gradient Checkpointing (New option)
# - Reducing Batch Size (User adjustable via args.batch_size_per_gpu)
# - Reducing Image Size (User adjustable via args.image_size)
#
# Fixes:
# - RuntimeError due to shape mismatch in VLM forward pass when handling labels/padding.
# - Addressed OutOfMemoryError by setting more memory-friendly default arguments and
#   highlighting key arguments for memory management.
# - FIXED IndexError: Correctly prepare input_ids and labels for VLM by encoding
#   the full prompt+caption sequence and masking labels appropriately.
# - FIXED AttributeError: Moved collate_fn inside the Trainer class.
# - FIXED SyntaxError: Removed misplaced backslash in dataset __init__.
# - FIXED NameError: Added missing SegmentationHead class definition.
# - FIXED ValueError: Ensured PaliGemmaProcessor is called with images argument
#   even when processing prompt only in __getitem__.
# - FIXED AttributeError: Corrected `parser.add_clip` to `parser.add_argument`
#   for the `--grad_clip` argument.
# - FIXED TypeError: Added the missing argument name ('--grad_clip') to `parser.add_argument`.
# - FIXED Warning: Added explicit '<image>' token to the text input for PaliGemmaProcessor.
# - FIXED NameError: Added the definition for the `ChangeDetectionVLM` class.
# - FIXED TypeError: Enabled gradient checkpointing using the model's method after loading.
# - **FIXED RuntimeError:** Corrected the input channel size for the SegmentationHead.
# =============================================================================

import sys
import os
import json
import argparse
import numpy as np
from collections import defaultdict
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F
import math
from PIL import Image # Use PIL for image loading compatible with transformers processor
# from imageio.v2 import imread # Keep if needed, but PIL is standard for HF
from random import randint
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from tqdm.auto import tqdm # Use tqdm.auto for notebook compatibility
from accelerate import Accelerator, DistributedDataParallelKwargs # For multi-GPU, mixed precision
from accelerate.utils import set_seed
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, BitsAndBytesConfig # Import VLM
import bitsandbytes # Required for quantization
from transformers.utils import is_accelerate_available # Check if accelerate is installed

# =============================================================================
# --- Hardcoded Paths for Kaggle ---
# Update DATASET_ROOT to the correct path of your dataset input.
# =============================================================================
# Make sure your dataset is added to the Kaggle notebook input
# Example: '/kaggle/input/levir-mci-dataset/LEVIR-MCI-dataset'
DATASET_ROOT = '/kaggle/input/levir-mci-dataset/LEVIR-MCI-dataset' # <-- *** UPDATE THIS PATH ***
SAVE_OUTPUT_DIR = '/kaggle/working/' # Standard Kaggle writable output directory

# =============================================================================
# --- Constants ---
# =============================================================================
# Define the mapping from RGB colors to class IDs
# Background: 0 (black), Road: 1 (grey), Building: 2 (white)
COLOR_TO_ID_MAPPING = {
    (0, 0, 0): 0,          # Background
    (128, 128, 128): 1,    # Road (Grey)
    (255, 255, 255): 2,    # Building (White)
}
NUM_CLASSES = 3 # Background, Road, Building
DEFAULT_VLM = "google/paligemma-3b-pt-224" # Default PaliGemma model

# =============================================================================
# --- Utility Functions ---
# =============================================================================

def rgb_to_class_id_mask(rgb_mask_pil):
    """Converts an RGB mask (PIL Image) to a class ID mask (H, W)."""
    rgb_mask_np = np.array(rgb_mask_pil, dtype=np.uint8)
    h, w, c = rgb_mask_np.shape
    if c != 3:
        raise ValueError(f"Input mask must have 3 channels (RGB), but got {c}")

    class_id_mask = np.full((h, w), 0, dtype=np.int64) # Default to background
    for color, class_id in COLOR_TO_ID_MAPPING.items():
        matches = np.all(rgb_mask_np == np.array(color, dtype=rgb_mask_np.dtype), axis=-1)
        class_id_mask[matches] = class_id
    return class_id_mask

# =============================================================================
# --- dataset.py content adapted for VLM ---
# =============================================================================

class LEVIRCCDataset(Dataset):
    def __init__(self, data_folder, processed_data_dir, split, processor,
                 load_segmentation=True,
                 max_length=64, # Increased default max length for prompt + caption
                 vqa_prompt="Describe the changes between the two images.", # Default prompt
                 max_iters=None):
        """
        Args:
            data_folder (str): Path to the root LEVIR-MCI dataset folder.
            processed_data_dir (str): Path to the folder with splits (train.txt etc.).
            split (str): 'train', 'val', or 'test'.
            processor (PaliGemmaProcessor): The processor for the VLM.
            load_segmentation (bool): If True, loads segmentation maps.
            max_length (int): Max sequence length for VLM tokenizer (prompt + caption).
            vqa_prompt (str): Prompt used when framing captioning as VQA.
            max_iters (int, optional): Repeats dataset for this many items per epoch.
        """
        self.data_folder = data_folder
        self.processed_data_dir = processed_data_dir
        self.split = split
        self.processor = processor
        self.load_segmentation = load_segmentation
        self.max_length = max_length
        self.vqa_prompt = vqa_prompt
        self.max_iters = max_iters

        assert self.split in {'train', 'val', 'test'}

        # --- Load Image Filenames/Base Names ---
        split_file_path = os.path.join(self.processed_data_dir, f'{split}.txt')
        try:
            with open(split_file_path, 'r') as f:
                self.img_ids = [line.strip() for line in f if line.strip()]
        except FileNotFoundError:
             raise FileNotFoundError(f"Split file not found at {split_file_path}. Run preprocessing first.")

        if not self.img_ids:
            raise ValueError(f"No image IDs found in split file: {split_file_path}")

        # --- Load Captions from original JSON ---
        # We need the raw captions to tokenize them with the VLM processor
        captions_json_path = os.path.join(DATASET_ROOT, 'LevirCCcaptions.json')
        self.captions_data = {}
        try:
            with open(captions_json_path, 'r') as f:
                raw_data = json.load(f)['images']
            for img_info in raw_data:
                base_name = os.path.splitext(img_info['filename'])[0]
                # Store list of raw sentences for each image ID
                self.captions_data[base_name] = [s['raw'] for s in img_info['sentences']]
        except FileNotFoundError:
            print(f"Warning: Caption file {captions_json_path} not found. Cannot load ground truth captions.")
            self.captions_data = {} # Ensure it's initialized
        except KeyError:
            print(f"Warning: Caption file {captions_json_path} has unexpected format. Cannot load ground truth captions.")
            self.captions_data = {}

        # ---- Prepare file paths ----
        self.files = []
        image_base_path = os.path.join(self.data_folder, 'images', self.split)
        label_folder_name = 'label' # Assuming label folder name is 'label'

        missing_files_count = 0
        for img_base_name in self.img_ids:
            img_file_name = f"{img_base_name}.png"

            file_paths = {
                "name": img_base_name,
                "imgA": os.path.join(image_base_path, 'A', img_file_name),
                "imgB": os.path.join(image_base_path, 'B', img_file_name),
            }
            seg_path = None
            if self.load_segmentation:
                seg_path = os.path.join(image_base_path, label_folder_name, img_file_name)
                file_paths["seg_label"] = seg_path

            # Check if image files exist
            paths_to_check = [file_paths["imgA"], file_paths["imgB"]]
            if self.load_segmentation:
                 paths_to_check.append(seg_path)

            files_exist = all(os.path.exists(p) for p in paths_to_check if p is not None)

            # Check if captions exist for this image
            captions_exist = img_base_name in self.captions_data and self.captions_data[img_base_name]

            if not files_exist or not captions_exist:
                if not files_exist:
                    # print(f"Debug: Missing image/label file for {img_base_name}") # Keep logging minimal
                    pass
                if not captions_exist:
                     # print(f"Debug: Missing captions for {img_base_name}") # Keep logging minimal
                     pass
                missing_files_count += 1
                continue

            self.files.append(file_paths)

        if missing_files_count > 0:
            print(f"Warning: Skipped {missing_files_count} entries due to missing files or captions in split '{self.split}'.")

        if not self.files:
             raise ValueError(f"No valid file sets found for split '{self.split}'. Check paths and preprocessing output.")

        # --- Handle max_iters ---
        if max_iters is not None and max_iters > 0:
            n_repeat = int(np.ceil(max_iters / len(self.files)))
            self.files = self.files * n_repeat

        print(f"Initialized LEVIRCCDataset for split '{self.split}' with {len(self.files)} items.")


    def __len__(self):
        if self.max_iters is not None and self.max_iters > 0:
            return self.max_iters
        return len(self.files)

    def __getitem__(self, index):
        actual_index = index % len(self.files)
        datafiles = self.files[actual_index]
        img_base_name = datafiles['name']

        try:
            # Load images using PIL
            imgA_pil = Image.open(datafiles["imgA"]).convert("RGB")
            imgB_pil = Image.open(datafiles["imgB"]).convert("RGB")
        except Exception as e:
             print(f"Error loading images for {datafiles['name']}: {e}")
             raise

        # --- Load and Process Segmentation ---
        seg_mask_class_ids = None
        if self.load_segmentation:
            try:
                seg_label_pil = Image.open(datafiles["seg_label"])
                # Ensure it's RGB before converting
                if seg_label_pil.mode != 'RGB':
                     seg_label_pil = seg_label_pil.convert('RGB')
                seg_mask_class_ids_np = rgb_to_class_id_mask(seg_label_pil)
                seg_mask_class_ids = torch.from_numpy(seg_mask_class_ids_np).long()
            except Exception as e:
                 print(f"Error loading/processing segmentation label for {datafiles['name']}: {e}")
                 # Return a dummy mask or raise error? Let's return dummy for now.
                 # Get image size from loaded image A
                 width, height = imgA_pil.size
                 seg_mask_class_ids = torch.zeros((height, width), dtype=torch.long)


        # --- Select one caption and prepare VLM inputs ---
        # Choose a random caption for this image
        available_captions = self.captions_data.get(img_base_name, [])
        if not available_captions:
             # Handle case where captions might be missing despite earlier check
             chosen_caption = "No changes detected." # Default caption
             # print(f"Warning: No captions found for {img_base_name} at getitem, using default.") # Keep logging minimal
        else:
             chosen_caption = available_captions[randint(0, len(available_captions) - 1)]

        # --- Prepare VLM inputs (image + text) ---
        # The VLM expects the input sequence to be <image> + text
        # For our change detection task, we're using image B as the primary VLM input
        # The target sequence for the VLM is the prompt followed by the caption.
        # We need to encode the full sequence (prompt + caption) and then create labels
        # by masking out the prompt tokens.

        # Construct the full text sequence for the VLM
        # Add the explicit <image> token at the beginning
        full_text_sequence = "<image>" + self.vqa_prompt + chosen_caption

        try:
            # Process Image B and the full text sequence
            # The processor will add the <image> token before the text
            vlm_inputs = self.processor(
                text=full_text_sequence,
                images=imgB_pil, # Use image B for the VLM input
                return_tensors="pt",
                # padding='max_length', # Padding will be handled by collate_fn
                # max_length=self.max_length # Max length will be handled by collate_fn
            )

            # Extract input_ids, attention_mask, and pixel_values
            input_ids = vlm_inputs['input_ids'].squeeze(0) # Remove batch dim
            attention_mask = vlm_inputs['attention_mask'].squeeze(0) # Remove batch dim
            pixel_values_b = vlm_inputs['pixel_values'].squeeze(0) # Remove batch dim

            # Create labels by shifting input_ids and masking out prompt tokens
            # The VLM calculates loss on the tokens *after* the prompt and image tokens.
            # We need to find where the caption starts in the input_ids.
            # A simple way is to find the first token of the caption within the input_ids
            # after the prompt tokens. However, this can be fragile.
            # A more robust way is to encode the prompt and full sequence separately
            # to find the split point.

            # Encode prompt separately to find its length - ENSURE images are passed here too
            # Add the explicit <image> token here as well for consistency
            prompt_inputs = self.processor(
                text="<image>" + self.vqa_prompt,
                images=imgB_pil, # Pass images here as well
                return_tensors="pt",
                add_special_tokens=False # Don't add BOS/EOS to prompt-only input
            )
            # Note: PaliGemmaProcessor adds one image token ID at the beginning when image is provided.
            # The `input_ids` from `vlm_inputs` will be [<image_token_id>, prompt_token_ids, caption_token_ids]
            # The `input_ids` from `prompt_inputs` will be [<image_token_id>, prompt_token_ids]
            # We can use the length of `prompt_inputs['input_ids']` to find where the caption starts.
            caption_start_index = prompt_inputs['input_ids'].shape[1]


            # Create labels: copy input_ids and mask out everything before caption_start_index
            labels = input_ids.clone()
            # Get the ignore index from the model config
            ignore_index = self.processor.tokenizer.pad_token_id # Use pad token ID for masking, or model's ignore_index
            # PaliGemma config uses -100 by default for ignore_index in loss calculation
            # Let's use -100 as the ignore index for labels
            ignore_index_for_loss = -100 # Standard ignore index for CrossEntropyLoss

            # Mask out tokens before the caption starts
            labels[:caption_start_index] = ignore_index_for_loss

            # Note: Padding will be applied in the collate_fn, which will add pad_token_id.
            # The VLM's loss function is configured to ignore tokens with ignore_index.
            # We need to make sure the pad_token_id is also ignored by the loss.
            # PaliGemma's default ignore_index is -100, and its pad_token_id is 0.
            # The loss function should ignore both. The model's forward pass handles this
            # if its `ignore_index` is set correctly (default -100).
            # We should ensure our padding value (0) is mapped to -100 in labels if needed,
            # or rely on the model's internal ignore_index handling.
            # Let's stick to setting labels before caption to -100. The collate_fn's padding
            # will add 0s, which the model's loss should ignore if its ignore_index is -100.


        except Exception as e:
            print(f"Error processing data with VLM processor for {img_base_name}: {e}")
            # Return None or raise error? Let's raise for now to debug data issues.
            raise


        # Need to process imgA through the processor as well to get pixel_values_a
        try:
            inputs_a = self.processor(
                images=imgA_pil,
                return_tensors="pt"
            )
            pixel_values_a = inputs_a['pixel_values'].squeeze(0) # Remove batch dim
        except Exception as e:
             print(f"Error processing image A with VLM processor for {img_base_name}: {e}")
             # Handle error - e.g., return dummy or raise
             raise


        batch = {
            'pixel_values_a': pixel_values_a, # Use the processed pixel_values_a
            'pixel_values_b': pixel_values_b, # Use the pixel_values from the processed VLM input
            'input_ids': input_ids, # Use the input_ids from the processed VLM input (prompt + caption)
            'attention_mask': attention_mask, # Use the attention_mask from the processed VLM input
            'labels': labels, # Use the correctly masked labels
            'name': datafiles['name'],
        }

        if self.load_segmentation and seg_mask_class_ids is not None:
             batch['seg_mask'] = seg_mask_class_ids

        # Store all reference captions (raw text) for BLEU calculation during validation
        batch['reference_captions'] = self.captions_data.get(img_base_name, [])


        return batch


# --- Segmentation Head Module ---
class SegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes, target_size, patch_size):
        super().__init__()
        self.target_size = target_size
        self.patch_size = patch_size

        # Simple convolutional layers for segmentation
        # Takes patch embeddings, reshapes, and upsamples
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, num_classes, kernel_size=1) # Output channels = num_classes
        )

    def forward(self, vision_outputs):
        # vision_outputs is expected to be a simple object with .last_hidden_state
        # Shape of last_hidden_state: (B, num_patches, in_channels)
        patch_embeddings = vision_outputs.last_hidden_state
        batch_size, num_patches, in_channels = patch_embeddings.shape

        # Calculate grid size (assuming square patches and square image for simplicity)
        # This might need adjustment for non-square images/patches or different VLM outputs
        grid_size = int(math.sqrt(num_patches))
        if grid_size * grid_size != num_patches:
             # Handle cases where num_patches isn't a perfect square (e.g., class token)
             # For now, let's assume the first token is class token and remove it
             if num_patches == grid_size * grid_size + 1:
                  patch_embeddings = patch_embeddings[:, 1:, :] # Remove class token
                  num_patches = patch_embeddings.shape[1]
                  grid_size = int(math.sqrt(num_patches))
             else:
                  raise ValueError(f"Number of patches ({num_patches}) is not a perfect square or square + 1.")


        # Reshape patch embeddings to a grid (B, in_channels, grid_size, grid_size)
        # Permute dimensions: (B, num_patches, in_channels) -> (B, in_channels, num_patches)
        # Reshape last dim: (B, in_channels, grid_size, grid_size)
        reshaped_features = patch_embeddings.permute(0, 2, 1).reshape(batch_size, in_channels, grid_size, grid_size)

        # Pass through convolutional layers
        conv_output = self.conv_layers(reshaped_features)

        # Upsample to target size
        # Use bilinear interpolation for upsampling logits
        seg_logits = F.interpolate(conv_output, size=self.target_size, mode='bilinear', align_corners=False)

        return seg_logits


# Helper class to pass combined features to segmentation head
class CombinedVisionOutputs:
    def __init__(self, last_hidden_state):
        self.last_hidden_state = last_hidden_state


# =============================================================================
# --- ChangeDetectionVLM Model ---
# This class wraps the PaliGemma model and adds a segmentation head.
# =============================================================================
class ChangeDetectionVLM(nn.Module):
    def __init__(self, args, model_name_or_path, processor, num_classes, image_size, patch_size, freeze_vlm_base=False, quantization_config=None):
        super().__init__()
        self.args = args
        self.processor = processor
        self.num_classes = num_classes
        self.image_size = image_size
        self.patch_size = patch_size

        # --- Load VLM ---
        # Use from_pretrained with quantization_config
        self.vlm = PaliGemmaForConditionalGeneration.from_pretrained(
            model_name_or_path,
            quantization_config=quantization_config,
            # Removed gradient_checkpointing from here
        )

        # Enable gradient checkpointing after loading if the argument is set
        if args.gradient_checkpointing:
            self.vlm.gradient_checkpointing_enable()
            print("Gradient checkpointing enabled.")


        # --- Freeze VLM Base (Optional) ---
        if freeze_vlm_base:
            self.vlm.vision_tower.requires_grad_(False)
            self.vlm.multi_modal_projector.requires_grad_(False)
            # Freeze base LLM layers except for the language modeling head
            for name, param in self.vlm.language_model.named_parameters():
                 if 'lm_head' not in name:
                      param.requires_grad_(False)
            print("Frozen VLM vision tower, projector, and base language model layers.")


        # --- Segmentation Head ---
        # Get the output dimension of the VLM's vision tower
        vision_hidden_size = self.vlm.config.vision_config.hidden_size

        # The segmentation head takes concatenated features from two images,
        # so the input channels should be twice the vision hidden size.
        segmentation_input_channels = 2 * vision_hidden_size

        # Calculate target size for segmentation head output (should match original image size)
        segmentation_target_size = (args.image_size, args.image_size) # Assuming square images

        self.segmentation_head = SegmentationHead(
            in_channels=segmentation_input_channels, # Corrected input channels
            num_classes=num_classes,
            target_size=segmentation_target_size,
            patch_size=patch_size # Pass patch size to segmentation head if needed
        )


    def forward(self, pixel_values_a, pixel_values_b, input_ids, attention_mask, labels=None, mode='train'):
        """
        Forward pass for combined VQA and Segmentation.

        Args:
            pixel_values_a (torch.Tensor): Pixel values for image A (B, C, H, W).
            pixel_values_b (torch.Tensor): Pixel values for image B (B, C, H, W).
            input_ids (torch.Tensor): Tokenized input text (prompt + caption) (B, L).
            attention_mask (torch.Tensor): Attention mask for input_ids (B, L).
            labels (torch.Tensor, optional): Tokenized target labels (masked) (B, L).
                                              Required for caption loss calculation.
            mode (str): 'train' or 'eval'. Affects VLM loss calculation.

        Returns:
            dict: Contains 'caption_loss' (if labels provided) and 'seg_logits'.
        """
        # --- VLM Forward Pass (for captioning loss and vision features) ---
        # The VLM forward pass handles the image and text inputs.
        # We need to pass image B and the combined input_ids/attention_mask.
        # Labels are passed for loss calculation during training.

        # Concatenate pixel values of A and B along the batch dimension
        # The VLM expects images to be part of the batch.
        # We'll process them together to get vision features.
        # Note: PaliGemma is designed for single image input per text sequence.
        # For change detection, we're using image B with text, and image A for segmentation features.
        # Let's process image B with the text through the main VLM forward pass for captioning.
        # We'll get vision features for image A separately using the vision tower.

        # VLM forward pass for captioning (using image B and text)
        # Pass labels if in train mode for loss calculation
        vlm_outputs = self.vlm(
            pixel_values=pixel_values_b,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels if mode == 'train' else None, # Pass labels only in train mode
            return_dict=True,
            output_hidden_states=True # Need hidden states for segmentation head
        )

        caption_loss = vlm_outputs.loss if mode == 'train' else None

        # --- Segmentation Branch ---
        # We need vision features from *both* image A and image B for segmentation.
        # PaliGemma's vision tower processes images independently.
        # We can pass image A through the vision tower to get its features.

        # Get vision features for image A
        # Access the vision tower directly (path might vary slightly by model)
        # For PaliGemma, the vision tower is self.vlm.vision_tower
        # The vision tower takes pixel values and returns outputs similar to a vision transformer
        # We need the patch embeddings, which are usually in the last_hidden_state of the vision tower output.
        vision_outputs_a = self.vlm.vision_tower(pixel_values_a)

        # Get vision features for image B from the main VLM output
        # The VLM output includes vision hidden states before the multimodal projector
        # Accessing these might require inspecting the model's output structure.
        # A common pattern is that the vision tower output is part of the VLM's hidden states.
        # Let's assume the vision hidden states for image B are available in vlm_outputs
        # This might need adjustment based on actual PaliGemmaForConditionalGeneration output structure.
        # A safer approach might be to pass image B through the vision tower separately as well.
        # Let's pass image B through the vision tower separately for consistency.
        vision_outputs_b = self.vlm.vision_tower(pixel_values_b)


        # Combine vision features from A and B for segmentation
        # Simple concatenation for now. More sophisticated fusion might be needed.
        # Shape of vision_outputs_a.last_hidden_state: (B, num_patches, hidden_size)
        # Shape of vision_outputs_b.last_hidden_state: (B, num_patches, hidden_size)
        # Concatenate along the last dimension (feature dimension)
        combined_vision_features = torch.cat(
            [vision_outputs_a.last_hidden_state, vision_outputs_b.last_hidden_state],
            dim=-1 # Concatenate features
        ) # Shape (B, num_patches, 2 * hidden_size)

        # Pass combined features to the segmentation head
        # The segmentation head expects a simple object with .last_hidden_state
        seg_outputs = self.segmentation_head(CombinedVisionOutputs(combined_vision_features)) # Shape (B, num_classes, H, W)
        seg_logits = seg_outputs


        # --- Return Outputs ---
        outputs = {
            'caption_loss': caption_loss,
            'seg_logits': seg_logits,
        }

        # Include VLM outputs if needed for debugging or further analysis
        # outputs['vlm_outputs'] = vlm_outputs

        return outputs

    def generate(self, pixel_values, input_ids, attention_mask, **kwargs):
         """
         Generates captions using the VLM.

         Args:
             pixel_values (torch.Tensor): Pixel values for the image (B, C, H, W).
             input_ids (torch.Tensor): Tokenized input text (prompt) (B, L).
             attention_mask (torch.Tensor): Attention mask for input_ids (B, L).
             **kwargs: Additional generation arguments (e.g., max_length, num_beams).

         Returns:
             torch.Tensor: Generated token IDs (B, generated_seq_len).
         """
         # The VLM's generate method handles the generation process.
         # We need to pass the image pixel values and the prompt input_ids/attention_mask.
         # The VLM will generate tokens following the prompt based on the image content.

         # Ensure the model is in evaluation mode for generation
         self.eval()

         # Use the VLM's built-in generate method
         generated_ids = self.vlm.generate(
             pixel_values=pixel_values,
             input_ids=input_ids,
             attention_mask=attention_mask,
             **kwargs # Pass additional generation arguments
         )

         return generated_ids


# =============================================================================
# --- preprocessing.py content (minimal changes needed) ---
# =============================================================================
# The preprocessing step is mainly needed to create the train/val/test splits (txt files)
# and potentially organize data. Vocabulary generation is less critical now
# as the VLM has its own tokenizer via the processor.

# --- Special Tokens (Less relevant for VLM tokenizer) ---
# SPECIAL_TOKENS = { ... } # Keep if needed for compatibility, but VLM handles its own

# --- Helper Functions (build_vocab, encode, pad_sequence) ---
# These might not be directly used by the VLM dataset, but keep if other parts rely on them.
# Ensure they don't conflict with VLM processor logic.

def build_vocab(sequences, min_token_count=1):
    """Builds vocabulary from pre-tokenized sequences (less relevant for VLM)."""
    # ... (keep original implementation if needed elsewhere) ...
    print("Note: build_vocab is less relevant when using a VLM's pre-trained tokenizer.")
    return {'<VLM_TOKENIZER>': 0} # Return dummy vocab

def encode(seq_tokens, token_to_idx, allow_unk=False):
    """Encodes tokens (less relevant for VLM)."""
    # ... (keep original implementation if needed elsewhere) ...
    print("Note: encode function is less relevant when using a VLM's pre-trained tokenizer.")
    return [0] * len(seq_tokens) # Return dummy encoding

def pad_sequence(sequence, max_length, pad_value=0):
    print("Note: pad_sequence function may be less relevant if VLM processor/collation handles padding.")

    return sequence[:max_length] + [pad_value] * max(0, max_length - len(sequence))


# --- Main Preprocessing Logic (Focus on creating split files) ---
def run_preprocessing(args):
    if args.dataset != 'LEVIR_MCI':
        raise ValueError(f"Dataset '{args.dataset}' not supported by this script.")

    input_captions_json = os.path.join(DATASET_ROOT, 'LevirCCcaptions.json')
    input_image_dir = os.path.join(DATASET_ROOT, 'images')
    # Use the hardcoded save directory for output splits
    output_splits_dir = SAVE_OUTPUT_DIR

    print(f"Using LEVIR_MCI dataset from: {DATASET_ROOT}")
    print(f"Saving split files to: {output_splits_dir}")

    os.makedirs(output_splits_dir, exist_ok=True)
    # No need to create token_save_dir if not saving intermediate tokens

    print('Loading captions from JSON to determine splits...')
    try:
        with open(input_captions_json, 'r') as f:
            data = json.load(f)['images']
    except FileNotFoundError:
        print(f"Error: Caption file not found at {input_captions_json}")
        sys.exit(1)
    except KeyError:
        print(f"Error: JSON file {input_captions_json} does not have the expected 'images' key.")
        sys.exit(1)

    split_filenames = {'train': [], 'val': [], 'test': []}
    image_check_counts = {'train': 0, 'val': 0, 'test': 0}
    missing_image_counts = {'train': 0, 'val': 0, 'test': 0}

    for img_info in data:
        filename = img_info['filename']
        filepath = img_info['filepath'] # train, val, or test
        imgid = img_info['imgid']

        if filepath not in split_filenames:
            print(f"Warning: Unknown filepath '{filepath}' found for {filename}. Skipping.")
            continue

        image_check_counts[filepath] += 1
        # Check if corresponding image files exist (essential check)
        img_a_path = os.path.join(input_image_dir, filepath, 'A', filename)
        img_b_path = os.path.join(input_image_dir, filepath, 'B', filename)
        # Also check label exists if segmentation is required downstream
        label_path = os.path.join(input_image_dir, filepath, 'label', filename) # Check 'label' folder

        # Ensure segmentation mask exists as it's crucial for one task
        if not os.path.exists(img_a_path) or not os.path.exists(img_b_path) or not os.path.exists(label_path):
             # print(f"Warning: Image pair or label not found for {filename} in {filepath}. Skipping this entry for split files.")
             missing_image_counts[filepath] += 1
             continue

        # Store base name without extension in split file list
        split_filenames[filepath].append(os.path.splitext(filename)[0])

    print("Image existence check complete.")
    for split in ['train', 'val', 'test']:
        print(f"Split '{split}': Found {len(split_filenames[split])} valid entries out of {image_check_counts[split]} checked ({missing_image_counts[split]} missing files).\n")

    # Save train/val/test split BASE filenames (without extension)
    for split, filenames_base in split_filenames.items():
        if not filenames_base:
             print(f"Warning: No valid entries found for split '{split}'. Split file will be empty.")

        split_file_path = os.path.join(output_splits_dir, f'{split}.txt')
        print(f"Saving {split} image list to {split_file_path} ({len(filenames_base)} images)")
        with open(split_file_path, 'w') as f:
            for fname_base in filenames_base:
                 f.write(f"{fname_base}\n")

    # --- Skip Vocab Generation and Token Saving ---
    print("Skipping vocabulary building and individual token saving (using VLM processor instead).\n")
    # Create a dummy vocab file just to satisfy potential checks in Trainer init (if not removed)
    dummy_vocab_path = os.path.join(output_splits_dir, 'vocab.json')
    with open(dummy_vocab_path, 'w') as f:
         json.dump({'<VLM_TOKENIZER>': 0}, f)
    print(f"Created dummy vocab file at {dummy_vocab_path}\n")


    print('Preprocessing (split file generation) finished.\n')


# =============================================================================
# --- training.py content adapted for VLM and Accelerate ---
# =============================================================================

# Helper function for mIoU (remains the same)
def calculate_iou(pred, target, num_classes):
    """Calculates Intersection over Union (IoU) per class."""
    ious = []
    # Ensure pred is class indices
    if pred.shape[1] == num_classes: # If logits (B, C, H, W)
        pred = torch.argmax(pred, dim=1) # Convert logits to class IDs (B, H, W)
    elif pred.ndim != 3: # Should be (B, H, W)
         raise ValueError(f"Unexpected prediction shape for IoU: {pred.shape}")

    pred = pred.view(-1)
    target = target.view(-1)

    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds[target_inds]).long().sum().item()
        # Cast to float before calculating union to avoid overflow with large tensors
        union = pred_inds.long().sum().float().item() + target_inds.long().sum().float().item() - intersection
        if union == 0:
            ious.append(0.0) # Or float('nan') ? Let's use 0 for simplicity
        else:
            ious.append(float(intersection) / float(max(union, 1e-6))) # Avoid division by zero
    return np.array(ious)


class Trainer:
    def __init__(self, args):
        self.args = args
        self.config = args # Use args directly as config

        # ---- Accelerator Setup ----
        # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) # May need if parts of VLM aren't used
        self.accelerator = Accelerator(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            mixed_precision=args.mixed_precision, # 'fp16', 'bf16', or 'no'
            log_with="tensorboard", # Or "wandb", "none"
            project_dir=os.path.join(SAVE_OUTPUT_DIR, 'accelerate_logs')
            # kwargs_handlers=[ddp_kwargs]
        )
        set_seed(args.seed)

        # Log args
        self.accelerator.print(f"Accelerator config: {self.accelerator.state}\n")
        self.accelerator.print(f"Running with args: {args}\n")

        # ---- Create Directories ----
        self.run_save_dir = os.path.join(SAVE_OUTPUT_DIR, 'training_output')
        if self.accelerator.is_main_process:
            os.makedirs(self.run_save_dir, exist_ok=True)
            self.accelerator.init_trackers(args.run_name or "vlm_change_detection") # Init TensorBoard/WandB
        self.accelerator.wait_for_everyone() # Ensure dir exists before anyone tries to write
        self.latest_checkpoint_path = os.path.join(self.run_save_dir, 'checkpoint_latest') # Accelerate saves folders

        # ---- VLM Processor ----
        self.accelerator.print("Loading VLM Processor...\n")
        self.processor = PaliGemmaProcessor.from_pretrained(args.model_name_or_path)
        # Set pad token if processor doesn't have one (shouldn't be needed for PaliGemma)
        # if self.processor.tokenizer.pad_token is None:
        #     self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

        # ---- Dataset & Dataloaders ----
        self.accelerator.print("Loading Datasets...\n")
        self.train_dataset = LEVIRCCDataset(
            data_folder=DATASET_ROOT,
            processed_data_dir=SAVE_OUTPUT_DIR, # Where train.txt etc. are
            split='train',
            processor=self.processor,
            load_segmentation=True, # Ensure this matches your intended use
            max_length=args.max_length,
            vqa_prompt=args.vqa_prompt
        )
        self.val_dataset = LEVIRCCDataset(
            data_folder=DATASET_ROOT,
            processed_data_dir=SAVE_OUTPUT_DIR,
            split='val',
            processor=self.processor,
            load_segmentation=True, # Ensure this matches your intended use
            max_length=args.max_length,
            vqa_prompt=args.vqa_prompt
        )

        if len(self.train_dataset) == 0 or len(self.val_dataset) == 0:
             raise ValueError("One or both datasets are empty. Check paths and preprocessing.")

        # Custom collate function for padding labels
        # Use the collate_fn defined outside __init__ to access self.processor
        self.train_loader = DataLoader(
            self.train_dataset, batch_size=args.batch_size_per_gpu, shuffle=True,
            num_workers=args.num_workers, pin_memory=False,
            collate_fn=self.collate_fn # Use the custom collate function from the class
        )
        self.val_loader = DataLoader(
            self.val_dataset, batch_size=args.batch_size_per_gpu, shuffle=False,
            num_workers=args.num_workers, pin_memory=False,
            collate_fn=self.collate_fn # Use the custom collate function from the class
        )

        # ---- Model Initialization ----
        self.accelerator.print("Initializing Model...\n")
        # Quantization Config
        quantization_config = None
        if args.quantization == '4bit':
             quantization_config = BitsAndBytesConfig(
                 load_in_4bit=True,
                 bnb_4bit_compute_dtype=torch.bfloat16 # Or torch.float16
             )
        elif args.quantization == '8bit':
             quantization_config = BitsAndBytesConfig(load_in_8bit=True)

        # Instantiate the VLM-based model
        self.model = ChangeDetectionVLM(
             args,
             args.model_name_or_path,
             self.processor,
             num_classes=NUM_CLASSES,
             image_size=args.image_size, # Use VLM default (e.g., 224) or resize in dataset
             patch_size=self.get_model_patch_size(args.model_name_or_path), # Get patch size from config
             freeze_vlm_base=args.freeze_encoder, # Use freeze_encoder flag
             quantization_config=quantization_config
         )

        # ---- Optimizer ----
        # Filter parameters that require gradients
        params_to_optimize = filter(lambda p: p.requires_grad, self.model.parameters())
        self.optimizer = optim.AdamW(params_to_optimize, lr=args.learning_rate) # Use AdamW

        # ---- Loss Functions ----
        # Captioning loss is handled internally by the VLM when labels are provided
        # Segmentation Loss: Standard CE
        self.segmentation_criterion = nn.CrossEntropyLoss().to(self.accelerator.device)

        # ---- Prepare with Accelerator ----
        self.model, self.optimizer, self.train_loader, self.val_loader = self.accelerator.prepare(
            self.model, self.optimizer, self.train_loader, self.val_loader
        )

        # ---- BLEU Score Smoothing ----
        self.smooth_fn = SmoothingFunction().method1

        # ---- Track Best Metrics & State ----
        self.best_bleu = 0.0
        self.best_miou = 0.0
        self.start_epoch = 0
        self.global_step = 0

        # ---- Resume from Checkpoint (Using Accelerate) ----
        if os.path.exists(self.latest_checkpoint_path):
            self.accelerator.print(f"Resuming from checkpoint: {self.latest_checkpoint_path}\n")
            try:
                self.accelerator.load_state(self.latest_checkpoint_path)
                # Extract state if needed (Accelerate loads model/opt automatically)
                # Checkpoint structure might vary, need to inspect or save explicitly
                # Example: Load custom state saved with accelerator.save()
                custom_state_path = os.path.join(self.latest_checkpoint_path, "custom_state.pth")
                if os.path.exists(custom_state_path):
                     custom_state = torch.load(custom_state_path, map_location='cpu') # Load state to CPU first
                     self.start_epoch = custom_state.get('epoch', 0)
                     self.global_step = custom_state.get('global_step', 0)
                     self.best_bleu = custom_state.get('best_bleu', 0.0)
                     self.best_miou = custom_state.get('best_miou', 0.0)
                     self.accelerator.print(f"Resumed custom state: Epoch {self.start_epoch}, Step {self.global_step}, Best BLEU {self.best_bleu:.4f}, Best mIoU {self.best_miou:.4f}\n")
                else:
                     self.accelerator.print("Warning: Custom state file not found in checkpoint. Starting from epoch 0.\n")


            except Exception as e:
                self.accelerator.print(f"Error loading checkpoint with Accelerator: {e}\n")
                self.accelerator.print("Starting training from scratch.\n")
                self.start_epoch = 0
                self.global_step = 0
                self.best_bleu = 0.0
                self.best_miou = 0.0
        else:
            self.accelerator.print("No checkpoint found. Starting training from scratch.\n")

    def get_model_patch_size(self, model_name_or_path):
        """ Helper to get patch size from model config. """
        try:
            # Load only the config first
            config = PaliGemmaForConditionalGeneration.from_pretrained(model_name_or_path).config
            # Access patch size (path might vary depending on model)
            # For SigLIP (used in PaliGemma):
            patch_size = config.vision_config.patch_size
            self.accelerator.print(f"Detected patch size: {patch_size}\n")
            return patch_size
        except Exception as e:
            self.accelerator.print(f"Warning: Could not auto-detect patch size for {model_name_or_path}: {e}. Using default 14.\n")
            return 14 # Default patch size

    # Custom collate function for padding input_ids, attention_mask, and labels
    def collate_fn(self, batch):
        # Separate elements that need padding (input_ids, attention_mask, labels)
        input_ids = [item['input_ids'] for item in batch]
        attention_mask = [item['attention_mask'] for item in batch]
        labels = [item['labels'] for item in batch] # These are already masked

        # Pad input_ids, attention_mask, and labels simultaneously using the VLM tokenizer's pad method
        # This ensures they all have the same length after padding.
        # We need to handle labels carefully: pad with ignore_index or pad_token_id?
        # The VLM processor's pad method pads with pad_token_id (usually 0).
        # The VLM's loss function ignores ignore_index (-100).
        # So, we should pad labels with -100 to ensure padded parts are ignored by loss.
        # However, the processor's pad method doesn't directly support padding labels with a different value.
        # A common approach is to pad input_ids/attention_mask normally, and then pad labels separately
        # with the ignore_index.

        # Pad input_ids and attention_mask using the processor's tokenizer
        padded_inputs = self.processor.tokenizer.pad(
            {'input_ids': input_ids, 'attention_mask': attention_mask},
            padding='longest',
            return_tensors='pt',
            # pad_to_multiple_of=8 # Optional
        )

        padded_input_ids = padded_inputs['input_ids']
        padded_attention_mask = padded_inputs['attention_mask']

        # Pad labels manually with the ignore_index
        max_len = padded_input_ids.shape[1] # Get the padded length
        ignore_index_for_loss = -100 # Standard ignore index

        padded_labels = []
        for label_seq in labels:
            # Calculate padding needed
            padding_len = max_len - len(label_seq)
            if padding_len > 0:
                # Pad with ignore_index
                padded_label_seq = torch.cat([label_seq, torch.full((padding_len,), ignore_index_for_loss, dtype=label_seq.dtype)])
            else:
                padded_label_seq = label_seq[:max_len] # Truncate if somehow longer

            padded_labels.append(padded_label_seq)

        padded_labels = torch.stack(padded_labels)


        # Manually collect other items in the batch
        collated_batch = {
            'pixel_values_a': torch.stack([item['pixel_values_a'] for item in batch]),
            'pixel_values_b': torch.stack([item['pixel_values_b'] for item in batch]),
            'input_ids': padded_input_ids,
            'attention_mask': padded_attention_mask,
            'labels': padded_labels, # Use the correctly padded labels
            'name': [item['name'] for item in batch], # Collect names as a list
            'reference_captions': [item['reference_captions'] for item in batch], # Collect reference captions as a list of lists
        }

        # Include seg_mask if loading segmentation - Access from self.args
        if self.args.load_segmentation:
             # Assuming seg_masks are already consistent size or can be stacked
             # If not, they would also need padding/resizing here
             collated_batch['seg_mask'] = torch.stack([item['seg_mask'] for item in batch])


        return collated_batch


    def train_epoch(self, epoch):
        self.model.train()
        total_combined_loss = 0
        total_cap_loss = 0
        total_seg_loss = 0
        num_batches = len(self.train_loader)

        progress_bar = tqdm(total=num_batches, desc=f"Epoch {epoch+1}/{self.args.epochs} [Train]", disable=not self.accelerator.is_main_process)

        for step, batch in enumerate(self.train_loader):
            with self.accelerator.accumulate(self.model): # Handles gradient accumulation
                # Forward pass (model takes care of internal logic)
                # Ensure all required inputs from the batch are passed
                outputs = self.model(
                    pixel_values_a=batch['pixel_values_a'],
                    pixel_values_b=batch['pixel_values_b'],
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels'], # Pass labels for caption loss calculation
                    mode='train'
                )

                # ---- Loss Calculation ----
                cap_loss = outputs['caption_loss'] # Loss directly from VLM output
                seg_logits = outputs['seg_logits']
                seg_mask_gt = batch['seg_mask'] # Ground truth seg masks (B, H, W)

                # Resize seg_mask_gt if it doesn't match seg_logits size (e.g., due to VLM input size != dataset image size)
                if seg_logits.shape[-2:] != seg_mask_gt.shape[-2:]:
                     # Ensure seg_mask_gt is float for interpolation if needed, then back to long
                     seg_mask_gt_resized = F.interpolate(seg_mask_gt.unsqueeze(1).float(), size=seg_logits.shape[-2:], mode='nearest').squeeze(1).long()
                else:
                     seg_mask_gt_resized = seg_mask_gt

                # Segmentation Loss (ensure GT mask is on the same device)
                seg_loss = self.segmentation_criterion(seg_logits, seg_mask_gt_resized)

                # Combine Losses
                combined_loss = cap_loss + self.args.seg_loss_weight * seg_loss

                # --- Accumulate and Log Loss ---
                # Average loss across devices and accumulation steps
                avg_combined_loss = self.accelerator.gather(combined_loss).mean().item()
                avg_cap_loss = self.accelerator.gather(cap_loss).mean().item()
                avg_seg_loss = self.accelerator.gather(seg_loss).mean().item()

                total_combined_loss += avg_combined_loss
                total_cap_loss += avg_cap_loss
                total_seg_loss += avg_seg_loss

                # Logging (only on main process)
                if self.accelerator.is_main_process:
                    progress_bar.set_postfix({
                        'CapL': f'{avg_cap_loss:.3f}',
                        'SegL': f'{avg_seg_loss:.3f}',
                        'CombL': f'{avg_combined_loss:.3f}'
                    })
                    # Log to TensorBoard/WandB
                    self.accelerator.log({
                        "train/loss_combined": avg_combined_loss,
                        "train/loss_caption": avg_cap_loss,
                        "train/loss_segmentation": avg_seg_loss,
                    }, step=self.global_step)


                # ---- Backward Pass & Optimization ----
                self.accelerator.backward(combined_loss)

                # Gradient Clipping (applied before optimizer step by accelerator if configured)
                if self.accelerator.sync_gradients and self.args.grad_clip > 0:
                    self.accelerator.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)

                self.optimizer.step()
                self.optimizer.zero_grad()

                self.global_step += 1
                progress_bar.update(1)


        progress_bar.close()
        # --- Print Average Loss at End of Epoch ---
        avg_epoch_combined_loss = total_combined_loss / num_batches
        avg_epoch_cap_loss = total_cap_loss / num_batches
        avg_epoch_seg_loss = total_seg_loss / num_batches # Corrected: Should be num_batches from the loader
        self.accelerator.print(f"Epoch {epoch+1} Train Avg Loss -> Combined: {avg_epoch_combined_loss:.4f} (Cap: {avg_epoch_cap_loss:.4f}, Seg: {avg_epoch_seg_loss:.4f})\n")
        # Log epoch averages
        if self.accelerator.is_main_process:
             self.accelerator.log({
                 "train/epoch_loss_combined": avg_epoch_combined_loss,
                 "train/epoch_loss_caption": avg_epoch_cap_loss,
                 "train/epoch_loss_segmentation": avg_epoch_seg_loss,
             }, step=epoch+1) # Log against epoch number

        return avg_epoch_combined_loss


    def validate_epoch(self, epoch):
        self.model.eval()
        references_corpus = [] # List of lists of reference tokens (strings) for BLEU
        hypotheses_corpus = [] # List of hypothesis tokens (strings) for BLEU
        total_iou = np.zeros(NUM_CLASSES)
        num_val_batches = 0

        val_progress_bar = tqdm(total=len(self.val_loader), desc=f"Epoch {epoch+1}/{self.args.epochs} [Validate]", disable=not self.accelerator.is_main_process)

        with torch.no_grad():
            for step, batch in enumerate(self.val_loader):
                # Forward pass for segmentation logits
                outputs = self.model(
                    pixel_values_a=batch['pixel_values_a'],
                    pixel_values_b=batch['pixel_values_b'],
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    # No labels in eval mode
                    mode='eval'
                )
                seg_logits = outputs['seg_logits']
                seg_mask_gt = batch['seg_mask']

                # --- Perform Generation for Captions ---
                # Unwrap model for generation if needed, or use accelerator.unwrap_model
                unwrapped_model = self.accelerator.unwrap_model(self.model)
                # Prepare inputs for generation (only need prompt, images)
                # Use generate method defined in ChangeDetectionVLM
                # For generation, we only pass the prompt and image B
                prompt_only_inputs = self.processor(
                    text=self.args.vqa_prompt,
                    images=batch['pixel_values_b'], # Use image B's pixel values
                    return_tensors="pt",
                    # No padding needed for generation input
                ).to(self.accelerator.device) # Move to device

                generated_ids = unwrapped_model.generate(
                     pixel_values=prompt_only_inputs['pixel_values'],
                     input_ids=prompt_only_inputs['input_ids'],
                     attention_mask=prompt_only_inputs['attention_mask'],
                     # Add specific generation args if needed
                     max_length=self.args.max_length, # Use max_length arg for generation
                     num_beams=3, # Example beam search
                     early_stopping=True,
                 ) # Shape (B, generated_seq_len)

                # --- Gather tensors across devices ---
                # Gather segmentation results
                all_seg_logits = self.accelerator.gather_for_metrics(seg_logits)
                all_seg_mask_gt = self.accelerator.gather_for_metrics(seg_mask_gt)
                # Gather generated captions and references
                # Pad generated_ids before gathering to ensure consistent shape
                all_generated_ids = self.accelerator.pad_across_processes(generated_ids, dim=1, pad_index=self.processor.tokenizer.pad_token_id)
                all_generated_ids = self.accelerator.gather_for_metrics(all_generated_ids)
                # References are lists of strings, gather them manually if needed (tricky with DDP)
                # For simplicity, calculate BLEU only on the main process using its part of the data,
                # OR gather raw reference strings (requires careful handling)
                # Let's decode and collect on each process, then gather decoded strings (simpler)

                # Decode generated captions
                decoded_preds = self.processor.batch_decode(all_generated_ids, skip_special_tokens=True)

                # Get reference captions from the batch (list of lists of strings)
                # Need to handle gathering this across processes if calculating global BLEU
                # For now, let's collect references corresponding to the predictions on this process
                raw_references_batch = batch['reference_captions'] # This is on the current process's batch slice

                # --- Calculate Final Metrics (on main process after gathering) ---
                if self.accelerator.is_main_process:
                    # mIoU Calculation
                    # Resize GT mask if needed
                    if all_seg_logits.shape[-2:] != all_seg_mask_gt.shape[-2:]:
                         all_seg_mask_gt_resized = F.interpolate(all_seg_mask_gt.unsqueeze(1).float(), size=all_seg_logits.shape[-2:], mode='nearest').squeeze(1).long()
                    else:
                         all_seg_mask_gt_resized = all_seg_mask_gt

                    batch_iou = calculate_iou(all_seg_logits.cpu(), all_seg_mask_gt_resized.cpu(), NUM_CLASSES)
                    total_iou += batch_iou
                    num_val_batches += 1 # Count batches on main process

                    # Prepare for BLEU Score
                    # References need to be list of lists of strings for each hypothesis
                    # `raw_references_batch` needs to be structured correctly
                    # Assume raw_references_batch is [ [ref1_img1, ref2_img1,...], [ref1_img2, ...], ...]
                    # We need [[ref1_img1_tokens, ref2_img1_tokens,...], [ref1_img2_tokens,...], ...]\

                    batch_references_for_bleu = []
                    batch_hypotheses_for_bleu = []

                    # Iterate through predictions and corresponding references in the gathered batch
                    for pred_text, list_of_ref_texts in zip(decoded_preds, raw_references_batch):
                        # Tokenize hypothesis (simple split for BLEU)
                        hyp_tokens = pred_text.split()
                        # Tokenize all references for this item
                        ref_tokens_list = [ref.split() for ref in list_of_ref_texts]

                        if hyp_tokens and any(ref_tokens_list): # Ensure not empty
                             batch_hypotheses_for_bleu.append(hyp_tokens)
                             batch_references_for_bleu.append(ref_tokens_list)


                    # Extend the main corpus lists
                    hypotheses_corpus.extend(batch_hypotheses_for_bleu)
                    references_corpus.extend(batch_references_for_bleu)

                val_progress_bar.update(1)

        val_progress_bar.close()

        # ---- Calculate Final Metrics (on main process) ----
        mean_iou = 0.0
        bleu_score = 0.0
        if self.accelerator.is_main_process:
            # Final mIoU
            mean_iou = np.mean(total_iou / num_val_batches) if num_val_batches > 0 else 0.0

            # Final BLEU Score
            if not references_corpus or not hypotheses_corpus:
                 print("Warning: No valid references or hypotheses collected for BLEU score.\n")
            elif len(references_corpus) != len(hypotheses_corpus):
                  print(f"Warning: Mismatch in reference ({len(references_corpus)}) and hypothesis ({len(hypotheses_corpus)}) counts for BLEU.\n")
            else:
                try:
                     bleu_score = corpus_bleu(references_corpus, hypotheses_corpus, smoothing_function=self.smooth_fn)
                except Exception as e:
                     print(f"Error calculating BLEU score: {e}\n")

            self.accelerator.print(f"Epoch {epoch+1} Validation -> BLEU-4: {bleu_score:.4f}, mIoU: {mean_iou:.4f}\n")
            # Log validation metrics
            self.accelerator.log({
                "eval/bleu4": bleu_score,
                "eval/mIoU": mean_iou,
            }, step=epoch+1) # Log against epoch number

        # Return metrics (gathered on main process)
        return bleu_score, mean_iou


    def run_training(self):
        self.accelerator.print("Starting Training...\n")
        self.accelerator.print(f"Total training steps: {len(self.train_loader) // self.accelerator.gradient_accumulation_steps * self.args.epochs}\n")

        for epoch in range(self.start_epoch, self.args.epochs):
            self.accelerator.print(f"--- Starting Epoch {epoch+1}/{self.args.epochs} ---\n")
            train_loss = self.train_epoch(epoch)
            val_bleu, val_miou = self.validate_epoch(epoch)

            # ---- Save Checkpoint (using Accelerator) ----
            if self.accelerator.is_main_process:
                is_best_bleu = val_bleu > self.best_bleu
                if is_best_bleu:
                    self.best_bleu = val_bleu
                    print(f"*** New best BLEU score: {self.best_bleu:.4f} ***\n")

                is_best_miou = val_miou > self.best_miou
                if is_best_miou:
                     self.best_miou = val_miou
                     print(f"*** New best mIoU score: {self.best_miou:.4f} ***\n")

                # Save latest checkpoint using Accelerator
                self.accelerator.save_state(self.latest_checkpoint_path)

                # Save custom state separately within the checkpoint folder
                custom_state = {
                     'epoch': epoch + 1,
                     'global_step': self.global_step,
                     'args': self.args,
                     'best_bleu': self.best_bleu,
                     'best_miou': self.best_miou,
                 }
                custom_state_path = os.path.join(self.latest_checkpoint_path, "custom_state.pth")
                self.accelerator.save(custom_state, custom_state_path)
                print(f"Saved latest checkpoint state to {self.latest_checkpoint_path}\n")


                # Save best BLEU checkpoint separately if desired
                if is_best_bleu:
                    best_bleu_path = os.path.join(self.run_save_dir, 'checkpoint_best_bleu')
                    self.accelerator.save_state(best_bleu_path)
                    # Also save custom state in best checkpoint dir
                    best_custom_state_path = os.path.join(best_bleu_path, "custom_state.pth")
                    self.accelerator.save(custom_state, best_custom_state_path)
                    print(f"Saved best BLEU checkpoint state to {best_bleu_path}\n")

            # Wait for all processes to finish saving/loading before next epoch
            self.accelerator.wait_for_everyone()


        self.accelerator.print("Training finished.\n")
        if self.accelerator.is_main_process:
             self.accelerator.print(f"Best Validation BLEU-4 achieved: {self.best_bleu:.4f}\n")
             self.accelerator.print(f"Best Validation mIoU achieved: {self.best_miou:.4f}\n")
             self.accelerator.end_training() # Clean up trackers

# =============================================================================
# --- VQA Inference Function ---
# =============================================================================
@torch.no_grad()
def answer_question(model, processor, img_path_a, img_path_b, question, device):
    """
    Performs VQA inference on a pair of images.

    Args:
        model: The fine-tuned ChangeDetectionVLM model (unwrapped).
        processor: The VLM processor.
        img_path_a (str): Path to the 'before' image.
        img_path_b (str): Path to the 'after' image.
        question (str): The natural language question.
        device: The torch device to run inference on.

    Returns:
        str: The generated answer string.
    """
    model.eval()

    try:
        imgA_pil = Image.open(img_path_a).convert("RGB")
        imgB_pil = Image.open(img_path_b).convert("RGB")
    except Exception as e:
        return f"Error loading images: {e}"

    # Process image B and the question for VQA
    # The processor will add the <image> token before the text
    try:
        # Add the explicit <image> token to the question for inference
        vlm_inputs = processor(
            text="<image>" + question,
            images=imgB_pil, # Use image B for VQA input
            return_tensors="pt",
        ).to(device) # Move to device

    except Exception as e:
        return f"Error processing inputs: {e}"

    # Generate answer using the model's generate method
    try:
        generated_ids = model.generate(
            pixel_values=vlm_inputs['pixel_values'],
            input_ids=vlm_inputs['input_ids'],
            attention_mask=vlm_inputs['attention_mask'],
            # Add generation kwargs if needed (e.g., max_length)
            max_length=128 # Set a reasonable max length for answers
        )
        # Decode the generated tokens
        answer = processor.decode(generated_ids[0], skip_special_tokens=True)
        return answer
    except Exception as e:
        return f"Error during generation: {e}"


# =============================================================================
# --- Main Execution Block ---
# =============================================================================

if __name__ == '__main__':
    # Set the environment variable for memory management early
    import os
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

    parser = argparse.ArgumentParser(description='Train or Evaluate VLM for Change Detection VQA')

    # --- Paths and Basic Config ---
    parser.add_argument('--dataset', type=str, default='LEVIR_MCI', help='Dataset name')
    parser.add_argument('--run_name', type=str, default=None, help='Optional run name for logging')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')

    # --- Model Config ---
    parser.add_argument('--model_name_or_path', type=str, default=DEFAULT_VLM, help='Pre-trained VLM name/path')
    parser.add_argument('--freeze_encoder', action='store_true', help='Freeze VLM vision tower and base LLM layers')
    # Set quantization default to 4bit for better memory usage
    parser.add_argument('--quantization', type=str, default='4bit', choices=['no', '4bit', '8bit'], help='Apply quantization (4bit or 8bit)')
    # Set gradient checkpointing default to True for better memory usage
    parser.add_argument('--gradient_checkpointing', action='store_true', default=True, help='Enable gradient checkpointing to save memory')


    # --- Dataset and Preprocessing ---
    parser.add_argument('--max_length', type=int, default=64, help='Maximum sequence length for VLM tokenizer')
    parser.add_argument('--image_size', type=int, default=224, help='Image size (must match VLM input size)')
    parser.add_argument('--vqa_prompt', type=str, default="Describe the changes between the two images.", help='Prompt for captioning/VQA')
    # Preprocessing args (less critical now but kept for consistency)
    parser.add_argument('--word_count_threshold', default=5, type=int, help='Min word count (less relevant)')
    # Add load_segmentation to args parser so it's available in self.args
    parser.add_argument('--load_segmentation', action='store_true', default=True, help='Whether to load segmentation masks')


    # --- Training Config (using Accelerate) ---
    parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
    # Reduce default batch size and increase accumulation steps for better memory usage
    parser.add_argument('--batch_size_per_gpu', type=int, default=1, help='Batch size per GPU (adjust based on memory)')
    parser.add_argument('--num_workers', type=int, default=2, help='Number of dataloader workers')
    parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate for AdamW optimizer')
    parser.add_argument('--grad_clip', type=float, default=1.0, help='Max gradient norm for clipping (0 for no clipping)')
    parser.add_argument('--seg_loss_weight', type=float, default=1.0, help='Weight for the segmentation loss term')
    # Increase default gradient accumulation steps to compensate for smaller batch size
    parser.add_argument('--gradient_accumulation_steps', type=int, default=8, help='Steps for gradient accumulation')
    parser.add_argument('--mixed_precision', type=str, default='bf16', choices=['no', 'fp16', 'bf16'], help='Mixed precision type')

    # --- Execution Mode ---
    parser.add_argument('--run_mode', type=str, default='train', choices=['preprocess', 'train', 'evaluate', 'vqa'],
                        help='Pipeline mode: preprocess, train, evaluate (on val set), or vqa (single example)')

    # --- VQA Mode Specific Args ---
    parser.add_argument('--vqa_img_a', type=str, default=None, help='Path to image A for VQA mode')
    parser.add_argument('--vqa_img_b', type=str, default=None, help='Path to image B for VQA mode')
    parser.add_argument('--vqa_question', type=str, default="What has changed?", help='Question for VQA mode')
    parser.add_argument('--vqa_checkpoint', type=str, default=None, help='Path to checkpoint folder for VQA/evaluate mode (uses latest if not specified)')


    # In a notebook, parse default args or provide specific ones
    # Example: args = parser.parse_args(['--run_mode', 'preprocess'])
    # Example: args = parser.parse_args(['--run_mode', 'train', '--batch_size_per_gpu', '2', '--gradient_accumulation_steps', '4', '--quantization', '4bit', '--gradient_checkpointing'])\n    # Example: args = parser.parse_args(['--run_mode', 'vqa', '--vqa_img_a', 'path/to/imgA.png', '--vqa_img_b', 'path/to/imgB.png', '--vqa_question', 'How many buildings were added?'])
    if 'ipykernel' in sys.modules:
        # This block allows running in a notebook without passing command-line args
        # Use a minimal set of default args for notebook execution
        # You can override these in a cell before this block if needed
        args = parser.parse_args([]) # Parse no arguments, use defaults
    else:
        # Parse command-line arguments if not in a notebook
        args = parser.parse_args()


    # --- Execute Based on Mode ---
    trainer = None # Initialize trainer variable

    if args.run_mode == 'preprocess':
        print("--- Running Preprocessing ---\n")
        run_preprocessing(args)
        print("--- Preprocessing Finished ---\n")

    elif args.run_mode == 'train':
        print("--- Running Training ---\n")
        # Ensure preprocessing ran (check for split files)
        if not os.path.exists(os.path.join(SAVE_OUTPUT_DIR, 'train.txt')):
             print("Split files not found. Running preprocessing first.\n")
             run_preprocessing(args)
             print("-" * 20 + "\n") # Separator

        trainer = Trainer(args)
        trainer.run_training()
        print("--- Training Finished ---\n")

    elif args.run_mode == 'evaluate':
        print("--- Running Evaluation ---\n")
        # Requires a trained model checkpoint
        trainer = Trainer(args) # Initialize trainer to load model and data
        checkpoint_to_load = args.vqa_checkpoint or trainer.latest_checkpoint_path
        if not os.path.exists(checkpoint_to_load):
             print(f"Error: Checkpoint not found at {checkpoint_to_load}. Cannot evaluate.\n")
        else:
             print(f"Loading state from: {checkpoint_to_load}\n")
             trainer.accelerator.load_state(checkpoint_to_load)
             print("Evaluating on validation set...\n")
             # Use epoch=-1 or similar to indicate it's a final eval run
             val_bleu, val_miou = trainer.validate_epoch(epoch=-1)
             print(f"--- Evaluation Finished ---\n")
             print(f"Validation BLEU-4: {val_bleu:.4f}\n")
             print(f"Validation mIoU: {val_miou:.4f}\n")

    elif args.run_mode == 'vqa':
        print("--- Running VQA Inference ---\n")
        if not args.vqa_img_a or not args.vqa_img_b:
             print("Error: Please provide paths to both image A and image B using --vqa_img_a and --vqa_img_b\n")
        else:
            # Initialize minimal components needed for inference
            # Need args to initialize the model correctly (e.g., quantization)
            # Create a dummy args object or parse minimal args
            minimal_parser = argparse.ArgumentParser()
            minimal_parser.add_argument('--model_name_or_path', type=str, default=DEFAULT_VLM)
            minimal_parser.add_argument('--quantization', type=str, default='no', choices=['no', '4bit', '8bit'])
            minimal_parser.add_argument('--freeze_encoder', action='store_true', default=False) # Inference doesn't freeze
            minimal_parser.add_argument('--gradient_checkpointing', action='store_true', default=False) # Inference doesn't need checkpointing
            minimal_parser.add_argument('--max_length', type=int, default=128) # Add max_length for generation
            # Parse only known args to avoid errors with other args present in sys.argv in notebooks
            minimal_args, _ = minimal_parser.parse_known_args(sys.argv[1:] if 'ipykernel' not in sys.modules else [])


            accelerator = Accelerator(mixed_precision=minimal_args.mixed_precision) # Use mixed precision if available
            processor = PaliGemmaProcessor.from_pretrained(minimal_args.model_name_or_path)

            # Quantization Config
            quantization_config = None
            if minimal_args.quantization == '4bit':
                 quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
            elif minimal_args.quantization == '8bit':
                 quantization_config = BitsAndBytesConfig(load_in_8bit=True)

            # Load model (without full Trainer setup)
            model = ChangeDetectionVLM(
                 minimal_args, # Pass minimal args
                 minimal_args.model_name_or_path,
                 processor,
                 quantization_config=quantization_config,
                 freeze_vlm_base=False # No need to freeze for inference
             )
            model = accelerator.prepare(model) # Prepare model for device

            # Load checkpoint
            checkpoint_to_load = args.vqa_checkpoint or os.path.join(SAVE_OUTPUT_DIR, 'training_output', 'checkpoint_latest')
            if not os.path.exists(checkpoint_to_load):
                 print(f"Warning: Checkpoint {checkpoint_to_load} not found. Using pre-trained weights.\n")
            else:
                 print(f"Loading model state from: {checkpoint_to_load}\n")
                 # Need to load state dict manually if not using accelerator.load_state
                 try:
                     # Assumes accelerate saved state_dict in 'pytorch_model.bin' or similar
                     # This might need adjustment based on how accelerate saves
                     state_dict_path = os.path.join(checkpoint_to_load, "pytorch_model.bin") # Common path
                     if not os.path.exists(state_dict_path):
                          # Try custom state path if main one doesn't exist (less likely for model state)
                          # Fallback or specific filename needed here based on saving method.
                          raise FileNotFoundError("Could not find model state dict file.")

                     # Load state dict, handling potential DDP prefix
                     state_dict = torch.load(state_dict_path, map_location='cpu')
                     unwrapped_model = accelerator.unwrap_model(model)
                     # Remove 'module.' prefix if present
                     new_state_dict = {}
                     for k, v in state_dict.items():
                         if k.startswith('module.'):
                             new_state_dict[k[7:]] = v
                         else:
                             new_state_dict[k] = v
                     unwrapped_model.load_state_dict(new_state_dict) # Load into unwrapped model
                     print("Model state loaded successfully.\n")
                 except Exception as e:
                      print(f"Error loading model state from checkpoint {checkpoint_to_load}: {e}. Using pre-trained weights.\n")


            # Perform VQA
            print(f"\nImage A: {args.vqa_img_a}\n")
            print(f"\nImage B: {args.vqa_img_b}\n")
            print(f"\nQuestion: {args.vqa_question}\n")

            answer = answer_question(
                accelerator.unwrap_model(model), # Pass unwrapped model
                processor,
                args.vqa_img_a,
                args.vqa_img_b,
                args.vqa_question,
                accelerator.device
            )

            print(f"\nAnswer: {answer}\n")
            print("--- VQA Finished ---\n")

    else:
        print(f"Unknown run_mode: {args.run_mode}\n")
