In [None]:
# =============================================================================
# Combined Script for Kaggle Notebook with Resume Logic and DataParallel
# This script combines dataset.py, model_arch.py, preprocessing.py, and training.py
# Adjusted for Kaggle file paths, includes resume logic, and uses torch.nn.DataParallel
# for multi-GPU training if available.
# Includes fixes for zero BLEU score issue.
# This version does NOT include LLM integration.
# Corrected AttributeError in LEVIRCCDataset.
# Corrected structure of references for BLEU calculation.
# =============================================================================

import sys
import os
import json
import argparse
import numpy as np
from collections import defaultdict
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader # Explicitly import Dataset and DataLoader
from torch import nn
import torch.nn.functional as F
import math
from imageio.v2 import imread # Make sure imageio is installed in your Kaggle environment
from random import randint
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from tqdm import tqdm

# =============================================================================
# --- Hardcoded Paths for Kaggle ---
# IMPORTANT: You MUST update DATASET_ROOT to the correct path of your dataset
# after adding it as an input to your Kaggle Notebook.
# It will typically be something like '/kaggle/input/your-dataset-name/LEVIR-MCI-dataset'
# Check the input path in the Kaggle Notebook's file explorer.
# SAVE_OUTPUT_DIR is set to the standard Kaggle output directory, which is persistent
# within a single notebook session and can be saved with the notebook.
# =============================================================================
# Replace 'your-dataset-name' with the actual name you give your dataset when uploading to Kaggle
# You might need to inspect the input directory structure to get the exact path.
DATASET_ROOT = '/kaggle/input/levir-mci-dataset/LEVIR-MCI-dataset' # <-- *** UPDATE THIS PATH ***
SAVE_OUTPUT_DIR = '/kaggle/working/' # Standard Kaggle writable output directory

# =============================================================================
# --- dataset.py content starts here ---
# =============================================================================

# Define the mapping from RGB colors to class IDs
# Background: 0 (black), Road: 1 (grey), Building: 2 (white)
COLOR_TO_ID_MAPPING = {
    (0, 0, 0): 0,          # Background
    (128, 128, 128): 1,    # Road (Grey)
    (255, 255, 255): 2,    # Building (White)
}
NUM_CLASSES = 3 # Background, Road, Building

def rgb_to_class_id_mask(rgb_mask_np):
    """Converts an RGB mask (H, W, 3) to a class ID mask (H, W)."""
    h, w, c = rgb_mask_np.shape
    if c != 3:
        # Handle grayscale masks if they appear unexpectedly
        if c == 1 or rgb_mask_np.ndim == 2:
             print(f"Warning: Expected RGB mask (H, W, 3) but got grayscale shape {rgb_mask_np.shape}. Trying to process.")
             # Attempt to handle based on intensity if possible (e.g., if 0, 128, 255 are used)
             # This requires specific logic based on how grayscale encodes classes.
             # For now, assuming direct mapping if values are 0, 1, 2 might be wrong.
             # Safest is to raise error or map all to background.
             class_id_mask = np.full((h, w), 0, dtype=np.int64) # Default to background class (0)
             if rgb_mask_np.ndim == 2:
                grey_mask = rgb_mask_np
             else: # c == 1
                grey_mask = rgb_mask_np.squeeze(-1)

             # Apply mapping based on intensity values IF they match expected grayscale equivalents
             class_id_mask[grey_mask == 0] = 0    # Background
             class_id_mask[grey_mask == 128] = 1  # Road
             class_id_mask[grey_mask == 255] = 2  # Building
             return class_id_mask
        else:
            raise ValueError(f"Input mask must have 3 channels (RGB), but got {c} with shape {rgb_mask_np.shape}")

    # Initialize with background class ID
    class_id_mask = np.full((h, w), 0, dtype=np.int64)

    # Iterate through the defined mappings
    for color, class_id in COLOR_TO_ID_MAPPING.items():
        matches = np.all(rgb_mask_np == np.array(color, dtype=rgb_mask_np.dtype), axis=-1)
        class_id_mask[matches] = class_id

    # Optional: Check for pixels not matching any defined color
    # known_mask = np.zeros((h, w), dtype=bool)
    # for color in COLOR_TO_ID_MAPPING:
    #     known_mask |= np.all(rgb_mask_np == np.array(color, dtype=rgb_mask_np.dtype), axis=-1)
    # if not np.all(known_mask):
    #     unmatched_pixels = rgb_mask_np[~known_mask]
    #     unique_unmatched = np.unique(unmatched_pixels.reshape(-1, 3), axis=0)
    #     print(f"Warning: Some pixels in the RGB mask did not match known class colors. Unique unmatched colors: {unique_unmatched}")

    return class_id_mask


