In [None]:
# =============================================================================
# Separate Models for Temporal Image Segmentation and Captioning (Improved)
# This script implements improved models and training procedures
# for the LEVIR-MCI dataset, leveraging pre-trained backbones, U-Net,
# and Transformer architectures.
#
# Improvements:
# - Replaced simple CNN encoder with a pre-trained ResNet backbone.
# - Implemented a U-Net decoder for segmentation.
# - Implemented a Transformer decoder for captioning.
# - Added Learning Rate Scheduling (ReduceLROnPlateau).
# - Implemented Beam Search decoding for captioning validation.
# - Adjusted for Kaggle file paths and uses torch.nn.DataParallel if available.
# - Automatically runs preprocessing if output files are not found.
# - Includes weighted loss for segmentation.
# =============================================================================

import sys
import os
import json
import numpy as np
from collections import defaultdict, Counter
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F
import math
from imageio.v2 import imread # Make sure imageio is installed
from random import randint
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from tqdm import tqdm
import time # Import time for timing epochs
import torchvision.models as models # For pre-trained models
from torch.optim.lr_scheduler import ReduceLROnPlateau # For LR scheduling
import heapq # For Beam Search

# =============================================================================
# --- Hardcoded Paths for Kaggle ---\
# IMPORTANT: Update DATASET_ROOT if necessary.
# =============================================================================
DATASET_ROOT = '/kaggle/input/levir-mci-dataset/LEVIR-MCI-dataset' # <-- *** VERIFY THIS PATH ***
SAVE_OUTPUT_DIR = '/kaggle/working/' # Standard Kaggle writable output directory

# =============================================================================
# --- Utility Functions ---\
# =============================================================================

# Define the mapping from RGB colors to class IDs
# Background: 0 (black), Road: 1 (grey), Building: 2 (white)
COLOR_TO_ID_MAPPING = {
    (0, 0, 0): 0,          # Background (Black)
    (128, 128, 128): 1,    # Road (Grey)
    (255, 255, 255): 2,    # Building (White)
}
NUM_CLASSES = 3 # Background, Road, Building

def rgb_to_class_id_mask(rgb_mask_np):
    """Converts an RGB mask (H, W, 3) to a class ID mask (H, W)."""
    h, w, c = rgb_mask_np.shape
    if c != 3:
        if c == 1 or rgb_mask_np.ndim == 2:
             class_id_mask = np.full((h, w), 0, dtype=np.int64)
             if rgb_mask_np.ndim == 2:
                grey_mask = rgb_mask_np
             else:
                grey_mask = rgb_mask_np.squeeze(-1)
             class_id_mask[grey_mask == 0] = 0
             class_id_mask[grey_mask == 128] = 1
             class_id_mask[grey_mask == 255] = 2
             return class_id_mask
        else:
            raise ValueError(f"Input mask must have 3 channels (RGB), but got {c} with shape {rgb_mask_np.shape}")

    class_id_mask = np.full((h, w), 0, dtype=np.int64)
    for color, class_id in COLOR_TO_ID_MAPPING.items():
        matches = np.all(rgb_mask_np == np.array(color, dtype=rgb_mask_np.dtype), axis=-1)
        class_id_mask[matches] = class_id
    return class_id_mask

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count if self.count > 0 else 0

def calculate_iou(predicted_mask, true_mask, model_or_num_classes):
    """
    Calculates Intersection over Union (IoU) for a batch of masks.
    Handles both direct num_classes or a model object (DataParallel or not).
    """
    if isinstance(model_or_num_classes, nn.Module):
        model_ref = model_or_num_classes.module if isinstance(model_or_num_classes, nn.DataParallel) else model_or_num_classes
        # Try accessing num_classes directly or via a specific attribute if needed
        if hasattr(model_ref, 'num_classes'):
             num_classes = model_ref.num_classes
        elif hasattr(model_ref, 'decoder') and hasattr(model_ref.decoder, 'num_classes'): # Check if it's nested in decoder
             num_classes = model_ref.decoder.num_classes
        else:
             # Fallback or raise error if num_classes cannot be determined
             print("Warning: Could not automatically determine num_classes from model. Falling back to NUM_CLASSES global.")
             num_classes = NUM_CLASSES # Use global as fallback
    else:
        num_classes = model_or_num_classes # Assume it's the integer num_classes

    iou_per_class = torch.zeros(num_classes, device=predicted_mask.device)
    for class_id in range(num_classes):
        predicted_class_mask = (predicted_mask == class_id)
        true_class_mask = (true_mask == class_id)

        intersection = (predicted_class_mask & true_class_mask).sum().float()
        union = (predicted_class_mask | true_class_mask).sum().float()

        if union == 0:
            iou_per_class[class_id] = float('nan') # Avoid division by zero, handle later with nanmean
        else:
            iou_per_class[class_id] = intersection / union

    return iou_per_class

# =============================================================================
# --- Dataset Class ---\
# (Mostly reused, ensure paths and normalization are correct)
# =============================================================================