class LEVIRCCDataset(Dataset):
    def __init__(self, data_folder, processed_data_dir, split,
                 load_segmentation=True,
                 max_length=41,
                 vocab_file='vocab.json',
                 allow_unk=True,
                 max_iters=None):
        """
        Args:
            data_folder (str): Path to the root LEVIR-MCI dataset folder (containing 'images/').
            processed_data_dir (str): Path to the folder where preprocessed data (vocab.json, splits, tokens/) is saved.
            split (str): 'train', 'val', or 'test'.
            load_segmentation (bool): If True, loads and returns segmentation maps (as class IDs).
            max_length (int): Max caption sequence length (from preprocessing).
            vocab_file (str): Name of the vocabulary JSON file within processed_data_dir.
            allow_unk (bool): Whether to allow unknown tokens when loading captions.
            max_iters (int, optional): If specified, repeats the dataset to provide this many items per epoch.
        """
        self.data_folder = data_folder
        self.processed_data_dir = processed_data_dir
        self.split = split
        self.load_segmentation = load_segmentation
        self.max_length = max_length
        self.allow_unk = allow_unk

        # Image normalization parameters (verify these are suitable for your images)
        # These are standard ImageNet means/stds, might need to calculate for your dataset
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) * 255 # Standard ImageNet mean * 255
        self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32) * 255  # Standard ImageNet std * 255


        assert self.split in {'train', 'val', 'test'}

        # ---- Load Vocabulary ----
        vocab_path = os.path.join(self.processed_data_dir, vocab_file)
        try:
            with open(vocab_path, 'r') as f:
                self.word_vocab = json.load(f)
            self.idx_to_word = {v: k for k, v in self.word_vocab.items()}
            self.pad_idx = self.word_vocab.get('<NULL>', 0)
            self.start_idx = self.word_vocab.get('<START>', 2) # Ensure start_idx is available
            self.end_idx = self.word_vocab.get('<END>', 3) # Ensure end_idx is available
            # Removed the erroneous print statement here
            # print(f"Decoder loaded vocabulary with size: {self.vocab_size}")
        except FileNotFoundError:
            raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}. Run preprocessing script first.")

        # --- Load Image Filenames/Base Names ---
        split_file_path = os.path.join(self.processed_data_dir, f'{split}.txt')
        try:
            with open(split_file_path, 'r') as f:
                self.img_ids = [line.strip() for line in f if line.strip()]
        except FileNotFoundError:
             raise FileNotFoundError(f"Split file not found at {split_file_path}. Run preprocessing script first.")

        if not self.img_ids:
            raise ValueError(f"No image IDs found in split file: {split_file_path}")

        # ---- Prepare file paths ----
        self.files = []
        image_base_path = os.path.join(self.data_folder, 'images', self.split)
        token_base_path = os.path.join(self.processed_data_dir, 'tokens')
        label_folder_name = 'label' # Assuming label folder name is 'label'

        missing_files_count = 0
        for img_base_name in self.img_ids:
            img_file_name = f"{img_base_name}.png"
            token_file_name = f"{img_base_name}.json"

            file_paths = {
                "name": img_base_name,
                "imgA": os.path.join(image_base_path, 'A', img_file_name),
                "imgB": os.path.join(image_base_path, 'B', img_file_name),
                "token": os.path.join(token_base_path, token_file_name)
            }
            seg_path = None
            if self.load_segmentation:
                # Use the updated label folder name
                seg_path = os.path.join(image_base_path, label_folder_name, img_file_name)
                file_paths["seg_label"] = seg_path # Changed key name for clarity

            # Check if ALL required files exist
            paths_to_check = [file_paths["imgA"], file_paths["imgB"], file_paths["token"]]
            if self.load_segmentation:
                 paths_to_check.append(seg_path)

            files_exist = all(os.path.exists(p) for p in paths_to_check if p is not None)

            if not files_exist:
                missing_files_count += 1
                continue

            self.files.append(file_paths)

        if missing_files_count > 0:
            print(f"Warning: Skipped {missing_files_count} entries due to missing files in split '{self.split}'.")

        if not self.files:
             raise ValueError(f"No valid file sets found for split '{self.split}'. Check file paths and preprocessing output.")

        # --- Handle max_iters ---
        self.max_iters = max_iters
        if max_iters is not None and max_iters > 0:
            if not self.files:
                 raise ValueError(f"Cannot use max_iters > 0 when no valid files were loaded for split '{self.split}'.")
            n_repeat = int(np.ceil(max_iters / len(self.files)))
            self.files = self.files * n_repeat

        print(f"Initialized LEVIRCCDataset for split '{self.split}' with {len(self.files)} items (after potential repeat for max_iters).")


    def __len__(self):
        if self.max_iters is not None and self.max_iters > 0:
            return self.max_iters
        return len(self.files)

    def __getitem__(self, index):
        actual_index = index % len(self.files) if self.files else 0
        if not self.files:
             raise IndexError("Dataset is empty.")
        datafiles = self.files[actual_index]

        # --- Load Images ---
        try:
            imgA_np = np.array(imread(datafiles["imgA"]), dtype=np.uint8)
            imgB_np = np.array(imread(datafiles["imgB"]), dtype=np.uint8)
            if imgA_np.ndim != 3 or imgA_np.shape[-1] != 3 or imgB_np.ndim != 3 or imgB_np.shape[-1] != 3:
                 raise ValueError(f"Image dimensions incorrect for {datafiles['name']}. Expected (H, W, 3), got {imgA_np.shape} and {imgB_np.shape}")
        except Exception as e:
             print(f"Error loading images for {datafiles['name']}: {e}")
             raise

        # Convert to float32 and HWC -> CHW
        imgA = np.asarray(imgA_np, np.float32).transpose(2, 0, 1)
        imgB = np.asarray(imgB_np, np.float32).transpose(2, 0, 1)

        # --- Normalize Images ---
        imgA = (imgA - self.mean[:, None, None]) / self.std[:, None, None]
        imgB = (imgB - self.mean[:, None, None]) / self.std[:, None, None]

        # --- Load and Process Segmentation ---
        seg_mask_class_ids = None
        if self.load_segmentation:
            try:
                # Load mask using the "seg_label" key
                seg_label_np = np.array(imread(datafiles["seg_label"]), dtype=np.uint8)

                # Convert RGB mask to Class ID mask (H, W) using the defined mapping
                seg_mask_class_ids_np = rgb_to_class_id_mask(seg_label_np)

                # Convert numpy array (int64) to torch LongTensor
                seg_mask_class_ids = torch.from_numpy(seg_mask_class_ids_np).long()

            except FileNotFoundError:
                 print(f"Error loading segmentation label for {datafiles['name']}. Path: {datafiles['seg_label']}")
                 raise
            except Exception as e:
                 print(f"Error processing segmentation label for {datafiles['name']}: {e}")
                 raise


        # --- Load Captions ---
        token_all = np.array([[self.pad_idx] * self.max_length])
        token_all_len = np.array([0])
        caption_list = [] # Initialize caption_list

        try:
            with open(datafiles["token"], 'r') as f:
                caption_list = json.load(f)

            if not caption_list:
                print(f"Warning: Empty caption list found in file {datafiles['token']} for {datafiles['name']}. Using default padding.")
            else:
                # Ensure consistency check happens only if list is not empty
                for i, cap in enumerate(caption_list):
                     if len(cap) != self.max_length:
                          raise ValueError(f"Inconsistent caption length in {datafiles['token']} for caption {i}. Expected {self.max_length}, got {len(cap)}.")
                token_all = np.array(caption_list, dtype=np.int64)
                token_all_len = np.array([(caption != self.pad_idx).sum() for caption in token_all], dtype=np.int64)

        except FileNotFoundError:
             print(f"Error loading token file for {datafiles['name']}. Path: {datafiles['token']}")
             raise
        except Exception as e:
             print(f"Error processing token file {datafiles['token']}: {e}")
             raise

        # --- Select one caption for training ---
        if not caption_list:
             token = token_all[0]
             token_len = token_all_len[0].item()
        else:
            token_index = randint(0, len(caption_list) - 1)
            token = token_all[token_index]
            token_len = token_all_len[token_index].item()


        # --- Convert to Tensors ---
        imgA = torch.from_numpy(imgA).float()
        imgB = torch.from_numpy(imgB).float()
        token = torch.from_numpy(token).long()
        token_len = torch.tensor(token_len).long()
        token_all = torch.from_numpy(token_all).long()
        token_all_len = torch.from_numpy(token_all_len).long()


        # --- Return Dictionary ---
        batch = {
            'imgA': imgA,
            'imgB': imgB,
            'token': token,
            'token_len': token_len,
            'token_all': token_all,
            'token_all_len': token_all_len,
            'name': datafiles['name'],
        }
        # Add segmentation mask if loaded
        if self.load_segmentation and seg_mask_class_ids is not None:
             batch['seg_mask'] = seg_mask_class_ids # Key name matches training loop

        return batch

# =============================================================================
# --- model_arch.py content starts here ---
# =============================================================================

# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # Shape: (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# --- Image Encoder ---
class Encoder(nn.Module):
    def __init__(self, encoder_type="resnet50", pretrained=True, freeze=True):
        super().__init__()
        self.encoder_type = encoder_type
        self.freeze = freeze
        self.out_channels = 0 # Will be set after loading backbone

        try:
            if encoder_type == "resnet50":
                # Use torch.hub.load to get the pretrained model
                self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=pretrained)
                 # Get out_channels before removing layers
                self.out_channels = self.backbone.fc.in_features # Usually 2048 for ResNet50
                self.backbone = nn.Sequential(*list(self.backbone.children())[:-2]) # Keep up to layer4 (before avgpool and fc)
            elif encoder_type == "efficientnet_b0":
                 # EfficientNet needs careful handling - check specific model architecture
                _effnet = torch.hub.load('pytorch/vision:v0.10.0', 'efficientnet_b0', pretrained=pretrained)
                self.backbone = _effnet.features # Keep the feature extractor part
                # Determine output channels - for B0, the last block output before pooling/classifier
                # This might vary slightly depending on torchvision version. Check model structure.
                # Typically the last conv layer in features has 1280 channels for b0
                self.out_channels = 1280 # Adjust if necessary by inspecting _effnet.classifier[1].in_features
            else:
                raise ValueError(f"Unknown encoder type: {encoder_type}")

            if self.out_channels == 0:
                 raise ValueError(f"Could not determine output channels for encoder type: {encoder_type}")

        except Exception as e:
             print(f"Error loading pretrained model '{encoder_type}': {e}")
             raise

        if freeze:
            print(f"Freezing weights for encoder: {encoder_type}")
            for param in self.backbone.parameters():
                param.requires_grad = False
        else:
            print(f"Training weights for encoder: {encoder_type}")


    def forward(self, img):
        features = self.backbone(img)
        # Expected output shape (B, C_out, H_feat, W_feat)
        # For ResNet50 pre-avgpool, H_feat = H/32, W_feat = W/32 (e.g., 256 -> 8)
        # For EfficientNet features, check output shape, might be H/32, W/32 as well
        return features