class LEVIRCCDataset(Dataset):
    def __init__(self, data_folder, processed_data_dir, split,
                 load_segmentation=True,
                 max_length=41,
                 vocab_file='vocab.json',
                 allow_unk=True,
                 max_iters=None):
        """
        Args:
            data_folder (str): Path to the root LEVIR-MCI dataset folder.
            processed_data_dir (str): Path to the folder with preprocessed data.
            split (str): 'train', 'val', or 'test'.
            load_segmentation (bool): Load segmentation maps.
            max_length (int): Max caption sequence length.
            vocab_file (str): Vocabulary JSON file name.
            allow_unk (bool): Allow unknown tokens.
            max_iters (int, optional): Repeat dataset for this many items per epoch.
        """
        self.data_folder = data_folder
        self.processed_data_dir = processed_data_dir
        self.split = split
        self.load_segmentation = load_segmentation
        self.max_length = max_length
        self.allow_unk = allow_unk

        # Image normalization parameters (ImageNet defaults)
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

        assert self.split in {'train', 'val', 'test'}

        # ---- Load Vocabulary ----
        vocab_path = os.path.join(self.processed_data_dir, vocab_file)
        try:
            with open(vocab_path, 'r') as f:
                self.word_vocab = json.load(f)
            self.idx_to_word = {v: k for k, v in self.word_vocab.items()}
            self.vocab_size = len(self.word_vocab)
            self.pad_idx = self.word_vocab.get('<NULL>', 0)
            self.start_idx = self.word_vocab.get('<START>', 2)
            self.end_idx = self.word_vocab.get('<END>', 3)
            self.unk_idx = self.word_vocab.get('<UNK>', 1)
        except FileNotFoundError:
            raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}. Run preprocessing.")

        # --- Load Image Filenames/Base Names ---
        split_file_path = os.path.join(self.processed_data_dir, f'{split}.txt')
        try:
            with open(split_file_path, 'r') as f:
                self.img_ids = [line.strip() for line in f if line.strip()]
        except FileNotFoundError:
             raise FileNotFoundError(f"Split file not found at {split_file_path}. Run preprocessing.")

        if not self.img_ids:
            raise ValueError(f"No image IDs found in split file: {split_file_path}")

        # ---- Prepare file paths ----
        self.files = []
        image_base_path = os.path.join(self.data_folder, 'images', self.split)
        token_base_path = os.path.join(self.processed_data_dir, 'tokens')
        label_folder_name = 'label' # Assuming 'label' contains the segmentation masks

        missing_files_count = 0
        for img_base_name in self.img_ids:
            img_file_name = f"{img_base_name}.png"
            token_file_name = f"{img_base_name}.json"

            file_paths = {
                "name": img_base_name,
                "imgA": os.path.join(image_base_path, 'A', img_file_name),
                "imgB": os.path.join(image_base_path, 'B', img_file_name),
                "token": os.path.join(token_base_path, token_file_name)
            }
            seg_path = None
            if self.load_segmentation:
                seg_path = os.path.join(image_base_path, label_folder_name, img_file_name)
                file_paths["seg_label"] = seg_path

            paths_to_check = [file_paths["imgA"], file_paths["imgB"], file_paths["token"]]
            if self.load_segmentation:
                 paths_to_check.append(seg_path)

            files_exist = all(os.path.exists(p) for p in paths_to_check if p is not None)

            if not files_exist:
                # Optionally print which file is missing
                # for p in paths_to_check:
                #     if p is not None and not os.path.exists(p):
                #         print(f"Missing file: {p}")
                missing_files_count += 1
                continue

            self.files.append(file_paths)

        if missing_files_count > 0:
            print(f"Warning: Skipped {missing_files_count} entries due to missing files in split '{self.split}'.")

        if not self.files:
             raise ValueError(f"No valid file sets found for split '{self.split}'. Check paths and preprocessing.")

        # --- Handle max_iters ---
        self.max_iters = max_iters
        if max_iters is not None and max_iters > 0:
            if not self.files:
                 raise ValueError(f"Cannot use max_iters > 0 when no valid files were loaded for split '{self.split}'.")
            n_repeat = int(np.ceil(max_iters / len(self.files)))
            self.files = self.files * n_repeat

        print(f"Initialized LEVIRCCDataset for split '{self.split}' with {len(self.files)} items (after potential repeat).")


    def __len__(self):
        if self.max_iters is not None and self.max_iters > 0:
            return self.max_iters
        return len(self.files)

    def __getitem__(self, index):
        actual_index = index % len(self.files) if self.files else 0
        if not self.files:
             raise IndexError("Dataset is empty.")
        datafiles = self.files[actual_index]

        # --- Load Images ---
        try:
            imgA_np = np.array(imread(datafiles["imgA"]), dtype=np.uint8)
            imgB_np = np.array(imread(datafiles["imgB"]), dtype=np.uint8)
            if imgA_np.ndim != 3 or imgA_np.shape[-1] != 3 or imgB_np.ndim != 3 or imgB_np.shape[-1] != 3:
                 raise ValueError(f"Image dimensions incorrect for {datafiles['name']}. Expected (H, W, 3), got {imgA_np.shape} and {imgB_np.shape}")
        except Exception as e:
             print(f"Error loading images for {datafiles['name']}: {e}")
             raise

        # Convert to float32, normalize, and HWC -> CHW
        imgA = (imgA_np.astype(np.float32) / 255.0 - self.mean) / self.std
        imgB = (imgB_np.astype(np.float32) / 255.0 - self.mean) / self.std
        imgA = torch.from_numpy(imgA.transpose(2, 0, 1))
        imgB = torch.from_numpy(imgB.transpose(2, 0, 1))


        # --- Load and Process Segmentation ---
        seg_mask_class_ids = torch.zeros(imgA.shape[1:], dtype=torch.long) # Default empty mask
        if self.load_segmentation:
            try:
                seg_label_np = np.array(imread(datafiles["seg_label"]), dtype=np.uint8)
                seg_mask_class_ids_np = rgb_to_class_id_mask(seg_label_np)
                seg_mask_class_ids = torch.from_numpy(seg_mask_class_ids_np).long()
            except FileNotFoundError:
                 print(f"Error loading segmentation label for {datafiles['name']}. Path: {datafiles['seg_label']}")
                 # Keep the default empty mask or raise error depending on strictness
                 # raise
            except Exception as e:
                 print(f"Error processing segmentation label for {datafiles['name']}: {e}")
                 # Keep the default empty mask or raise error
                 # raise

        # --- Load Captions ---
        token_sequence = None
        all_caption_tokens = [] # To store all tokenized captions for evaluation

        try:
            with open(datafiles["token"], 'r') as f:
                caption_list = json.load(f) # List of lists of token indices

            if not caption_list:
                print(f"Warning: Empty caption list found for {datafiles['name']}. Using default padding.")
                token_sequence = torch.full((self.max_length,), self.pad_idx, dtype=torch.long)
                all_caption_tokens = [[self.start_idx] + [self.end_idx] * (self.max_length - 1)] # Dummy list
            else:
                # Ensure all loaded captions have the expected max_length
                processed_caption_list = []
                for cap in caption_list:
                    if len(cap) == self.max_length:
                        processed_caption_list.append(cap)
                    else:
                        # Pad or truncate if necessary (should ideally be handled in preprocessing)
                        print(f"Warning: Caption length mismatch for {datafiles['name']}. Expected {self.max_length}, got {len(cap)}. Adjusting.")
                        cap = cap[:self.max_length] # Truncate
                        while len(cap) < self.max_length:
                            cap.append(self.pad_idx) # Pad
                        processed_caption_list.append(cap)
                caption_list = processed_caption_list

                if not caption_list: # If all captions were invalid length
                     print(f"Error: No valid length captions found for {datafiles['name']} after length check.")
                     token_sequence = torch.full((self.max_length,), self.pad_idx, dtype=torch.long)
                     all_caption_tokens = [[self.start_idx] + [self.end_idx] * (self.max_length - 1)]
                else:
                    if self.split == 'train':
                        # Select one random caption for training
                        selected_caption = caption_list[randint(0, len(caption_list) - 1)]
                        token_sequence = torch.tensor(selected_caption, dtype=torch.long)
                        all_caption_tokens = [selected_caption] # Store the selected one
                    else: # For validation/test, load all captions
                        all_caption_tokens = caption_list
                        token_sequence = None # Not needed directly for eval

        except FileNotFoundError:
            print(f"Error loading token file for {datafiles['name']}. Path: {datafiles['token']}")
            raise
        except Exception as e:
            print(f"Error processing token file {datafiles['token']}: {e}")
            raise

        # --- Prepare output dictionary ---
        item = {
            'imgA': imgA,
            'imgB': imgB,
            'name': datafiles['name'],
            'all_caption_tokens': all_caption_tokens # Always include for eval
        }

        if self.load_segmentation:
            item['seg_label'] = seg_mask_class_ids

        if self.split == 'train' and token_sequence is not None:
             item['caption_tokens'] = token_sequence # Include single sequence for training

        return item

# =============================================================================
# --- Preprocessing Function ---\
# (Reused - ensure it runs correctly before training)
# =============================================================================
def run_preprocessing_direct(data_folder, processed_data_dir, caption_file, max_length, word_count_threshold):
    """
    Runs the preprocessing steps: tokenization, vocabulary creation, and splitting.
    """
    print("Starting preprocessing...")
    output_dir = processed_data_dir
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'tokens'), exist_ok=True)

    # --- Load and process captions ---
    caption_file_path = os.path.join(data_folder, caption_file)
    try:
        with open(caption_file_path, 'r') as f:
            caption_data = json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"Caption file not found at {caption_file_path}")
    except json.JSONDecodeError:
        raise ValueError(f"Could not decode JSON from caption file at {caption_file_path}")

    word_counts = defaultdict(int)
    image_captions = defaultdict(list) # Stores {img_base_name: [ [token1, token2,...], [token1, token2,...] ]}

    print("Tokenizing captions and counting words...")
    max_actual_len = 0
    for img_info in tqdm(caption_data['images']):
        img_base_name = os.path.splitext(img_info['filename'])[0]
        for sentence in img_info['sentences']:
            tokens = sentence['tokens']
            max_actual_len = max(max_actual_len, len(tokens))
            # Add <START> and <END> tokens
            processed_tokens = ['<START>'] + tokens + ['<END>']
            # Pad or truncate tokens to max_length
            if len(processed_tokens) > max_length:
                processed_tokens = processed_tokens[:max_length-1] + ['<END>'] # Ensure END is last if truncated
            while len(processed_tokens) < max_length:
                processed_tokens.append('<NULL>') # Use <NULL> for padding

            image_captions[img_base_name].append(processed_tokens)

            # Count words (only original tokens)
            for token in tokens:
                word_counts[token] += 1

    print(f"Max actual caption length (before special tokens/padding): {max_actual_len}")
    print(f"Target max_length (with special tokens/padding): {max_length}")
    if max_actual_len + 2 > max_length:
        print(f"Warning: MAX_LENGTH ({max_length}) might be too short for some captions (max actual length + 2 = {max_actual_len + 2}). Consider increasing MAX_LENGTH.")

    # --- Build Vocabulary ---
    print("Building vocabulary...")
    words = [word for word, count in word_counts.items() if count >= word_count_threshold]
    vocab = {'<NULL>': 0, '<UNK>': 1, '<START>': 2, '<END>': 3}
    for i, word in enumerate(words):
        vocab[word] = i + 4
    vocab_size = len(vocab)
    print(f"Vocabulary size: {vocab_size} (using threshold {word_count_threshold})")

    vocab_file_path = os.path.join(output_dir, 'vocab.json')
    with open(vocab_file_path, 'w') as f:
        json.dump(vocab, f)
    print(f"Vocabulary saved to {vocab_file_path}")

    # --- Convert tokens to indices and save ---
    print("Converting tokens to indices and saving...")
    token_output_dir = os.path.join(output_dir, 'tokens')
    for img_base_name, captions in tqdm(image_captions.items()):
        indexed_captions = []
        for caption_tokens in captions:
            # Use vocab.get(token, vocab['<UNK>']) to handle unknown words
            indexed_caption = [vocab.get(token, vocab['<UNK>']) for token in caption_tokens]
            # Verify length after indexing
            if len(indexed_caption) != max_length:
                 print(f"Error: Indexed caption length mismatch for {img_base_name}. Expected {max_length}, got {len(indexed_caption)}. Skipping this caption.")
                 # Optionally pad/truncate again here, but it indicates a logic error earlier
                 continue
            indexed_captions.append(indexed_caption)

        if indexed_captions: # Only save if there are valid captions
            token_file_path = os.path.join(token_output_dir, f'{img_base_name}.json')
            with open(token_file_path, 'w') as f:
                json.dump(indexed_captions, f)
        else:
            print(f"Warning: No valid captions saved for {img_base_name}.")


    print("Tokenization and saving complete.")

    # --- Create Train/Val/Test Splits ---
    print("Creating train/val/test split files...")
    splits = ['train', 'val', 'test']
    for split in splits:
        split_image_folder_A = os.path.join(data_folder, 'images', split, 'A')
        split_image_folder_B = os.path.join(data_folder, 'images', split, 'B')
        split_label_folder = os.path.join(data_folder, 'images', split, 'label') # Assuming 'label' folder exists
        token_folder = os.path.join(output_dir, 'tokens')

        if not os.path.exists(split_image_folder_A):
            print(f"Warning: Image folder 'A' for split '{split}' not found at {split_image_folder_A}. Skipping split file creation.")
            continue

        # List image base names from folder A
        img_base_names_in_folder = sorted([os.path.splitext(f)[0] for f in os.listdir(split_image_folder_A) if f.endswith('.png')])

        # Filter names to only include those with corresponding B image, label (if exists), and token file
        valid_img_base_names = []
        for name in img_base_names_in_folder:
             img_b_exists = os.path.exists(os.path.join(split_image_folder_B, f"{name}.png"))
             label_exists = os.path.exists(os.path.join(split_label_folder, f"{name}.png"))
             token_exists = os.path.exists(os.path.join(token_folder, f"{name}.json"))

             # Require A, B, token. Require label only if the folder exists.
             if img_b_exists and token_exists and (label_exists or not os.path.exists(split_label_folder)):
                 valid_img_base_names.append(name)
             # else:
                 # print(f"Skipping {name} in split {split} due to missing files (B:{img_b_exists}, Label:{label_exists}, Token:{token_exists})")


        split_file_path = os.path.join(output_dir, f'{split}.txt')
        with open(split_file_path, 'w') as f:
            for name in valid_img_base_names:
                f.write(f"{name}\n")

        print(f"Split file for '{split}' created with {len(valid_img_base_names)} images at {split_file_path}")

    print("Preprocessing finished.")


# =============================================================================
# --- Class Weight Calculation Function ---\
# (Reused - calculates weights for segmentation loss)
# =============================================================================
def calculate_class_weights(data_folder, processed_data_dir, split='train', num_classes=NUM_CLASSES):
    """Calculates class weights using inverse frequency."""
    print(f"Calculating class weights for '{split}' split...")
    split_file_path = os.path.join(processed_data_dir, f'{split}.txt')
    if not os.path.exists(split_file_path):
         raise FileNotFoundError(f"Split file not found at {split_file_path}. Cannot calculate weights.")

    with open(split_file_path, 'r') as f:
        img_ids = [line.strip() for line in f if line.strip()]
    if not img_ids:
        raise ValueError(f"No image IDs found in split file: {split_file_path}.")

    label_folder_name = 'label'
    image_base_path = os.path.join(data_folder, 'images', split)
    class_pixel_counts = np.zeros(num_classes, dtype=np.int64)
    total_pixels = 0

    for img_base_name in tqdm(img_ids, desc="Counting pixels for class weights"):
        seg_path = os.path.join(image_base_path, label_folder_name, f"{img_base_name}.png")
        if not os.path.exists(seg_path):
            # print(f"Warning: Seg mask not found for {img_base_name} at {seg_path}. Skipping.")
            continue
        try:
            seg_label_np = np.array(imread(seg_path), dtype=np.uint8)
            seg_mask_class_ids_np = rgb_to_class_id_mask(seg_label_np)
            unique_classes, counts = np.unique(seg_mask_class_ids_np, return_counts=True)
            for class_id, count in zip(unique_classes, counts):
                if 0 <= class_id < num_classes:
                    class_pixel_counts[class_id] += count
            total_pixels += seg_mask_class_ids_np.size
        except Exception as e:
            print(f"Error processing mask {img_base_name} for pixel count: {e}. Skipping.")
            continue

    print(f"Class pixel counts: {class_pixel_counts}")
    print(f"Total pixels counted: {total_pixels}")

    if total_pixels == 0:
        print("Warning: No pixels counted. Returning equal weights.")
        return torch.ones(num_classes, dtype=torch.float)

    # Inverse frequency weighting: weight_c = total_pixels / (num_classes * pixels_in_class_c)
    # Add epsilon to avoid division by zero for classes not present
    class_weights = total_pixels / (num_classes * (class_pixel_counts + 1e-6))

    # Normalize weights (optional, can help stability)
    # class_weights = class_weights / np.sum(class_weights)

    class_weights_tensor = torch.from_numpy(class_weights).float()
    print(f"Calculated class weights: {class_weights_tensor.tolist()}")
    return class_weights_tensor


# =============================================================================
# --- Improved Model Architectures ---\
# =============================================================================

# --- ResNet Encoder ---
# --- ResNet Encoder (Corrected Weight Init) ---
class ResNetEncoder(nn.Module):
    """
    Encoder using a pre-trained ResNet model.
    Adapts the first layer to accept 6 channels (concatenated imgA and imgB).
    CORRECTED: Weight initialization for the first layer.
    """
    def __init__(self, arch='resnet50', pretrained=True, freeze_layers=25):
        super().__init__()
        print(f"Initializing ResNetEncoder with arch={arch}, pretrained={pretrained}")
        if arch == 'resnet18':
            resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if pretrained else None)
            self.out_features_base = 512
        elif arch == 'resnet34':
            resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT if pretrained else None)
            self.out_features_base = 512
        elif arch == 'resnet50':
            resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT if pretrained else None)
            self.out_features_base = 2048
        elif arch == 'resnet101':
            resnet = models.resnet101(weights=models.ResNet101_Weights.DEFAULT if pretrained else None)
            self.out_features_base = 2048
        else:
            raise ValueError(f"Unsupported ResNet architecture: {arch}")

        # Adapt the first convolutional layer for 6 input channels
        original_conv1 = resnet.conv1
        self.conv1 = nn.Conv2d(6, original_conv1.out_channels, # Requesting 6 in_channels
                               kernel_size=original_conv1.kernel_size,
                               stride=original_conv1.stride,
                               padding=original_conv1.padding,
                               bias=original_conv1.bias)

        # Initialize weights for the new conv1
        if pretrained:
            print("Adapting pre-trained weights for the first layer (6 channels)...")
            original_weights = original_conv1.weight.data # Shape: [out_channels, 3, k, k]

            # --- CORRECTED WEIGHT INITIALIZATION ---
            # Concatenate the original 3-channel weights twice along the input channel dim (dim=1)
            # This results in a weight tensor expecting 6 input channels: [out_channels, 6, k, k]
            # We divide by 2.0 to keep the initial activation scale similar to the original pre-trained model.
            self.conv1.weight.data = torch.cat((original_weights, original_weights), dim=1) / 2.0
            # --- END CORRECTION ---

            # Copy bias if it exists
            if original_conv1.bias is not None:
                self.conv1.bias.data = original_conv1.bias.data


        # Keep other layers from ResNet
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        # --- Feature dimensions at different stages ---
        # Output channels after each layer block
        # Determine output channels based on the architecture (handling bottleneck vs basic block)
        if hasattr(resnet.layer1[-1], 'conv3'): # Bottleneck block (ResNet50+)
            layer1_out_ch = resnet.layer1[-1].conv3.out_channels
            layer2_out_ch = resnet.layer2[-1].conv3.out_channels
            layer3_out_ch = resnet.layer3[-1].conv3.out_channels
        else: # Basic block (ResNet18/34)
             layer1_out_ch = resnet.layer1[-1].conv2.out_channels
             layer2_out_ch = resnet.layer2[-1].conv2.out_channels
             layer3_out_ch = resnet.layer3[-1].conv2.out_channels


        self.out_channels = [
            resnet.conv1.out_channels, # After conv1/bn1/relu/maxpool (e.g., 64)
            layer1_out_ch,             # After layer1
            layer2_out_ch,             # After layer2
            layer3_out_ch,             # After layer3
            self.out_features_base     # After layer4
        ]
        print(f"Encoder output channels per stage: {self.out_channels}")


        # Freeze early layers if requested
        if freeze_layers > 0:
            print(f"Freezing first {freeze_layers} layers of the encoder.")
            layers_to_freeze = [self.conv1, self.bn1, self.relu, self.maxpool]
            if freeze_layers >= 1: layers_to_freeze.append(self.layer1)
            if freeze_layers >= 2: layers_to_freeze.append(self.layer2)
            if freeze_layers >= 3: layers_to_freeze.append(self.layer3)
            # Layer 4 is typically kept trainable for fine-tuning
            for layer in layers_to_freeze:
                for param in layer.parameters():
                    param.requires_grad = False


    def forward(self, imgA, imgB):
        # Concatenate images along the channel dimension
        x = torch.cat((imgA, imgB), dim=1) # Shape: (N, 6, H, W)

        # --- Forward through ResNet layers ---
        x = self.conv1(x) # <--- This was the point of error
        x = self.bn1(x)
        x0 = self.relu(x) # Output after initial conv/relu (used for skip connection)
        x = self.maxpool(x0)

        x1 = self.layer1(x)  # Output after layer1
        x2 = self.layer2(x1) # Output after layer2
        x3 = self.layer3(x2) # Output after layer3
        x4 = self.layer4(x3) # Output after layer4 (final feature map)

        # Return features from multiple stages for U-Net decoder
        return [x4, x3, x2, x1, x0] # From deepest to shallowest features