# --- Attentive Feature Fusion ---
class AttentiveEncoder(nn.Module):
    def __init__(self, encoder_dim, n_layers=1, heads=8, dropout=0.1):
        super().__init__()
        self.encoder_dim = encoder_dim
        # Simple projection/combination instead of MHA for potential simplicity/efficiency
        # You can revert to the MultiheadAttention version if preferred
        self.combination = nn.Sequential(
            nn.Conv2d(encoder_dim * 2, encoder_dim, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(encoder_dim, encoder_dim, kernel_size=3, padding=1), # Add some local context mixing
            nn.ReLU(),
        )
        # --- Alternative: Using Attention (Original approach) ---
        # self.attention_layers = nn.ModuleList([
        #     nn.MultiheadAttention(embed_dim=encoder_dim, num_heads=heads, dropout=dropout, batch_first=True)
        #     for _ in range(n_layers)
        # ])
        # self.norm1 = nn.LayerNorm(encoder_dim)
        # self.norm2 = nn.LayerNorm(encoder_dim)
        # self.dropout = nn.Dropout(dropout)
        # self.combination = nn.Sequential(
        #     nn.Conv2d(encoder_dim * 2, encoder_dim, kernel_size=1, padding=0),
        #     nn.ReLU(),
        # )

    def forward(self, feat1, feat2):
        b, c, h, w = feat1.shape
        assert c == self.encoder_dim, f"Input feature dim ({c}) doesn't match encoder_dim ({self.encoder_dim})"\

        # Simple Concatenation and Combination
        combined = torch.cat([feat1, feat2], dim=1) # (B, 2*C, H, W)
        final_features = self.combination(combined) # (B, C, H, W)

        # --- Alternative: Attention mechanism ---
        # feat1_seq = feat1.flatten(2).permute(0, 2, 1) # (B, H*W, C)
        # feat2_seq = feat2.flatten(2).permute(0, 2, 1) # (B, H*W, C)
        # for attn_layer in self.attention_layers:
        #      attn_output1, _ = attn_layer(query=feat1_seq, key=feat2_seq, value=feat2_seq)
        #      feat1_seq = self.norm1(feat1_seq + self.dropout(attn_output1))
        #      attn_output2, _ = attn_layer(query=feat2_seq, key=feat1_seq, value=feat1_seq)
        #      feat2_seq = self.norm2(feat2_seq + self.dropout(attn_output2))
        # attended_feat1 = feat1_seq.permute(0, 2, 1).view(b, c, h, w)
        # attended_feat2 = feat2_seq.permute(0, 2, 1).view(b, c, h, w)
        # combined = torch.cat([attended_feat1, attended_feat2], dim=1)
        # final_features = self.combination(combined)

        return final_features # Return only the fused features

# --- Caption Decoder (Transformer-based) ---
class DecoderTransformer(nn.Module):
    def __init__(self, vocab_path, embed_dim, encoder_dim_attentive, n_layers=2, heads=8, dropout=0.1, ff_dim=2048):
        super().__init__()

        try:
            with open(vocab_path, 'r') as f:
                self.word_vocab = json.load(f)
            self.vocab_size = len(self.word_vocab)
            self.pad_idx = self.word_vocab.get('<NULL>', 0)
            self.start_idx = self.word_vocab.get('<START>', 2)
            self.end_idx = self.word_vocab.get('<END>', 3)
            print(f"Decoder loaded vocabulary with size: {self.vocab_size}")
        except FileNotFoundError:
             raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}")

        self.embed_dim = embed_dim
        self.encoder_dim_attentive = encoder_dim_attentive

        self.embedding = nn.Embedding(self.vocab_size, embed_dim, padding_idx=self.pad_idx)
        self.pos_encoder = PositionalEncoding(embed_dim, dropout)

        # Project FUSED Encoder Features to Decoder's Embedding Dimension
        self.encoder_to_decoder_proj = nn.Linear(encoder_dim_attentive, embed_dim)

        # Using batch_first=True for consistency with DataLoader batch format
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim, nhead=heads, dim_feedforward=ff_dim,
            dropout=dropout, batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.fc_out = nn.Linear(embed_dim, self.vocab_size)
        self.dropout = nn.Dropout(dropout)

    # --- Masking for Transformer Decoder ---
    # This function generates the causal mask for the target sequence (captions)
    def _generate_square_subsequent_mask(self, sz, device):
        # Create a mask where True indicates positions that *cannot* be attended to.
        # We want to attend to current and previous tokens, but not future ones.
        # triu creates an upper triangular matrix (including diagonal)
        # We want the *lower* triangular part for causal attention
        # So, we create upper triangular and invert it (or create lower directly)
        # Let's create lower triangle mask (including diagonal)
        mask = torch.triu(torch.ones(sz, sz, device=device) * float('-inf'), diagonal=1)
        return mask # Shape: (sz, sz)


    def forward(self, attentive_features, captions):
        """ Forward pass for training (teacher forcing). """
        b, c_enc, h, w = attentive_features.shape
        max_len = captions.size(1)

        # 1. Project and reshape attentive features to act as memory
        # (B, C_enc, H, W) -> (B, H*W, C_enc) -> (B, H*W, C_dec)
        memory = attentive_features.flatten(2).permute(0, 2, 1) # (B, H*W, C_enc)
        memory = self.encoder_to_decoder_proj(memory) # (B, H*W, embed_dim)

        # 2. Embed captions and add positional encoding
        # Note: PositionalEncoding expects (seq_len, batch_size, embed_dim) if batch_first=False (default)
        # But our TransformerDecoder is batch_first=True.
        # Let's keep PE consistent with TransformerDecoder input shape (B, seq_len, embed_dim)
        embedded_captions = self.embedding(captions) * math.sqrt(self.embed_dim) # (B, max_len, embed_dim)
        pos_encoded_captions = self.pos_encoder(embedded_captions.permute(1, 0, 2)).permute(1, 0, 2) # Apply PE and swap back

        # 3. Prepare masks
        # Causal mask to prevent attending to future tokens in the target sequence
        tgt_mask = self._generate_square_subsequent_mask(max_len, captions.device)

        # Padding mask to prevent attending to padding tokens in the target sequence
        # True indicates positions that should be ignored (masked)
        tgt_padding_mask = (captions == self.pad_idx) # (B, max_len)

        # Memory key padding mask: Mask for the encoder output (attentive_features)
        # Not needed here as attentive_features are not padded in the sequence dimension (H*W)
        memory_key_padding_mask = None

        # 4. Pass through Transformer Decoder
        # tgt: target sequence (pos_encoded_captions)
        # memory: encoder output (projected attentive_features)
        # tgt_mask: causal mask for target
        # tgt_key_padding_mask: padding mask for target
        # memory_key_padding_mask: padding mask for memory (encoder output)
        decoder_output = self.transformer_decoder(
            tgt=pos_encoded_captions,
            memory=memory, # Encoder output as memory
            tgt_mask=tgt_mask, # Causal mask
            tgt_key_padding_mask=tgt_padding_mask, # Padding mask (if used)
            memory_key_padding_mask=memory_key_padding_mask # Memory padding mask (if used)
        ) # Output shape: (B, max_len, embed_dim)

        # 5. Final prediction layer
        predictions = self.fc_out(decoder_output) # (B, max_len, vocab_size)
        return predictions


    @torch.no_grad() # Ensure no gradients are computed during generation
    def generate_caption(self, attentive_features, max_length=41, beam_size=1):
        """ Generates captions using greedy search (beam search not implemented). """
        if beam_size != 1:
            raise NotImplementedError("Beam search not yet implemented.")

        self.eval() # Set model to evaluation mode
        batch_size = attentive_features.size(0)
        device = attentive_features.device

        # --- Project Attentive Features (Memory) ---
        memory = attentive_features.flatten(2).permute(0, 2, 1) # (B, H*W, C_enc)
        memory = self.encoder_to_decoder_proj(memory) # (B, H*W, embed_dim)

        # --- Initialize Captions ---
        # Start with the <START> token for each sequence in the batch
        generated_caps = torch.full((batch_size, 1), self.start_idx, dtype=torch.long, device=device)
        # Keep track of which sequences have finished (generated <END>)
        completed_sequences = torch.zeros(batch_size, dtype=torch.bool, device=device)

        # --- Autoregressive Decoding Loop ---
        # Generate tokens one by one up to max_length
        for t in range(1, max_length): # Start from the second token position (t=1)
            current_seq_len = generated_caps.size(1) # Current length of generated sequences

            # Embed the currently generated sequence
            embedded_captions = self.embedding(generated_caps) * math.sqrt(self.embed_dim) # (B, current_seq_len, embed_dim)

            # Add positional encoding
            # PE expects (seq_len, batch_size, embed_dim), then swap back
            pos_encoded_captions = self.pos_encoder(embedded_captions.permute(1, 0, 2)).permute(1, 0, 2) # (B, current_seq_len, embed_dim)

            # Prepare causal mask for the current sequence length
            # This mask prevents attending to future tokens within the generated sequence
            tgt_mask = self._generate_square_subsequent_mask(current_seq_len, device)

            # No target padding mask needed during greedy generation as we generate token by token
            tgt_padding_mask = None # (generated_caps == self.pad_idx) # Not needed here

            # Memory key padding mask: Not needed for encoder output
            memory_key_padding_mask = None

            # Pass through Transformer Decoder
            # Use the entire generated sequence so far as target
            decoder_output = self.transformer_decoder(
                tgt=pos_encoded_captions,
                memory=memory, # Encoder output as memory
                tgt_mask=tgt_mask, # Causal mask
                tgt_key_padding_mask=tgt_padding_mask, # Padding mask (if used)
                memory_key_padding_mask=memory_key_padding_mask # Memory padding mask (if used)
            ) # Output shape: (B, current_seq_len, embed_dim)

            # Predict the *next* token based on the last position of the decoder output
            predictions = self.fc_out(decoder_output[:, -1, :]) # (B, vocab_size)

            # Greedy selection: Choose the token with the highest probability
            predicted_idx = predictions.argmax(dim=-1) # (B,)

            # Append the predicted token to the generated sequences
            # Only append to sequences that haven't finished yet
            predicted_idx = predicted_idx.masked_fill(completed_sequences, self.pad_idx) # Mask out predictions for finished sequences
            generated_caps = torch.cat([generated_caps, predicted_idx.unsqueeze(1)], dim=1)

            # Update completion status: Mark sequences that generated the <END> token
            just_completed = (predicted_idx == self.end_idx)
            completed_sequences |= just_completed

            # Stop decoding if all sequences in the batch have generated the <END> token
            if completed_sequences.all():
                break

        # --- Padding to ensure max_length ---
        # If the loop finished before reaching max_length, pad the remaining positions
        final_len = generated_caps.size(1)
        if final_len < max_length:
             padding = torch.full((batch_size, max_length - final_len), self.pad_idx, dtype=torch.long, device=device)
             generated_caps = torch.cat([generated_caps, padding], dim=1)
        # If somehow generated more than max_length (shouldn't happen with loop limit)
        elif final_len > max_length:
             generated_caps = generated_caps[:, :max_length]

        # --- Post-processing: Set tokens after <END> to <NULL> ---
        # This ensures that even if tokens were generated after <END> before the loop broke,
        # they are treated as padding.
        for i in range(batch_size):
            # Find the first occurrence of the <END> token
            end_indices = (generated_caps[i] == self.end_idx).nonzero(as_tuple=True)
            if end_indices[0].numel() > 0: # If <END> token is found
                end_pos = end_indices[0][0] # Get the index of the first <END> token
                # Set all tokens after the <END> token to the padding index
                generated_caps[i, end_pos + 1:] = self.pad_idx


        return generated_caps # Final shape: (B, max_length)


# --- Segmentation Head ---
# Simple example using ConvTranspose2d for upsampling
# Assumes attentive_features are H/32, W/32 relative to input image size
class SegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes, target_size=(256, 256)):
        super().__init__()
        self.target_size = target_size
        # Upsample H/32, W/32 -> H/16, W/16 -> H/8, W/8 -> H/4, W/4 -> H/2, W/2 -> H, W
        # Adjust intermediate channels as needed
        inter_channels = in_channels // 2 # Example: 1024 if in_channels=2048
        inter_channels2 = inter_channels // 2 # Example: 512
        inter_channels3 = inter_channels2 // 2 # Example: 256
        inter_channels4 = inter_channels3 // 2 # Example: 128
        inter_channels5 = inter_channels4 // 2 # Example: 64


        self.upconv1 = nn.ConvTranspose2d(in_channels, inter_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.bn1 = nn.BatchNorm2d(inter_channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(inter_channels, inter_channels, kernel_size=3, padding=1) # Conv after upsample
        self.bn1_c = nn.BatchNorm2d(inter_channels)
        self.relu1_c = nn.ReLU()


        self.upconv2 = nn.ConvTranspose2d(inter_channels, inter_channels2, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.bn2 = nn.BatchNorm2d(inter_channels2)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(inter_channels2, inter_channels2, kernel_size=3, padding=1)
        self.bn2_c = nn.BatchNorm2d(inter_channels2)
        self.relu2_c = nn.ReLU()


        self.upconv3 = nn.ConvTranspose2d(inter_channels2, inter_channels3, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.bn3 = nn.BatchNorm2d(inter_channels3)
        self.relu3 = nn.ReLU()
        self.conv3 = nn.Conv2d(inter_channels3, inter_channels3, kernel_size=3, padding=1)
        self.bn3_c = nn.BatchNorm2d(inter_channels3)
        self.relu3_c = nn.ReLU()

        # Need two more upsampling stages (stride 2) to get from H/8 to H
        self.upconv4 = nn.ConvTranspose2d(inter_channels3, inter_channels4, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.bn4 = nn.BatchNorm2d(inter_channels4)
        self.relu4 = nn.ReLU()

        self.upconv5 = nn.ConvTranspose2d(inter_channels4, inter_channels5, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.bn5 = nn.BatchNorm2d(inter_channels5)
        self.relu5 = nn.ReLU()

        # Final convolution to get num_classes channels
        self.final_conv = nn.Conv2d(inter_channels5, num_classes, kernel_size=1)

        # Alternative: Using bilinear interpolation + conv
        # self.upsample = nn.Upsample(size=target_size, mode='bilinear', align_corners=False)
        # self.final_conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        # Using ConvTranspose
        x = self.relu1(self.bn1(self.upconv1(x)))
        x = self.relu1_c(self.bn1_c(self.conv1(x)))

        x = self.relu2(self.bn2(self.upconv2(x)))
        x = self.relu2_c(self.bn2_c(self.conv2(x)))

        x = self.relu3(self.bn3(self.upconv3(x)))
        x = self.relu3_c(self.bn3_c(self.conv3(x)))

        x = self.relu4(self.bn4(self.upconv4(x)))
        x = self.relu5(self.bn5(self.upconv5(x)))

        x = self.final_conv(x)

        # Ensure output size matches target_size (e.g., 256x256)
        # This might be needed if ConvTranspose output padding doesn't align perfectly
        # Or if using the interpolation method.
        # x = F.interpolate(x, size=self.target_size, mode='bilinear', align_corners=False)

        return x # Output shape (B, num_classes, H, W)

# --- Main Multi-Task Model ---
class ChangeDetectionCaptioningModel(nn.Module):
    def __init__(self, args, vocab_path):
        super().__init__()
        self.args = args

        # Image Encoder (Siamese)
        self.encoder = Encoder(
            encoder_type=args.encoder_type,
            pretrained=not args.encoder_load_random,
            freeze=args.freeze_encoder
        )
        encoder_out_channels = self.encoder.out_channels

        # Attentive Feature Fusion
        self.attentive_encoder = AttentiveEncoder(
            encoder_dim=encoder_out_channels,
            n_layers=args.attn_layers, # Use separate arg if needed
            heads=args.heads,
            dropout=args.dropout
        )
        # The output dim of attentive_encoder is also encoder_out_channels (based on current AttentiveEncoder design)
        attentive_feature_dim = encoder_out_channels

        # Segmentation Head
        self.segmentation_head = SegmentationHead(
            in_channels=attentive_feature_dim,
            num_classes=NUM_CLASSES,
            target_size=(args.image_size, args.image_size) # Pass image size from args
        )

        # Caption Decoder
        self.decoder = DecoderTransformer(
            vocab_path=vocab_path,
            embed_dim=args.embed_dim,
            encoder_dim_attentive=attentive_feature_dim, # Use output dim of attentive fusion
            n_layers=args.decoder_layers, # Use separate arg if needed
            heads=args.heads,
            dropout=args.dropout,
            ff_dim=args.ff_dim
        )

    def forward(self, imgA, imgB, captions=None, mode='train'):
        """
        Args:
            imgA (Tensor): Batch of 'before' images (B, C, H, W).
            imgB (Tensor): Batch of 'after' images (B, C, H, W).
            captions (Tensor, optional): Ground truth captions for training (B, max_len).
            mode (str): 'train' or 'eval'. In 'eval', only returns predictions.

        Returns:
            Dictionary containing:
            - 'seg_logits': Logits for segmentation map (B, num_classes, H, W).
            - 'caption_logits': Logits for caption tokens (B, max_len, vocab_size) (only in train mode).
            - 'generated_captions': Generated caption indices (B, max_len) (only in eval mode).
        """
        # 1. Extract features
        featA = self.encoder(imgA) # (B, C_enc, H_feat, W_feat)
        featB = self.encoder(imgB) # (B, C_enc, H_feat, W_feat)

        # 2. Fuse features attentively
        attentive_features = self.attentive_encoder(featA, featB) # (B, C_att, H_feat, W_feat)

        # 3. Predict Segmentation Mask
        seg_logits = self.segmentation_head(attentive_features) # (B, num_classes, H_out, W_out)

        # 4. Predict/Generate Captions
        outputs = {'seg_logits': seg_logits}
        if mode == 'train':
            if captions is None:
                 raise ValueError("Captions must be provided in training mode.")
            # Pass attentive features and GT captions to decoder
            caption_logits = self.decoder(attentive_features, captions) # (B, max_len, vocab_size)
            outputs['caption_logits'] = caption_logits
        elif mode == 'eval':
            # Generate captions using the decoder's generation method
            generated_captions = self.decoder.generate_caption(
                attentive_features,
                max_length=self.args.max_length # Use max_length from args
            ) # (B, max_length)
            outputs['generated_captions'] = generated_captions
        else:
            raise ValueError(f"Unknown mode: {mode}")

        return outputs


# =============================================================================
# --- preprocessing.py content starts here ---
# =============================================================================

# --- Special Tokens ---
SPECIAL_TOKENS = {
    '<NULL>': 0, # Padding
    '<UNK>': 1,  # Unknown word
    '<START>': 2, # Start of sequence
    '<END>': 3,   # End of sequence
}

# --- Helper Functions (from preprocessing.py) ---
def build_vocab(sequences, min_token_count=1):
    """Builds vocabulary from pre-tokenized sequences."""
    token_to_count = defaultdict(int)
    for img_data in sequences:
        for seq_tokens in img_data['tokenized_captions']: # Already includes <START>/<END>
             # Skip counting special tokens if they are already added
            for token in seq_tokens:
                 if token not in SPECIAL_TOKENS: # Only count actual words
                    token_to_count[token] += 1

    token_to_idx = {}
    # Add special tokens first
    for token, idx in SPECIAL_TOKENS.items():
        token_to_idx[token] = idx

    # Add words meeting the frequency threshold, sorted by frequency
    sorted_tokens = sorted(token_to_count.items(), key=lambda item: item[1], reverse=True)
    current_idx = len(token_to_idx) # Start indexing after special tokens
    for token, count in sorted_tokens:
        if count >= min_token_count and token not in token_to_idx:
            token_to_idx[token] = current_idx
            current_idx += 1

    return token_to_idx

def encode(seq_tokens, token_to_idx, allow_unk=False):
    """Encodes a sequence of tokens into indices."""
    seq_idx = []
    unk_idx = token_to_idx.get('<UNK>')
    for token in seq_tokens:
        idx = token_to_idx.get(token, unk_idx)
        if idx == unk_idx and not allow_unk:
             # If allow_unk is False and token is not in vocab and not <UNK> itself
             if token != '<UNK>':
                raise KeyError(f'Token "{token}" not in vocab and allow_unk=False')
             # If token is '<UNK>' and allow_unk is False, it means <UNK> wasn't added to vocab
             elif token == '<UNK>' and unk_idx is None:
                 raise KeyError(f'Token "<UNK>" itself is not in vocab (check threshold) and allow_unk=False')

        elif idx is None and token == '<UNK>':
             # This case should ideally not happen if <UNK> is always added, but defensive
             print(f"Warning: Encountered '<UNK>' token which is somehow not in the vocabulary.")
             # Decide how to handle - skip? raise error? For now, let's skip it.
             continue
        elif idx is None: # Should only happen if allow_unk=False and token not found
             # Should have been caught by the earlier check, but defensive coding
             raise KeyError(f'Unhandled case: Token "{token}" not found.')

        # If idx is None here, it means token was not in vocab and allow_unk is True,
        # and unk_idx is None (meaning <UNK> wasn't added to vocab).
        # This shouldn't happen if SPECIAL_TOKENS are always added.
        if idx is not None:
             seq_idx.append(idx)
        else:
             # Fallback for unexpected None idx
             print(f"Warning: Skipping token '{token}' with None index.")


    return seq_idx


def pad_sequence(sequence, max_length, pad_value=0):
    """Pads a sequence (list of indices) to a specified max_length."""
    padding_length = max_length - len(sequence)
    if padding_length < 0:
        # Truncate if longer than max_length
        # print(f"Warning: Truncating sequence of length {len(sequence)} to max_length {max_length}") # Reduce print frequency
        return sequence[:max_length]
    return sequence + [pad_value] * padding_length

# --- Main Preprocessing Logic (as a function) ---
def run_preprocessing(args):
    if args.dataset != 'LEVIR_MCI':
        raise ValueError(f"Dataset '{args.dataset}' not supported by this script.")

    # Use the hardcoded dataset root for input
    input_captions_json = os.path.join(DATASET_ROOT, 'LevirCCcaptions.json')
    input_image_dir = os.path.join(DATASET_ROOT, 'images') # Base image dir
    # Use the hardcoded save directory for output
    token_save_dir = os.path.join(SAVE_OUTPUT_DIR, 'tokens')

    print(f"Using LEVIR_MCI dataset from: {DATASET_ROOT}")
    print(f"Saving processed data to: {SAVE_OUTPUT_DIR}")

    # Create output directories if they don't exist
    os.makedirs(SAVE_OUTPUT_DIR, exist_ok=True)
    os.makedirs(token_save_dir, exist_ok=True)

    print('Loading captions from JSON...')
    try:
        with open(input_captions_json, 'r') as f:
            data = json.load(f)['images']
    except FileNotFoundError:
        print(f"Error: Caption file not found at {input_captions_json}")
        sys.exit(1)
    except KeyError:
        print(f"Error: JSON file {input_captions_json} does not have the expected 'images' key.")
        sys.exit(1)

    print('Processing pre-tokenized captions...')
    all_caption_data = []
    split_filenames = {'train': [], 'val': [], 'test': []}

    for img_info in data:
        filename = img_info['filename']
        filepath = img_info['filepath'] # train, val, or test
        imgid = img_info['imgid']

        if filepath not in split_filenames:
            print(f"Warning: Unknown filepath '{filepath}' found for {filename}. Skipping.")
            continue

        tokenized_captions_with_specials = []
        raw_captions = [] # Keep raw just for reference if needed
        for sentence in img_info['sentences']:
            raw = sentence['raw']
            # Use pre-existing tokens and add start/end
            tokens = ['<START>'] + sentence['tokens'] + ['<END>']
            tokenized_captions_with_specials.append(tokens)
            raw_captions.append(raw)

        # Check if corresponding image files exist
        img_a_path = os.path.join(input_image_dir, filepath, 'A', filename)
        img_b_path = os.path.join(input_image_dir, filepath, 'B', filename)
        # Also check label if needed for consistent processing (optional here)
        # label_path = os.path.join(input_image_dir, filepath, 'label_rgb', filename)

        if not os.path.exists(img_a_path) or not os.path.exists(img_b_path):
             print(f"Warning: Image pair not found for {filename} in {filepath}. Skipping caption processing for this entry.")
             continue

        all_caption_data.append({
            'filename': filename,
            'filepath': filepath,
            'imgid': imgid,
            'tokenized_captions': tokenized_captions_with_specials, # Now using the list with added special tokens
            'raw_captions': raw_captions
        })
        # Store base name without extension in split file list
        split_filenames[filepath].append(os.path.splitext(filename)[0])


    print(f"Processed captions for {len(all_caption_data)} images.")

    print('Building vocabulary...')
    # Use a slightly higher min_token_count for UNK to be included if needed
    word_to_idx = build_vocab(all_caption_data, args.word_count_threshold)
    print(f'Vocabulary size: {len(word_to_idx)}')

    # Use the hardcoded save directory
    output_vocab_json = os.path.join(SAVE_OUTPUT_DIR, 'vocab.json')
    print(f'Saving vocabulary to {output_vocab_json}')
    with open(output_vocab_json, 'w') as f:
        json.dump(word_to_idx, f, indent=4)

    print('Encoding and padding captions...')
    null_token_value = SPECIAL_TOKENS['<NULL>']
    for img_data in all_caption_data:
        filename_base = os.path.splitext(img_data['filename'])[0]
        token_file_path = os.path.join(token_save_dir, f'{filename_base}.json') # Save as JSON

        padded_encoded_captions = []
        for tokens in img_data['tokenized_captions']:
            # Encode using the built vocabulary
            encoded = encode(tokens, word_to_idx, allow_unk=True) # Allow UNK during encoding
            # Pad the encoded sequence
            padded = pad_sequence(encoded, args.max_length, pad_value=null_token_value)
            padded_encoded_captions.append(padded)

        # Save the list of padded, encoded captions for this image
        with open(token_file_path, 'w') as f:
            json.dump(padded_encoded_captions, f)

    # Save train/val/test split BASE filenames (without extension)
    for split, filenames_base in split_filenames.items():
        # Use the hardcoded save directory
        split_file_path = os.path.join(SAVE_OUTPUT_DIR, f'{split}.txt')
        print(f"Saving {split} image list to {split_file_path} ({len(filenames_base)} images)")
        with open(split_file_path, 'w') as f:
            for fname_base in filenames_base:
                 f.write(f"{fname_base}\n")

    print('Preprocessing finished.')


# =============================================================================
# --- training.py content starts here ---
# =============================================================================

# Helper function for mIoU (add this or import from a utils file)
def calculate_iou(pred, target, num_classes):
    """Calculates Intersection over Union (IoU) per class."""
    ious = []
    pred = torch.argmax(pred, dim=1) # Convert logits to class IDs (B, H, W)
    pred = pred.view(-1)
    target = target.view(-1)

    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds[target_inds]).long().sum().item()
        union = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection
        if union == 0:
            # If there is no ground truth or prediction, it's undefined.
            # Some implementations return 1.0 (perfect IoU for absent class), others NaN or 0.
            # Returning 0 avoids NaN issues, but be aware of the interpretation.
            ious.append(0.0)
        else:
            ious.append(float(intersection) / float(max(union, 1))) # Avoid division by zero
    return np.array(ious)


class Trainer:
    def __init__(self, args):
        self.args = args

        # ---- Device Setup ----
        # Kaggle notebooks typically use CUDA if enabled
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.num_gpus = torch.cuda.device_count()
            print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
            if self.num_gpus > 1:
                print(f"Detected {self.num_gpus} GPUs. Will use DataParallel.\n") # Added newline for clarity
            else:
                print("Only one GPU detected. DataParallel will not be used.\n") # Added newline for clarity
        else:
            self.device = torch.device("cpu")
            self.num_gpus = 0
            print("Using CPU\n") # Added newline for clarity


        # ---- Create Directories ----
        # Use the hardcoded save directory
        os.makedirs(SAVE_OUTPUT_DIR, exist_ok=True)
        # Create a subdirectory for checkpoints and logs within the main save dir
        # Note: Kaggle's /kaggle/working is ephemeral per session unless saved,
        # but files saved here are included when you \\\"Save & Run All\\\" or commit.
        self.run_save_dir = os.path.join(SAVE_OUTPUT_DIR, 'training_output')
        os.makedirs(self.run_save_dir, exist_ok=True)
        print(f"Checkpoints and logs will be saved in: {self.run_save_dir}")

        # Define the path to the latest checkpoint
        self.latest_checkpoint_path = os.path.join(self.run_save_dir, 'checkpoint_latest.pth')

        # ---- Dataset & Dataloaders ----
        print("Loading Datasets...")
        self.train_dataset = LEVIRCCDataset(
            data_folder=DATASET_ROOT, # Use hardcoded dataset root from Kaggle input
            processed_data_dir=SAVE_OUTPUT_DIR, # Use hardcoded save directory for processed data (Kaggle working)
            split='train',
            load_segmentation=True, # Ensure segmentation is loaded
            max_length=args.max_length
        )
        self.val_dataset = LEVIRCCDataset(
            data_folder=DATASET_ROOT, # Use hardcoded dataset root from Kaggle input
            processed_data_dir=SAVE_OUTPUT_DIR, # Use hardcoded save directory for processed data (Kaggle working)
            split='val',
            load_segmentation=True, # Ensure segmentation is loaded
            max_length=args.max_length
        )

        # Check if dataset loading was successful
        if len(self.train_dataset) == 0 or len(self.val_dataset) == 0:
             raise ValueError("One or both datasets are empty. Check paths and preprocessing.")

        # --- DataLoader Initialization ---
        # num_workers: Adjust this value. Kaggle GPUs are powerful, increase workers
        # to keep the GPU busy. Start with a value like 4 or 8, or higher depending
        # on the dataset size and complexity of __getitem__.
        # pin_memory=True is good for CUDA.
        # If using DataParallel, the effective batch size per GPU is batch_size // num_gpus.
        # Make sure your batch_size is divisible by the number of GPUs.
        self.train_loader = DataLoader(
            self.train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.num_workers, pin_memory=True, drop_last=True # drop_last helps if batch norm sensitive
        )
        self.val_loader = DataLoader(
            self.val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.num_workers, pin_memory=True
        )

        # ---- Model Initialization ----
        print("Initializing Model...")
        # Use the hardcoded save directory to find vocab
        vocab_path = os.path.join(SAVE_OUTPUT_DIR, 'vocab.json')
        # Instantiate the combined model
        self.model = ChangeDetectionCaptioningModel(args, vocab_path) # Don't move to device yet

        # --- Wrap with DataParallel if multiple GPUs ---
        if self.num_gpus > 1:
            print(f"Wrapping model with DataParallel across {self.num_gpus} GPUs.\n") # Added newline
            self.model = nn.DataParallel(self.model)
            # Note: When using DataParallel, model state_dict keys will have 'module.' prefix.
            # This needs to be handled during checkpoint loading/saving.

        # Move model to device (primary GPU if DataParallel is used)
        self.model.to(self.device)


        # ---- Optimizer ----
        # Optimize all parameters of the combined model
        # If using DataParallel, optimizer should optimize parameters of the wrapped model
        self.optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=args.learning_rate # Use a single initial LR
        )

        # ---- Loss Functions ----
        # Captioning Loss: Ignore padding index
        pad_idx = self.train_dataset.pad_idx
        self.caption_criterion = nn.CrossEntropyLoss(ignore_index=pad_idx).to(self.device)
        # Segmentation Loss: Standard CE for multi-class segmentation
        # You might explore other losses like DiceLoss or FocalLoss depending on class imbalance
        self.segmentation_criterion = nn.CrossEntropyLoss().to(self.device)

        # ---- BLEU Score Smoothing ----
        self.smooth_fn = SmoothingFunction().method1 # Example smoothing

        # ---- Track Best Metrics ----
        self.best_bleu = 0.0
        self.best_miou = 0.0
        self.start_epoch = 0 # Initialize starting epoch

        # ---- Resume from Checkpoint (New Logic) ----
        if os.path.exists(self.latest_checkpoint_path):
            print(f"Found checkpoint at {self.latest_checkpoint_path}. Resuming training...")
            try:
                # Load checkpoint - map_location ensures it loads correctly regardless of device it was saved on
                checkpoint = torch.load(self.latest_checkpoint_path, map_location=self.device)

                # Load model state
                # Handle 'module.' prefix if checkpoint was saved from a DataParallel model
                model_state_dict = checkpoint['model_state_dict']
                if self.num_gpus > 1 and not list(model_state_dict.keys())[0].startswith('module.'):
                    # If current model is DataParallel but checkpoint keys don't have 'module.', add it
                    model_state_dict = {'module.' + k: v for k, v in model_state_dict.items()}
                elif self.num_gpus == 1 and list(model_state_dict.keys())[0].startswith('module.'):
                     # If current model is NOT DataParallel but checkpoint keys have 'module.', remove it
                     model_state_dict = {k.replace('module.', ''): v for k, v in model_state_dict.items()}

                self.model.load_state_dict(model_state_dict)


                # Load optimizer state
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

                # Resume epoch
                self.start_epoch = checkpoint['epoch']
                print(f"Resuming from epoch {self.start_epoch + 1}") # Print next epoch number

                # Load best metrics achieved so far
                self.best_bleu = checkpoint.get('best_bleu', 0.0) # Use .get() for backward compatibility
                self.best_miou = checkpoint.get('best_miou', 0.0) # Use .get() for backward compatibility
                print(f"Loaded best BLEU: {self.best_bleu:.4f}, best mIoU: {self.best_miou:.4f}")

                # Optional: Adjust learning rate if resuming (e.g., if using a scheduler)
                # for param_group in self.optimizer.param_groups:
                #     param_group['lr'] = args.learning_rate # Reset or adjust LR

            except Exception as e:
                print(f"Error loading checkpoint {self.latest_checkpoint_path}: {e}")
                print("Starting training from scratch.")
                # Reset start_epoch and best metrics if loading fails
                self.start_epoch = 0
                self.best_bleu = 0.0
                self.best_miou = 0.0
        else:
            print("No checkpoint found. Starting training from scratch.")


    def get_word_vocab(self):
        # Access vocab from the original module if using DataParallel
        if isinstance(self.model, nn.DataParallel):
             return self.model.module.decoder.word_vocab # Access decoder vocab via the original module
        else:
             return self.model.decoder.word_vocab # Access decoder vocab directly

    def train_epoch(self, epoch):
        self.model.train() # Set the main model to training mode

        total_cap_loss = 0
        total_seg_loss = 0
        total_combined_loss = 0
        num_batches = len(self.train_loader)
        # --- Progress Bar with Loss Display ---
        # The tqdm progress bar automatically displays the loss values set with .set_postfix()
        progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.args.epochs} [Train]")

        for i, batch in enumerate(progress_bar):
            # Move data to device
            # DataParallel automatically moves data to GPUs, but it's good practice
            # to ensure it's on the primary device before passing to the model.
            # DataParallel expects the input batch to be on the device of the first GPU (cuda:0 by default).
            imgA = batch['imgA'].to(self.device)
            imgB = batch['imgB'].to(self.device)
            captions = batch['token'].to(self.device) # Target captions (B, max_len)
            seg_mask_gt = batch['seg_mask'].to(self.device) # Ground truth seg masks (B, H, W)

            # Zero gradients
            self.optimizer.zero_grad()

            # ---- Forward Pass ----
            # When using DataParallel, you pass the batch to the wrapped model.
            # DataParallel handles splitting the batch across GPUs and gathering outputs.
            outputs = self.model(imgA, imgB, captions, mode='train')
            seg_logits = outputs['seg_logits'] # (B, num_classes, H, W)
            caption_logits = outputs['caption_logits'] # (B, max_len, vocab_size)

            # ---- Loss Calculation ----
            # Losses are calculated on the gathered outputs on the primary device.
            # 1. Captioning Loss
            # Reshape for CrossEntropyLoss: (B, max_len, vocab_size) -> (B * (max_len-1), vocab_size)
            # Target: (B, max_len) -> (B * (max_len-1))
            # Access vocab_size from the decoder module (handling DataParallel)
            decoder_vocab_size = self.model.module.decoder.vocab_size if isinstance(self.model, nn.DataParallel) else self.model.decoder.vocab_size
            cap_loss = self.caption_criterion(
                caption_logits[:, :-1, :].reshape(-1, decoder_vocab_size), # Predict steps 0 to max_len-2
                captions[:, 1:].reshape(-1)             # Target are steps 1 to max_len-1
            )

            # 2. Segmentation Loss
            # seg_logits: (B, num_classes, H, W), seg_mask_gt: (B, H, W) - shapes match criterion needs
            # Optional: Resize logits to match GT mask if needed (depends on SegHead output size)
            # if seg_logits.shape[-2:] != seg_mask_gt.shape[-2:]:\
            #     seg_logits = F.interpolate(seg_logits, size=seg_mask_gt.shape[-2:], mode='bilinear', align_corners=False)\
            seg_loss = self.segmentation_criterion(seg_logits, seg_mask_gt)

            # 3. Combine Losses
            combined_loss = cap_loss + self.args.seg_loss_weight * seg_loss

            # ---- Backward Pass & Optimization ----
            # DataParallel handles gradient averaging automatically
            combined_loss.backward()

            # Gradient Clipping (optional but recommended)
            if self.args.grad_clip > 0:
                # Apply gradient clipping to the parameters of the original module if using DataParallel
                if isinstance(self.model, nn.DataParallel):
                    torch.nn.utils.clip_grad_norm_(self.model.module.parameters(), self.args.grad_clip)
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)


            self.optimizer.step()

            # --- Accumulate and Display Loss ---
            total_cap_loss += cap_loss.item()
            total_seg_loss += seg_loss.item()
            total_combined_loss += combined_loss.item()
            # Update the progress bar postfix with current batch losses
            progress_bar.set_postfix({
                'CapL': f'{cap_loss.item():.3f}',
                'SegL': f'{seg_loss.item():.3f}',
                'CombL': f'{combined_loss.item():.3f}'
             })

        # --- Print Average Loss at End of Epoch ---
        avg_cap_loss = total_cap_loss / num_batches
        avg_seg_loss = total_seg_loss / num_batches
        avg_combined_loss = total_combined_loss / num_batches
        print(f"Epoch {epoch+1} Train Avg Loss -> Combined: {avg_combined_loss:.4f} (Cap: {avg_cap_loss:.4f}, Seg: {avg_seg_loss:.4f})\n") # Added newline
        return avg_combined_loss


    def validate_epoch(self, epoch):
        self.model.eval() # Set the main model to evaluation mode

        references_corpus = [] # List of lists of reference tokens for BLEU
        hypotheses_corpus = [] # List of hypothesis tokens for BLEU
        total_iou = np.zeros(NUM_CLASSES)
        num_val_batches = 0

        # Get inverse vocabulary mapping (index to word) and special indices from the VAL dataset
        # These are correctly stored in the dataset object.
        idx_to_word = self.val_dataset.idx_to_word
        # Access word_vocab from the val_dataset as well
        special_indices = {idx for token, idx in self.val_dataset.word_vocab.items() if token in ['<NULL>', '<START>', '<END>']}


        progress_bar = tqdm(self.val_loader, desc=f"Epoch {epoch+1}/{self.args.epochs} [Validate]")

        with torch.no_grad():
            for i, batch in enumerate(progress_bar):
                # Move data to device (primary GPU if DataParallel)
                imgA = batch['imgA'].to(self.device)
                imgB = batch['imgB'].to(self.device)
                all_ref_captions = batch['token_all'].to(self.device) # Move references to device for consistency
                seg_mask_gt = batch['seg_mask'].to(self.device) # (B, H, W)

                # ---- Forward Pass (Evaluation Mode) ----
                # DataParallel handles splitting and gathering
                outputs = self.model(imgA, imgB, mode='eval')
                seg_logits = outputs['seg_logits'] # (B, num_classes, H, W)
                generated_indices = outputs['generated_captions'] # (B, max_len)

                # ---- Calculate mIoU ----
                # Optional: Resize logits if needed
                # if seg_logits.shape[-2:] != seg_mask_gt.shape[-2:]:\
                #     seg_logits = F.interpolate(seg_logits, size=seg_mask_gt.shape[-2:], mode='bilinear', align_corners=False)\
                batch_iou = calculate_iou(seg_logits.cpu(), seg_mask_gt.cpu(), NUM_CLASSES)
                total_iou += batch_iou
                num_val_batches += 1

                # ---- Prepare for BLEU Score ----
                # Convert indices to words, removing special tokens

                # Generated hypotheses (move to CPU before converting to numpy)
                generated_indices_np = generated_indices.cpu().numpy()
                current_batch_hypotheses = []
                for single_hyp_indices in generated_indices_np:
                    hyp_words = [idx_to_word.get(idx, '<UNK>') for idx in single_hyp_indices if idx not in special_indices]
                    current_batch_hypotheses.append(hyp_words)

                # Ground truth references (move to CPU before converting to numpy)
                all_ref_captions_np = all_ref_captions.cpu().numpy()
                current_batch_references_for_bleu = [] # List of lists of lists of strings
                current_batch_hypotheses_for_bleu = [] # List of lists of strings

                for batch_idx in range(all_ref_captions_np.shape[0]): # Iterate through batch items
                    item_references = [] # List of lists of strings for this item
                    has_valid_reference = False
                    for ref_idx in range(all_ref_captions_np.shape[1]): # Iterate through references for the item
                        ref_indices = all_ref_captions_np[batch_idx, ref_idx, :]
                        ref_words = [idx_to_word.get(idx, '<UNK>') for idx in ref_indices if idx not in special_indices]
                        if ref_words: # Only add if the reference is not empty after removing special tokens
                            item_references.append(ref_words)
                            has_valid_reference = True

                    if has_valid_reference: # Only add to corpus if there is at least one valid reference
                        current_batch_references_for_bleu.append(item_references)
                        current_batch_hypotheses_for_bleu.append(current_batch_hypotheses[batch_idx]) # Add corresponding hypothesis

                # Append collected valid references and their corresponding hypotheses to corpus lists
                references_corpus.extend(current_batch_references_for_bleu)
                hypotheses_corpus.extend(current_batch_hypotheses_for_bleu)


        # ---- Calculate Final Metrics ----
        # mIoU
        mean_iou = np.mean(total_iou / num_val_batches) if num_val_batches > 0 else 0.0

        # BLEU Score
        bleu_score = 0.0
        if not references_corpus or not hypotheses_corpus:
             print("Warning: No valid references or hypotheses collected for BLEU score calculation.")
        elif len(references_corpus) != len(hypotheses_corpus):
              print(f"Warning: Mismatch in reference ({len(references_corpus)}) and hypothesis ({len(hypotheses_corpus)}) counts for BLEU. BLEU score calculation might be inaccurate.")
              # Avoid calculating BLEU if lengths mismatch significantly, indicates an issue
        else:
            try:
                 # corpus_bleu expects references as list of list of token strings, hypotheses as list of token strings
                 # references_corpus is already list of list of list of token strings
                 # hypotheses_corpus is list of list of token strings
                 bleu_score = corpus_bleu(
                     references_corpus,
                     hypotheses_corpus,
                     smoothing_function=self.smooth_fn
                 )
            except Exception as e:
                 print(f"Error calculating BLEU score: {e}")


        print(f"Epoch {epoch+1} Validation -> BLEU-4: {bleu_score:.4f}, mIoU: {mean_iou:.4f}\n") # Added newline
        return bleu_score, mean_iou


    def run_training(self):
        print("Starting Training...\n") # Added newline

        # Start the training loop from the loaded epoch or 0
        for epoch in range(self.start_epoch, self.args.epochs):
            train_loss = self.train_epoch(epoch)
            val_bleu, val_miou = self.validate_epoch(epoch)

            # ---- Save Checkpoint ----
            # Decide best based on primary metric (e.g., BLEU) or a combination
            is_best_bleu = val_bleu > self.best_bleu
            if is_best_bleu:
                self.best_bleu = val_bleu
                print(f"*** New best BLEU score: {self.best_bleu:.4f} ***\n") # Added newline

            is_best_miou = val_miou > self.best_miou
            if is_best_miou:
                 self.best_miou = val_miou
                 print(f"*** New best mIoU score: {self.best_miou:.4f} ***\n") # Added newline


            # Save checkpoint data
            checkpoint_data = {
                'epoch': epoch + 1, # Save the epoch *after* it completes
                'args': self.args, # Save args for reproducibility
                'best_bleu': self.best_bleu,
                'best_miou': self.best_miou,
                # Save state_dict of the original module if using DataParallel
                'model_state_dict': self.model.module.state_dict() if isinstance(self.model, nn.DataParallel) else self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
            }

            # Use the run-specific save directory
            checkpoint_path = os.path.join(self.run_save_dir, 'checkpoint_latest.pth')
            torch.save(checkpoint_data, checkpoint_path)
            # print(f"Saved latest checkpoint to {checkpoint_path}") # Reduce print frequency

            # Use the run-specific save directory
            if is_best_bleu:
                best_checkpoint_path = os.path.join(self.run_save_dir, 'checkpoint_best_bleu.pth')
                torch.save(checkpoint_data, best_checkpoint_path)
                print(f"Saved best BLEU checkpoint to {best_checkpoint_path}\n") # Added newline
            # Optionally save best mIoU checkpoint separately
            # if is_best_miou:
            #     best_miou_checkpoint_path = os.path.join(self.run_save_dir, 'checkpoint_best_miou.pth')
            #     torch.save(checkpoint_data, best_miou_checkpoint_path)\
            #     print(f"Saved best mIoU checkpoint to {best_miou_checkpoint_path}")\


        print("Training finished.\n") # Added newline
        print(f"Best Validation BLEU-4 achieved: {self.best_bleu:.4f}")
        print(f"Best Validation mIoU achieved: {self.best_miou:.4f}")


# =============================================================================
# --- Main Execution Block ---
# This block will run when you execute the notebook cell.
# It parses arguments (or uses defaults) and starts preprocessing/training.
# =============================================================================

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train the Multi-Task Change Detection and Captioning Model')

    # --- Paths (Hardcoded at the top, but keeping args for potential override) ---
    # parser.add_argument('--data_folder', type=str, default=DATASET_ROOT, help='Path to root LEVIR-MCI dataset folder')
    # parser.add_argument('--processed_data_dir', type=str, default=SAVE_OUTPUT_DIR, help='Path to folder with preprocessed data')
    # parser.add_argument('--save_dir', type=str, default=SAVE_OUTPUT_DIR, help='Directory to save checkpoints and logs')

    # --- Dataset ---
    parser.add_argument('--dataset', type=str, default='LEVIR_MCI', help='The name of the dataset')
    parser.add_argument('--max_length', type=int, default=41, help='Maximum caption length used during preprocessing')
    parser.add_argument('--image_size', type=int, default=256, help='Input image size (H and W)')
    parser.add_argument('--word_count_threshold', default=5, type=int, help='Minimum word count to include in vocabulary (for preprocessing)')


    # --- Model Architecture ---
    parser.add_argument('--encoder_type', type=str, default='resnet50', choices=['resnet50', 'efficientnet_b0'], help='Image encoder backbone')
    parser.add_argument('--encoder_load_random', action='store_true', help='Initialize encoder with random weights')
    parser.add_argument('--freeze_encoder', action='store_true', help='Freeze weights of the image encoder backbone')
    parser.add_argument('--embed_dim', type=int, default=512, help='Embedding dimension for decoder')
    parser.add_argument('--ff_dim', type=int, default=2048, help='Feed-forward dimension in Transformer layers')
    parser.add_argument('--attn_layers', type=int, default=1, help='Number of layers in attentive fusion (if using MHA version)')
    parser.add_argument('--decoder_layers', type=int, default=2, help='Number of layers in transformer decoder')
    parser.add_argument('--heads', type=int, default=8, help='Number of attention heads')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout probability')

    # --- Training ---
    parser.add_argument('--epochs', type=int, default=30, help='Number of training epochs')
    # batch_size: Adjust this value based on GPU memory. Start lower if needed.
    # IMPORTANT: For DataParallel, batch_size should be divisible by the number of GPUs.
    parser.add_argument('--batch_size', type=int, default=16, help='Training batch size (reduce if OOM). Should be divisible by num_gpus.')
    # num_workers: Adjust based on Kaggle's CPU cores and dataset loading speed.
    # Start with 4 or 8 and experiment.
    parser.add_argument('--num_workers', type=int, default=4, help='Number of dataloader workers (adjust based on system)')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for the optimizer')
    parser.add_argument('--grad_clip', type=float, default=5.0, help='Gradient clipping value (0 for no clipping)')
    parser.add_argument('--seg_loss_weight', type=float, default=1.0, help='Weight for the segmentation loss term')

    # --- Execution Mode ---
    parser.add_argument('--run_mode', type=str, default='train', choices=['preprocess', 'train'],
                        help='Which part of the pipeline to run: preprocess or train.')

    # In a notebook, you might pass arguments like this for preprocessing:
    # args = parser.parse_args(['--run_mode', 'preprocess', '--word_count_threshold', '5', '--max_length', '41'])\
    # And like this for training:
    # args = parser.parse_args(['--run_mode', 'train', '--batch_size', '32', '--epochs', '50'])

    # For running in a single cell, you can set the mode directly or use a simple list
    # For the first run (preprocessing):
    # args = parser.parse_args(['--run_mode', 'preprocess'])
    # For the second run (training):
    args = parser.parse_args([]) # <-- Set this to 'train' for training run

    print(f"Running in mode: {args.run_mode}")

    if args.run_mode == 'preprocess':
        # Run preprocessing
        run_preprocessing(args)
    elif args.run_mode == 'train':
        # Run training
        # Ensure preprocessing has been run
        if not os.path.exists(os.path.join(SAVE_OUTPUT_DIR, 'vocab.json')) or \
           not os.path.exists(os.path.join(SAVE_OUTPUT_DIR, 'train.txt')) or \
           not os.path.exists(os.path.join(SAVE_OUTPUT_DIR, 'tokens')):
            print("Preprocessing output not found. Running preprocessing first.")
            run_preprocessing(args)

        trainer = Trainer(args)
        trainer.run_training()
    else:
        print(f"Unknown run_mode: {args.run_mode}")


Running in mode: train
Preprocessing output not found. Running preprocessing first.
Using LEVIR_MCI dataset from: /kaggle/input/levir-mci-dataset/LEVIR-MCI-dataset
Saving processed data to: /kaggle/working/
Loading captions from JSON...
Processing pre-tokenized captions...
Processed captions for 10077 images.
Building vocabulary...
Vocabulary size: 501
Saving vocabulary to /kaggle/working/vocab.json
Encoding and padding captions...
Saving train image list to /kaggle/working/train.txt (6815 images)
Saving val image list to /kaggle/working/val.txt (1333 images)
Saving test image list to /kaggle/working/test.txt (1929 images)
Preprocessing finished.
Using CUDA GPU: Tesla T4
Detected 2 GPUs. Will use DataParallel.

Checkpoints and logs will be saved in: /kaggle/working/training_output
Loading Datasets...
Initialized LEVIRCCDataset for split 'train' with 6815 items (after potential repeat for max_iters).
Initialized LEVIRCCDataset for split 'val' with 1333 items (after potential repeat for 

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 218MB/s]


Training weights for encoder: resnet50
Decoder loaded vocabulary with size: 501
Wrapping model with DataParallel across 2 GPUs.

No checkpoint found. Starting training from scratch.
Starting Training...



Epoch 1/30 [Train]: 100%|██████████| 425/425 [04:01<00:00,  1.76it/s, CapL=1.759, SegL=0.357, CombL=2.116]


Epoch 1 Train Avg Loss -> Combined: 2.5386 (Cap: 2.0488, Seg: 0.4898)



Epoch 1/30 [Validate]: 100%|██████████| 84/84 [00:47<00:00,  1.78it/s]


Epoch 1 Validation -> BLEU-4: 0.2041, mIoU: 0.4123

*** New best BLEU score: 0.2041 ***

*** New best mIoU score: 0.4123 ***

Saved best BLEU checkpoint to /kaggle/working/training_output/checkpoint_best_bleu.pth



Epoch 2/30 [Train]: 100%|██████████| 425/425 [04:04<00:00,  1.74it/s, CapL=1.217, SegL=0.161, CombL=1.378]


Epoch 2 Train Avg Loss -> Combined: 1.6848 (Cap: 1.4543, Seg: 0.2305)



Epoch 2/30 [Validate]: 100%|██████████| 84/84 [00:40<00:00,  2.08it/s]


Epoch 2 Validation -> BLEU-4: 0.6292, mIoU: 0.5436

*** New best BLEU score: 0.6292 ***

*** New best mIoU score: 0.5436 ***

Saved best BLEU checkpoint to /kaggle/working/training_output/checkpoint_best_bleu.pth



Epoch 3/30 [Train]: 100%|██████████| 425/425 [04:05<00:00,  1.73it/s, CapL=1.825, SegL=0.188, CombL=2.014]


Epoch 3 Train Avg Loss -> Combined: 1.4697 (Cap: 1.3127, Seg: 0.1570)



Epoch 3/30 [Validate]: 100%|██████████| 84/84 [00:41<00:00,  2.04it/s]


Epoch 3 Validation -> BLEU-4: 0.6063, mIoU: 0.5715

*** New best mIoU score: 0.5715 ***



Epoch 4/30 [Train]: 100%|██████████| 425/425 [04:04<00:00,  1.74it/s, CapL=1.067, SegL=0.134, CombL=1.201]


Epoch 4 Train Avg Loss -> Combined: 1.3778 (Cap: 1.2530, Seg: 0.1248)



Epoch 4/30 [Validate]: 100%|██████████| 84/84 [00:41<00:00,  2.02it/s]


Epoch 4 Validation -> BLEU-4: 0.5381, mIoU: 0.5915

*** New best mIoU score: 0.5915 ***



Epoch 5/30 [Train]: 100%|██████████| 425/425 [04:04<00:00,  1.74it/s, CapL=0.982, SegL=0.088, CombL=1.070]


Epoch 5 Train Avg Loss -> Combined: 1.3194 (Cap: 1.2094, Seg: 0.1100)



Epoch 5/30 [Validate]: 100%|██████████| 84/84 [00:41<00:00,  2.03it/s]


Epoch 5 Validation -> BLEU-4: 0.5818, mIoU: 0.6067

*** New best mIoU score: 0.6067 ***



Epoch 6/30 [Train]:  59%|█████▊    | 249/425 [02:23<01:40,  1.75it/s, CapL=0.755, SegL=0.034, CombL=0.789]