# --- U-Net Decoder ---
class UNetDecoder(nn.Module):
    """
    U-Net decoder part. Takes features from encoder stages and upsamples.
    """
    def __init__(self, encoder_channels, decoder_channels, num_classes, final_upsample_mode='bilinear'):
        super().__init__()
        self.num_classes = num_classes
        self.final_upsample_mode = final_upsample_mode
        if len(encoder_channels) != 5:
             raise ValueError("UNetDecoder expects 5 feature maps from the encoder (x4, x3, x2, x1, x0)")
        if len(decoder_channels) != 4:
             raise ValueError("UNetDecoder expects 4 decoder channel sizes")

        # Encoder channels (deepest first): e.g., [2048, 1024, 512, 256, 64] for ResNet50
        enc_c4, enc_c3, enc_c2, enc_c1, enc_c0 = encoder_channels
        # Decoder channels (bottom-up): e.g., [512, 256, 128, 64]
        dec_c3, dec_c2, dec_c1, dec_c0 = decoder_channels

        # --- Upsampling Blocks ---
        # Block 1 (Upsample x4, combine with x3)
        self.upconv3 = nn.ConvTranspose2d(enc_c4, dec_c3, kernel_size=2, stride=2)
        self.dec_conv3 = self._conv_block(dec_c3 + enc_c3, dec_c3)

        # Block 2 (Upsample previous, combine with x2)
        self.upconv2 = nn.ConvTranspose2d(dec_c3, dec_c2, kernel_size=2, stride=2)
        self.dec_conv2 = self._conv_block(dec_c2 + enc_c2, dec_c2)

        # Block 3 (Upsample previous, combine with x1)
        self.upconv1 = nn.ConvTranspose2d(dec_c2, dec_c1, kernel_size=2, stride=2)
        self.dec_conv1 = self._conv_block(dec_c1 + enc_c1, dec_c1)

        # Block 4 (Upsample previous, combine with x0)
        self.upconv0 = nn.ConvTranspose2d(dec_c1, dec_c0, kernel_size=2, stride=2)
        self.dec_conv0 = self._conv_block(dec_c0 + enc_c0, dec_c0)

        # Final convolution layer
        self.final_conv = nn.Conv2d(dec_c0, num_classes, kernel_size=1)

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def _center_crop_and_concat(self, upsampled, bypass):
        """Center crop bypass connection to match spatial dimensions and concatenate."""
        diffY = bypass.size()[2] - upsampled.size()[2]
        diffX = bypass.size()[3] - upsampled.size()[3]
        bypass_cropped = bypass[:, :, diffY // 2 : diffY // 2 + upsampled.size()[2],
                                     diffX // 2 : diffX // 2 + upsampled.size()[3]]
        return torch.cat([upsampled, bypass_cropped], dim=1)

    def forward(self, encoder_features, target_size):
        # encoder_features is a list [x4, x3, x2, x1, x0] (deepest to shallowest)
        x4, x3, x2, x1, x0 = encoder_features

        # Decode block 3
        d3 = self.upconv3(x4)
        d3 = self._center_crop_and_concat(d3, x3)
        d3 = self.dec_conv3(d3)

        # Decode block 2
        d2 = self.upconv2(d3)
        d2 = self._center_crop_and_concat(d2, x2)
        d2 = self.dec_conv2(d2)

        # Decode block 1
        d1 = self.upconv1(d2)
        d1 = self._center_crop_and_concat(d1, x1)
        d1 = self.dec_conv1(d1)

        # Decode block 0
        d0 = self.upconv0(d1)
        d0 = self._center_crop_and_concat(d0, x0)
        d0 = self.dec_conv0(d0)

        # Final convolution
        logits = self.final_conv(d0)

        # Upsample to the original target size
        # Using interpolate is often more flexible than ConvTranspose for the final layer
        if logits.shape[-2:] != target_size:
            logits = F.interpolate(logits, size=target_size, mode=self.final_upsample_mode, align_corners=False if self.final_upsample_mode=='bilinear' else None)

        return logits


# --- Segmentation Model (ResNet + U-Net) ---

# --- Segmentation Model (ResNet + U-Net - CORRECTED Initialization) ---
class SegmentationModel(nn.Module):
    def __init__(self, num_classes, encoder_arch='resnet34', pretrained=True, freeze_encoder_layers=0,
                 decoder_channels=(256, 128, 64, 32), final_upsample_mode='bilinear'):
        super().__init__()
        self.encoder = ResNetEncoder(arch=encoder_arch, pretrained=pretrained, freeze_layers=freeze_encoder_layers)

        # --- CORRECTION: Reverse encoder channels list ---
        # The encoder outputs channels in order [x0, x1, x2, x3, x4]
        # The decoder expects them in order [x4, x3, x2, x1, x0] for its initialization logic
        encoder_channels_reversed = self.encoder.out_channels[::-1] # Reverse the list
        # Example for ResNet34: [64, 64, 128, 256, 512] -> [512, 256, 128, 64, 64]
        # --- END CORRECTION ---

        self.decoder = UNetDecoder(encoder_channels=encoder_channels_reversed, # Pass the reversed list
                                   decoder_channels=decoder_channels,
                                   num_classes=num_classes,
                                   final_upsample_mode=final_upsample_mode)
        self.num_classes = num_classes # Make sure num_classes is accessible

    # --- THIS METHOD MUST EXIST ---
    def forward(self, imgA, imgB):
        target_size = imgA.shape[-2:] # H, W of original input
        # Get features from the encoder (List: [x4, x3, x2, x1, x0])
        encoder_features = self.encoder(imgA, imgB)
        # Pass features and target size to the decoder
        segmentation_output = self.decoder(encoder_features, target_size) # (N, num_classes, H, W)
        return segmentation_output
    # --- END of forward method ---

# --- Transformer Components ---
class PositionalEncoding(nn.Module):
    """Injects positional information into the input embeddings."""
    def __init__(self, d_model, dropout=0.1, max_len=50): # max_len should accommodate max_length
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # Shape: (1, max_len, d_model)
        self.register_buffer('pe', pe) # Not a model parameter

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        # x.size(1) is the sequence length
        # Add positional encoding up to the length of the sequence
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class TransformerDecoderLayer(nn.Module):
    """A single layer for the Transformer Decoder."""
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", layer_norm_eps=1e-5):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = F.relu if activation == "relu" else F.gelu # Example activation choices

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        """
        Args:
            tgt: Target sequence (batch_size, tgt_len, d_model)
            memory: Memory sequence from encoder (batch_size, src_len, d_model)
            tgt_mask: Mask for target sequence (tgt_len, tgt_len) - prevents attending to future tokens
            memory_mask: Mask for memory sequence (Not typically used in standard image captioning)
            tgt_key_padding_mask: Mask for padding in target sequence (batch_size, tgt_len)
            memory_key_padding_mask: Mask for padding in memory sequence (batch_size, src_len)
        """
        # Self-attention block
        tgt2, self_attn_weights = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                                                  key_padding_mask=tgt_key_padding_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # Multihead attention block (attends to encoder memory)
        tgt2, cross_attn_weights = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
                                                        key_padding_mask=memory_key_padding_mask)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        # Feedforward block
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt, self_attn_weights, cross_attn_weights # Return attention weights if needed

class TransformerDecoder(nn.Module):
    """Transformer Decoder composed of multiple layers."""
    def __init__(self, decoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = nn.ModuleList([decoder_layer for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm # Optional final layer norm

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        output = tgt
        all_self_attn_weights = []
        all_cross_attn_weights = []

        for mod in self.layers:
            output, self_attn, cross_attn = mod(output, memory, tgt_mask=tgt_mask,
                                                 memory_mask=memory_mask,
                                                 tgt_key_padding_mask=tgt_key_padding_mask,
                                                 memory_key_padding_mask=memory_key_padding_mask)
            all_self_attn_weights.append(self_attn)
            all_cross_attn_weights.append(cross_attn)


        if self.norm is not None:
            output = self.norm(output)

        # Return attention weights from the last layer, or all layers if needed
        return output, all_self_attn_weights, all_cross_attn_weights


# --- Captioning Model (ResNet Encoder + Transformer Decoder) ---
class CaptioningModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_decoder_layers=6,
                 dim_feedforward=2048, dropout=0.1, activation="relu",
                 encoder_arch='resnet34', pretrained=True, freeze_encoder_layers=0,
                 max_length=41): # max_length needed for positional encoding
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_length = max_length

        # --- Encoder ---
        # Use a separate ResNet instance or share if memory allows and tasks are related
        # Here, we use a separate instance for simplicity.
        self.encoder = ResNetEncoder(arch=encoder_arch, pretrained=pretrained, freeze_layers=freeze_encoder_layers)
        encoder_output_dim = self.encoder.out_features_base # e.g., 512 for ResNet34, 2048 for ResNet50

        # --- Input Projection ---
        # Project encoder output features to the decoder's expected dimension (d_model)
        self.input_proj = nn.Conv2d(encoder_output_dim, d_model, kernel_size=1)

        # --- Decoder Components ---
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=max_length)

        # Standard Transformer Decoder Layer
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, norm=decoder_norm)

        # --- Output Layer ---
        self.output_layer = nn.Linear(d_model, vocab_size)

        # --- Initialize Weights ---
        self._reset_parameters()


    def _reset_parameters(self):
        """Initiate parameters in the transformer model."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        # Initialize embedding weights specifically if needed
        nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model**-0.5)
        # Ensure padding index embedding is zero
        if hasattr(self.embedding, 'padding_idx') and self.embedding.padding_idx is not None:
             with torch.no_grad():
                 self.embedding.weight[self.embedding.padding_idx].fill_(0)


    def _generate_square_subsequent_mask(self, sz, device):
        """Generates a square mask for the sequence. Used in self-attention."""
        mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def _create_padding_mask(self, sequence, pad_idx):
        """Creates a mask for padding tokens."""
        # True where sequence == pad_idx, False otherwise
        return (sequence == pad_idx)


    def forward(self, imgA, imgB, caption_tokens, pad_idx):
        """
        Forward pass for training.
        Args:
            imgA, imgB: Input images (N, 3, H, W)
            caption_tokens: Target caption tokens (N, T), including <START> but shifted for input.
                            Input to decoder: <START> w1 w2 ... wn
                            Target for loss:      w1 w2 ... wn <END>
            pad_idx: Index of the padding token.
        Returns:
            Output logits (N, T, vocab_size)
        """
        # --- Encode Images ---
        # Get the deepest features from the encoder (x4)
        encoder_features_list = self.encoder(imgA, imgB)
        encoder_output = encoder_features_list[0] # Shape: (N, C_enc, H', W')

        # Project encoder features to d_model and flatten
        memory = self.input_proj(encoder_output) # (N, d_model, H', W')
        N, _, H_mem, W_mem = memory.shape
        memory = memory.permute(0, 2, 3, 1).view(N, -1, self.d_model) # (N, H'*W', d_model)

        # --- Prepare Decoder Input ---
        # For training, use teacher forcing. Input is shifted right.
        # Input: <START> w1 w2 ... wn
        tgt_input = caption_tokens[:, :-1] # (N, T-1)
        # Target for loss: w1 w2 ... wn <END>
        # tgt_output = caption_tokens[:, 1:] # (N, T-1) - Handled by loss function

        # --- Embed and Add Positional Encoding ---
        tgt_emb = self.embedding(tgt_input) * math.sqrt(self.d_model) # Scale embedding
        tgt_pos = self.pos_encoder(tgt_emb) # (N, T-1, d_model)

        # --- Create Masks ---
        tgt_seq_len = tgt_input.size(1)
        device = tgt_input.device
        # Mask to prevent attending to future tokens
        tgt_mask = self._generate_square_subsequent_mask(tgt_seq_len, device) # (T-1, T-1)
        # Mask to ignore padding tokens in the target sequence input
        tgt_padding_mask = self._create_padding_mask(tgt_input, pad_idx) # (N, T-1)
        # Padding mask for encoder memory (if encoder output could be variable length - not typical here)
        memory_padding_mask = None # Assuming fixed size encoder output

        # --- Decode ---
        # memory shape: (N, S, d_model) where S = H'*W'
        # tgt_pos shape: (N, T-1, d_model)
        decoder_output, _, _ = self.decoder(tgt_pos, memory,
                                            tgt_mask=tgt_mask,
                                            tgt_key_padding_mask=tgt_padding_mask,
                                            memory_key_padding_mask=memory_padding_mask)
                                            # Shape: (N, T-1, d_model)

        # --- Final Output Layer ---
        logits = self.output_layer(decoder_output) # Shape: (N, T-1, vocab_size)

        return logits

    # --- Beam Search Decoding (for validation/inference) ---
    def beam_search_decode(self, imgA, imgB, beam_size, start_idx, end_idx, pad_idx, max_len=None):
        """
        Generates captions using beam search.
        Args:
            imgA, imgB: Input images (1, 3, H, W) - Process one image pair at a time.
            beam_size (int): Number of beams to keep.
            start_idx, end_idx, pad_idx: Special token indices.
            max_len (int, optional): Maximum generation length. Defaults to self.max_length.
        Returns:
            List[int]: The sequence of token indices for the best caption.
        """
        if max_len is None:
            max_len = self.max_length
        device = imgA.device

        # --- Encode Images ---
        with torch.no_grad():
            encoder_features_list = self.encoder(imgA, imgB)
            encoder_output = encoder_features_list[0]
            memory = self.input_proj(encoder_output)
            N, _, H_mem, W_mem = memory.shape
            memory = memory.permute(0, 2, 3, 1).view(N, -1, self.d_model) # (1, S, d_model)
            # Expand memory for beam size
            memory = memory.expand(beam_size, -1, -1) # (beam_size, S, d_model)
            memory_padding_mask = None

        # --- Initialize Beams ---
        # Start with the <START> token
        initial_input = torch.full((beam_size, 1), start_idx, dtype=torch.long, device=device) # (beam_size, 1)

        # Top k sequences found so far (log_prob, sequence)
        # Use negative log probability because heapq is a min-heap
        top_k_sequences = [(-0.0, initial_input)] # Start with zero log prob
        completed_sequences = [] # (neg_log_prob, sequence)

        # --- Decoding Loop ---
        for t in range(max_len - 1): # Max length steps
            if not top_k_sequences: # Stop if no active beams
                break

            new_candidates = [] # Store candidates for the next step (neg_log_prob, sequence)

            for neg_log_prob, current_seq in top_k_sequences:
                # current_seq shape: (beam_size, current_len) -> need (1, current_len) for model
                # We process one beam candidate at a time here for simplicity,
                # although batching across beams is more efficient.
                current_seq_single = current_seq[0:1, :] # Take the first row (batch size 1)

                # Prepare decoder input
                tgt_emb = self.embedding(current_seq_single) * math.sqrt(self.d_model)
                tgt_pos = self.pos_encoder(tgt_emb)

                tgt_mask = self._generate_square_subsequent_mask(current_seq_single.size(1), device)
                tgt_padding_mask = self._create_padding_mask(current_seq_single, pad_idx)

                # Decode one step
                # Use only the first beam's memory, as it's expanded
                decoder_output, _, _ = self.decoder(tgt_pos, memory[0:1,:,:], # Use memory for batch size 1
                                                    tgt_mask=tgt_mask,
                                                    tgt_key_padding_mask=tgt_padding_mask,
                                                    memory_key_padding_mask=memory_padding_mask)
                                                    # Shape: (1, current_len, d_model)

                # Get logits for the last token
                last_token_logits = self.output_layer(decoder_output[:, -1, :]) # (1, vocab_size)
                log_probs = F.log_softmax(last_token_logits, dim=-1) # (1, vocab_size)

                # Get top k next tokens and their log probabilities
                # Add current sequence's log prob to the next token's log prob
                top_log_probs, top_indices = torch.topk(log_probs.squeeze(0), beam_size, dim=-1) # (beam_size,)

                # Add candidates to the list
                for i in range(beam_size):
                    next_token_idx = top_indices[i].item()
                    next_token_log_prob = top_log_probs[i].item()
                    total_neg_log_prob = neg_log_prob - next_token_log_prob # Add log probs = subtract neg log probs

                    # Create the new sequence
                    new_seq = torch.cat([current_seq_single, top_indices[i].unsqueeze(0).unsqueeze(0)], dim=1) # (1, current_len + 1)

                    if next_token_idx == end_idx:
                        # Add completed sequence (normalize log prob by length?)
                        # Simple normalization: divide by length
                        # More sophisticated: length penalty alpha * log(len) / log(start_len)
                        normalized_neg_log_prob = total_neg_log_prob / (new_seq.size(1)**0.7) # Length penalty
                        completed_sequences.append((normalized_neg_log_prob, new_seq))
                    else:
                        # Add to candidates for next step
                         new_candidates.append((total_neg_log_prob, new_seq))

            # --- Prune Candidates ---
            # Sort all candidates by negative log probability and keep top beam_size
            new_candidates.sort(key=lambda x: x[0])
            top_k_sequences = new_candidates[:beam_size]

            # --- Prune Completed Sequences ---
            # Keep only the top beam_size completed sequences
            completed_sequences.sort(key=lambda x: x[0])
            completed_sequences = completed_sequences[:beam_size]


        # If no sequences completed, return the best active beam
        if not completed_sequences:
             if top_k_sequences:
                 # Normalize the score of the best active beam
                 best_neg_log_prob, best_seq = top_k_sequences[0]
                 normalized_prob = best_neg_log_prob / (best_seq.size(1)**0.7)
                 completed_sequences.append((normalized_prob, best_seq))
             else:
                 # Should not happen if start token is valid, but handle anyway
                 return [start_idx, end_idx] # Return minimal sequence


        # Return the best completed sequence (lowest normalized negative log prob)
        best_normalized_prob, best_sequence = min(completed_sequences, key=lambda x: x[0])
        return best_sequence.squeeze(0).tolist() # Return as list of ints


# =============================================================================
# --- Training & Validation Functions (Updated) ---\
# =============================================================================

def train_segmentation(train_loader, model, criterion, optimizer, epoch, device):
    """Trains the segmentation model for one epoch."""
    model.train()
    losses = AverageMeter()
    ious = AverageMeter()
    start_time = time.time()

    for i, data in enumerate(tqdm(train_loader, desc=f"Seg Training Epoch {epoch}")):
        imgA = data['imgA'].to(device, non_blocking=True)
        imgB = data['imgB'].to(device, non_blocking=True)
        seg_labels = data['seg_label'].to(device, non_blocking=True)

        optimizer.zero_grad()
        seg_outputs = model(imgA, imgB) # Shape: (N, num_classes, H, W)

        # Loss calculation (CrossEntropyLoss handles logits directly)
        loss = criterion(seg_outputs, seg_labels)

        loss.backward()
        # Optional: Gradient clipping
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        losses.update(loss.item(), imgA.size(0))

        # Calculate IoU
        with torch.no_grad():
             _, predicted_masks = torch.max(seg_outputs, 1)
             batch_iou_per_class = calculate_iou(predicted_masks, seg_labels, model) # Pass model
             # Average IoU across classes for the batch (handle NaNs)
             mean_batch_iou = torch.nanmean(batch_iou_per_class).item() if not torch.all(torch.isnan(batch_iou_per_class)) else 0.0
             ious.update(mean_batch_iou, imgA.size(0))


    epoch_time = time.time() - start_time
    print(f"Epoch {epoch} Seg Training Loss: {losses.avg:.4f}, Mean IoU: {ious.avg:.4f}, Time: {epoch_time:.2f}s")
    return losses.avg # Return loss for LR scheduler


def validate_segmentation(val_loader, model, criterion, device):
    """Validates the segmentation model."""
    model.eval()
    losses = AverageMeter()
    all_class_ious = [] # List to store IoU tensor (num_classes,) for each batch
    start_time = time.time()

    with torch.no_grad():
        for i, data in enumerate(tqdm(val_loader, desc="Seg Validation")):
            imgA = data['imgA'].to(device, non_blocking=True)
            imgB = data['imgB'].to(device, non_blocking=True)
            seg_labels = data['seg_label'].to(device, non_blocking=True)

            seg_outputs = model(imgA, imgB)
            loss = criterion(seg_outputs, seg_labels)
            losses.update(loss.item(), imgA.size(0))

            _, predicted_masks = torch.max(seg_outputs, 1)
            batch_iou_per_class = calculate_iou(predicted_masks, seg_labels, model) # Pass model
            # Store the IoU tensor for this batch (handle potential all-NaN case)
            if not torch.all(torch.isnan(batch_iou_per_class)):
                 all_class_ious.append(batch_iou_per_class)


    epoch_time = time.time() - start_time

    # Get num_classes from the model
    model_ref = model.module if isinstance(model, nn.DataParallel) else model
    num_classes = model_ref.num_classes

    # Calculate mean IoU across all batches and classes
    if all_class_ious:
        # Stack IoUs from all batches: (num_valid_batches, num_classes)
        stacked_ious = torch.stack([iou for iou in all_class_ious if not torch.all(torch.isnan(iou))])
        # Mean IoU per class across batches (ignoring NaNs within a class's calculation)
        mean_iou_per_class = torch.nanmean(stacked_ious, dim=0) # (num_classes,)
        # Overall mean IoU (mean of per-class means, ignoring NaNs)
        mean_overall_iou = torch.nanmean(mean_iou_per_class).item()
    else:
        mean_iou_per_class = torch.zeros(num_classes, device=device) * float('nan')
        mean_overall_iou = 0.0

    print(f"Seg Validation Loss: {losses.avg:.4f}, Mean Overall IoU: {mean_overall_iou:.4f}")
    # Replace NaN with 0 for printing if desired
    iou_list_print = [f"{iou:.4f}" if not torch.isnan(iou) else "NaN" for iou in mean_iou_per_class]
    print(f"IoU per class: {iou_list_print}")
    print(f"Validation Time: {epoch_time:.2f}s")

    return losses.avg, mean_overall_iou


def train_captioning(train_loader, model, criterion, optimizer, epoch, device, pad_idx):
    """Trains the captioning model for one epoch."""
    model.train()
    losses = AverageMeter()
    start_time = time.time()

    for i, data in enumerate(tqdm(train_loader, desc=f"Cap Training Epoch {epoch}")):
        imgA = data['imgA'].to(device, non_blocking=True)
        imgB = data['imgB'].to(device, non_blocking=True)
        # caption_tokens shape: (N, max_length) - includes <START> and <END>/<NULL>
        caption_tokens = data['caption_tokens'].to(device, non_blocking=True)

        optimizer.zero_grad()

        # Forward pass: Input is shifted right, excludes last token (<END>/<NULL>)
        # Output logits correspond to targets shifted left, excludes first token (<START>)
        logits = model(imgA, imgB, caption_tokens, pad_idx) # (N, T-1, vocab_size)

        # Targets for loss: Exclude <START> token
        targets = caption_tokens[:, 1:] # (N, T-1)

        # Calculate loss
        # Reshape for CrossEntropyLoss: (N * (T-1), vocab_size) and (N * (T-1),)
        loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))

        loss.backward()
        # Optional: Gradient clipping
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()

        losses.update(loss.item(), imgA.size(0)) # Use batch size for averaging loss

    epoch_time = time.time() - start_time
    print(f"Epoch {epoch} Cap Training Loss: {losses.avg:.4f}, Time: {epoch_time:.2f}s")
    return losses.avg # Return loss for LR scheduler


def validate_captioning(val_loader, model, criterion, device, idx_to_word, start_idx, end_idx, pad_idx, beam_size=3):
    """Validates the captioning model using Beam Search and calculates BLEU score."""
    model.eval()
    references = [] # List of lists of lists of words (ground truth)
    hypotheses = [] # List of lists of words (predicted)
    start_time = time.time()

    # Get the actual model object if wrapped in DataParallel
    model_ref = model.module if isinstance(model, nn.DataParallel) else model

    with torch.no_grad():
        for i, data in enumerate(tqdm(val_loader, desc="Cap Validation (Beam Search)")):
            # Process one image pair at a time for beam search simplicity
            # More efficient implementations would batch the beam search.
            batch_size = data['imgA'].size(0)
            for j in range(batch_size):
                imgA_single = data['imgA'][j:j+1].to(device) # Keep batch dim: (1, 3, H, W)
                imgB_single = data['imgB'][j:j+1].to(device) # Keep batch dim: (1, 3, H, W)
                # all_caption_tokens_single: list of lists of token IDs for this image
                all_caption_tokens_single = data['all_caption_tokens'][j]

                # --- Beam Search Decoding ---
                generated_ids = model_ref.beam_search_decode(imgA_single, imgB_single,
                                                             beam_size, start_idx, end_idx, pad_idx,
                                                             max_len=val_loader.dataset.max_length)

                # Convert generated IDs to words
                generated_words = []
                for token_id in generated_ids:
                    if token_id == end_idx: break
                    if token_id not in [start_idx, pad_idx]: # Exclude special tokens
                        generated_words.append(idx_to_word.get(token_id, '<UNK>'))
                hypotheses.append(generated_words)

                # Convert ground truth token IDs to words
                image_references = []
                for gt_caption_tokens in all_caption_tokens_single:
                    gt_words = []
                    for token_id in gt_caption_tokens:
                        if token_id == end_idx: break
                        if token_id not in [start_idx, pad_idx]: # Exclude special tokens
                             gt_words.append(idx_to_word.get(token_id, '<UNK>'))
                    # Add non-empty references only
                    if gt_words:
                        image_references.append(gt_words)
                # Add references only if there are any valid ones
                if image_references:
                    references.append(image_references)
                else:
                    # If no valid references, add a dummy one to align with hypothesis?
                    # Or handle potential mismatch later in BLEU calc. Let's add dummy.
                    references.append([['<UNK>']]) # Add a dummy reference
                    print(f"Warning: No valid reference captions found for image {data['name'][j]}. Added dummy reference.")


    epoch_time = time.time() - start_time

    # --- Calculate BLEU scores ---
    bleu_1, bleu_2, bleu_3, bleu_4 = 0.0, 0.0, 0.0, 0.0
    if not references or not hypotheses or len(references) != len(hypotheses):
         print(f"Warning: Mismatch in number of references ({len(references)}) and hypotheses ({len(hypotheses)}) or empty lists. Cannot calculate BLEU.")
    else:
        try:
            smooth = SmoothingFunction().method1 # Choose a smoothing method
            bleu_1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0), smoothing_function=smooth)
            bleu_2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0), smoothing_function=smooth)
            bleu_3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smooth)
            bleu_4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)
            print(f"Cap Validation BLEU-1: {bleu_1:.4f}, BLEU-2: {bleu_2:.4f}, BLEU-3: {bleu_3:.4f}, BLEU-4: {bleu_4:.4f}")
        except ZeroDivisionError:
            print("Could not calculate BLEU score (division by zero). Check generated captions.")
        except Exception as e:
            print(f"Error calculating BLEU score: {e}")


    print(f"Validation Time: {epoch_time:.2f}s")

    # Return BLEU-4 score (commonly used primary metric)
    # Note: Validation loss is not calculated here as we use beam search, not teacher forcing.
    return bleu_4


# =============================================================================
# --- Main Execution ---\
# =============================================================================

if __name__ == '__main__':

    # --- Configuration Parameters ---
    # Choose the mode: 'train_segmentation' or 'train_captioning'
    RUN_MODE = 'train_captioning' # <-- SET MODE HERE ('train_segmentation' or 'train_captioning')

    # General Parameters
    DATA_FOLDER = DATASET_ROOT
    PROCESSED_DATA_DIR = os.path.join(SAVE_OUTPUT_DIR, 'processed_data_v2') # Use new dir for potentially different vocab/splits
    CAPTION_FILE = 'LevirCCcaptions.json'
    CHECKPOINT_DIR = os.path.join(SAVE_OUTPUT_DIR, 'checkpoints_v2') # Use new dir for new models
    EPOCHS = 20 # Increased epochs might be needed for larger models
    BATCH_SIZE = 8 # Reduced batch size might be needed due to larger models
    NUM_WORKERS = 4 # Adjust based on Kaggle instance resources
    LR = 5e-5 # Potentially lower LR for fine-tuning pre-trained models
    WEIGHT_DECAY = 1e-4 # Added weight decay
    RESUME_CHECKPOINT = '' # Set path to resume

    # Preprocessing Parameters
    MAX_LENGTH = 45 # Increased slightly, check preprocessing output warning
    WORD_COUNT_THRESHOLD = 3 # Slightly lower threshold?

    # --- Model Specific Parameters ---
    # Shared Encoder (ResNet)
    ENCODER_ARCH = 'resnet34' # Options: 'resnet18', 'resnet34', 'resnet50', 'resnet101'
    ENCODER_PRETRAINED = True
    ENCODER_FREEZE_LAYERS = 25 # 0: train all, 1: freeze conv1/layer1, 2: freeze up to layer2, etc.

    # Segmentation Model (U-Net Decoder)
    SEG_DECODER_CHANNELS = (256, 128, 64, 32) # Adjust based on encoder size if needed
    SEG_FINAL_UPSAMPLE_MODE = 'bilinear' # 'bilinear' or 'nearest'

    # Captioning Model (Transformer Decoder)
    CAP_D_MODEL = 512 # Must match embedding dim, proj dim
    CAP_NHEAD = 8
    CAP_NUM_DECODER_LAYERS = 4 # Fewer layers might be faster to train initially
    CAP_DIM_FEEDFORWARD = 1024 # Adjust based on d_model and layers
    CAP_DROPOUT = 0.1
    # Beam Search for Validation
    VAL_BEAM_SIZE = 3

    # --- Setup ---
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {DEVICE}")
    if DEVICE == 'cuda':
        print(f"CUDA Devices: {torch.cuda.device_count()}")
        # Set a specific device if needed, e.g., torch.cuda.set_device(0)

    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(PROCESSED_DATA_DIR, exist_ok=True) # Ensure processed data dir exists

    # --- Check and Run Preprocessing ---
    # Check for essential files from preprocessing
    vocab_exists = os.path.exists(os.path.join(PROCESSED_DATA_DIR, 'vocab.json'))
    split_exists = os.path.exists(os.path.join(PROCESSED_DATA_DIR, 'train.txt'))
    tokens_exist = os.path.exists(os.path.join(PROCESSED_DATA_DIR, 'tokens'))
    preprocessing_needed = not (vocab_exists and split_exists and tokens_exist)

    if preprocessing_needed:
        print(f"Preprocessing output not found or incomplete in {PROCESSED_DATA_DIR}. Running preprocessing.")
        try:
            run_preprocessing_direct(DATA_FOLDER, PROCESSED_DATA_DIR, CAPTION_FILE, MAX_LENGTH, WORD_COUNT_THRESHOLD)
            print("Preprocessing complete.")
        except Exception as e:
            print(f"Error during preprocessing: {e}")
            print("Please ensure the dataset path and structure are correct.")
            sys.exit(1) # Exit if preprocessing fails
    else:
        print(f"Preprocessing output found in {PROCESSED_DATA_DIR}. Skipping preprocessing.")


    # --- Main Logic ---
    try:
        if RUN_MODE == 'train_segmentation':
            print(f"\n--- Starting Segmentation Training ({ENCODER_ARCH} + U-Net) ---")

            # Calculate class weights
            class_weights = calculate_class_weights(DATA_FOLDER, PROCESSED_DATA_DIR, split='train', num_classes=NUM_CLASSES).to(DEVICE)

            # Load Data
            print("Loading segmentation datasets...")
            train_dataset = LEVIRCCDataset(DATA_FOLDER, PROCESSED_DATA_DIR, 'train', load_segmentation=True, max_length=MAX_LENGTH)
            val_dataset = LEVIRCCDataset(DATA_FOLDER, PROCESSED_DATA_DIR, 'val', load_segmentation=True, max_length=MAX_LENGTH)
            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
            val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

            # Initialize Model, Loss, Optimizer, Scheduler
            print("Initializing segmentation model...")
            model = SegmentationModel(num_classes=NUM_CLASSES,
                                      encoder_arch=ENCODER_ARCH,
                                      pretrained=ENCODER_PRETRAINED,
                                      freeze_encoder_layers=ENCODER_FREEZE_LAYERS,
                                      decoder_channels=SEG_DECODER_CHANNELS,
                                      final_upsample_mode=SEG_FINAL_UPSAMPLE_MODE).to(DEVICE)

            criterion = nn.CrossEntropyLoss(weight=class_weights).to(DEVICE)
            # Filter parameters that require gradients for the optimizer
            params_to_optimize = filter(lambda p: p.requires_grad, model.parameters())
            optimizer = optim.AdamW(params_to_optimize, lr=LR, weight_decay=WEIGHT_DECAY) # Use AdamW
            # Learning Rate Scheduler
            scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=3, verbose=True) # Monitor validation IoU

            # DataParallel
            if torch.cuda.device_count() > 1:
                print(f"Using {torch.cuda.device_count()} GPUs for Segmentation Training.")
                model = nn.DataParallel(model)

            # Resume Checkpoint
            start_epoch = 0
            best_metric = 0.0 # Use IoU as the metric
            if RESUME_CHECKPOINT and os.path.isfile(RESUME_CHECKPOINT):
                print(f"Loading checkpoint '{RESUME_CHECKPOINT}'")
                checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
                start_epoch = checkpoint['epoch'] + 1
                # Handle DataParallel state dict loading
                state_dict = checkpoint['state_dict']
                if isinstance(model, nn.DataParallel) and not list(state_dict.keys())[0].startswith('module.'):
                    state_dict = {'module.' + k: v for k, v in state_dict.items()}
                elif not isinstance(model, nn.DataParallel) and list(state_dict.keys())[0].startswith('module.'):
                     state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
                model.load_state_dict(state_dict)
                optimizer.load_state_dict(checkpoint['optimizer'])
                best_metric = checkpoint.get('best_metric', 0.0)
                if 'scheduler_state_dict' in checkpoint and hasattr(scheduler, 'load_state_dict'):
                     scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                print(f"Loaded checkpoint (epoch {checkpoint['epoch']}), Best IoU: {best_metric:.4f}")
            elif RESUME_CHECKPOINT:
                 print(f"Checkpoint not found at '{RESUME_CHECKPOINT}'")


            # Training Loop
            print("Starting training loop...")
            for epoch in range(start_epoch, EPOCHS):
                train_loss = train_segmentation(train_loader, model, criterion, optimizer, epoch, DEVICE)
                val_loss, val_metric = validate_segmentation(val_loader, model, criterion, DEVICE) # val_metric is mIoU

                # LR Scheduler Step (based on validation IoU)
                scheduler.step(val_metric)

                # Save checkpoint
                is_best = val_metric > best_metric
                best_metric = max(val_metric, best_metric)

                # Prepare state dict correctly (save without 'module.' prefix)
                model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()

                save_dict = {
                    'epoch': epoch,
                    'arch': ENCODER_ARCH,
                    'state_dict': model_state,
                    'optimizer': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_metric': best_metric, # Best IoU
                    'current_metric': val_metric, # Current IoU
                    'val_loss': val_loss
                }

                checkpoint_path = os.path.join(CHECKPOINT_DIR, f'segmentation_{ENCODER_ARCH}_ep{epoch}.pth.tar')
                torch.save(save_dict, checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")

                if is_best:
                    best_checkpoint_path = os.path.join(CHECKPOINT_DIR, f'BEST_segmentation_{ENCODER_ARCH}.pth.tar')
                    torch.save(save_dict, best_checkpoint_path)
                    print(f"*** Best segmentation checkpoint saved to {best_checkpoint_path} with IoU: {best_metric:.4f} ***")

            print("Segmentation Training Finished.")

        elif RUN_MODE == 'train_captioning':
            print(f"\n--- Starting Captioning Training ({ENCODER_ARCH} + Transformer) ---")

            # Load Data (load_segmentation=False)
            print("Loading captioning datasets...")
            train_dataset = LEVIRCCDataset(DATA_FOLDER, PROCESSED_DATA_DIR, 'train', load_segmentation=False, max_length=MAX_LENGTH)
            val_dataset = LEVIRCCDataset(DATA_FOLDER, PROCESSED_DATA_DIR, 'val', load_segmentation=False, max_length=MAX_LENGTH)

            # Get vocab info from dataset
            vocab_size = train_dataset.vocab_size
            pad_idx = train_dataset.pad_idx
            start_idx = train_dataset.start_idx
            end_idx = train_dataset.end_idx
            idx_to_word = train_dataset.idx_to_word
            print(f"Vocab Size: {vocab_size}, Pad Index: {pad_idx}")


            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
            # Use batch_size=1 for validation if beam search isn't batched
            val_batch_size = 1 # Adjust if beam search implementation handles batches
            val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)


            # Initialize Model, Loss, Optimizer, Scheduler
            print("Initializing captioning model...")
            model = CaptioningModel(vocab_size=vocab_size,
                                    d_model=CAP_D_MODEL,
                                    nhead=CAP_NHEAD,
                                    num_decoder_layers=CAP_NUM_DECODER_LAYERS,
                                    dim_feedforward=CAP_DIM_FEEDFORWARD,
                                    dropout=CAP_DROPOUT,
                                    encoder_arch=ENCODER_ARCH,
                                    pretrained=ENCODER_PRETRAINED,
                                    freeze_encoder_layers=ENCODER_FREEZE_LAYERS,
                                    max_length=MAX_LENGTH).to(DEVICE)

            # Set padding_idx for embedding if applicable
            model.embedding.padding_idx = pad_idx

            # Cross-Entropy loss, ignore padding
            criterion = nn.CrossEntropyLoss(ignore_index=pad_idx).to(DEVICE)
            # Filter parameters that require gradients
            params_to_optimize = filter(lambda p: p.requires_grad, model.parameters())
            optimizer = optim.AdamW(params_to_optimize, lr=LR, weight_decay=WEIGHT_DECAY)
            # LR Scheduler (Monitor validation BLEU-4)
            scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=3, verbose=True)

            # DataParallel
            if torch.cuda.device_count() > 1:
                print(f"Using {torch.cuda.device_count()} GPUs for Captioning Training.")
                model = nn.DataParallel(model)

            # Resume Checkpoint
            start_epoch = 0
            best_metric = 0.0 # Use BLEU-4 as the metric
            if RESUME_CHECKPOINT and os.path.isfile(RESUME_CHECKPOINT):
                print(f"Loading checkpoint '{RESUME_CHECKPOINT}'")
                checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
                start_epoch = checkpoint['epoch'] + 1
                 # Handle DataParallel state dict loading
                state_dict = checkpoint['state_dict']
                if isinstance(model, nn.DataParallel) and not list(state_dict.keys())[0].startswith('module.'):
                    state_dict = {'module.' + k: v for k, v in state_dict.items()}
                elif not isinstance(model, nn.DataParallel) and list(state_dict.keys())[0].startswith('module.'):
                     state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
                model.load_state_dict(state_dict)
                optimizer.load_state_dict(checkpoint['optimizer'])
                best_metric = checkpoint.get('best_metric', 0.0) # Best BLEU-4
                if 'scheduler_state_dict' in checkpoint and hasattr(scheduler, 'load_state_dict'):
                     scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                print(f"Loaded checkpoint (epoch {checkpoint['epoch']}), Best BLEU-4: {best_metric:.4f}")
            elif RESUME_CHECKPOINT:
                 print(f"Checkpoint not found at '{RESUME_CHECKPOINT}'")


            # Training Loop
            print("Starting training loop...")
            for epoch in range(start_epoch, EPOCHS):
                train_loss = train_captioning(train_loader, model, criterion, optimizer, epoch, DEVICE, pad_idx)
                val_metric = validate_captioning(val_loader, model, criterion, DEVICE, idx_to_word, start_idx, end_idx, pad_idx, beam_size=VAL_BEAM_SIZE) # val_metric is BLEU-4

                # LR Scheduler Step (based on validation BLEU-4)
                scheduler.step(val_metric)

                # Save checkpoint
                is_best = val_metric > best_metric
                best_metric = max(val_metric, best_metric)

                # Prepare state dict correctly
                model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()

                save_dict = {
                    'epoch': epoch,
                    'arch': ENCODER_ARCH,
                    'state_dict': model_state,
                    'optimizer': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_metric': best_metric, # Best BLEU-4
                    'current_metric': val_metric, # Current BLEU-4
                    'train_loss': train_loss,
                    'vocab_size': vocab_size, # Save vocab info if needed
                    'pad_idx': pad_idx,
                    'start_idx': start_idx,
                    'end_idx': end_idx
                }

                checkpoint_path = os.path.join(CHECKPOINT_DIR, f'captioning_{ENCODER_ARCH}_ep{epoch}.pth.tar')
                torch.save(save_dict, checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")

                if is_best:
                    best_checkpoint_path = os.path.join(CHECKPOINT_DIR, f'BEST_captioning_{ENCODER_ARCH}.pth.tar')
                    torch.save(save_dict, best_checkpoint_path)
                    print(f"*** Best captioning checkpoint saved to {best_checkpoint_path} with BLEU-4: {best_metric:.4f} ***")


            print("Captioning Training Finished.")

        else:
            print(f"Invalid RUN_MODE: {RUN_MODE}. Choose 'train_segmentation' or 'train_captioning'.")

    except FileNotFoundError as e:
        print(f"\nError: A required file or directory was not found.")
        print(f"Details: {e}")
        print("Please ensure the dataset path (DATASET_ROOT) is correct and that preprocessing ran successfully.")
    except Exception as e:
        print(f"\nAn unexpected error occurred: {e}")
        import traceback
        traceback.print_exc()