# Imports

Crucial packages are imported here

## Packages

In [1]:
import numpy as np
import random
import os
import sys
import json
from tqdm.auto import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from collections import OrderedDict
import matplotlib.pyplot as plt
from clip import clip
import sys
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


 The device to use is set here, by default it is 'cuda:1'

In [2]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device

'cuda:1'

## Functions

These are functions to automate common tasks

In [3]:
def load_json(path):  # Loads json files
    with open(path) as f:
        file = json.load(f)
    f.close()
    return file
def save_json(file,path): # Saves json files
    with open(path,'w') as f:
        json.dump(file,f)
    f.close()
    print("Saved Successfully")
def rem_print(word):  #Allows for printed statements to be overwritten
    t_word = word
    for _ in range(250 - len(t_word)):
        word = word + ' '
    print(word,end='\r')

# Captioning

This section deals with tasks involving captioning only

## Packages

In [4]:
from torch.utils import data
from torch.nn.utils.rnn import pack_padded_sequence
import time
sys.path.append('/home/guest/Documents/Siraj TM/RSCaMa')
from utils_tool.utils import *
sys.path.append('/home/guest/Documents/Siraj TM/RSCaMa')
from model.model_decoder import DecoderTransformer
from model.model_encoder_attMamba import EnhancedEncoder
from model.model_encoder_attMamba import CrossAttentiveEncoder
import torchvision.transforms.functional as TF
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.layers import drop_path, trunc_normal_



## Functions

In [5]:
def manual_preprocess(images, size=224):
    """
    Manual preprocessing function that replicates:
    Resize(bicubic) -> CenterCrop -> ToTensor -> Normalize
    (without RGB conversion)
    
    Args:
        images: Single image (PIL Image, tensor, numpy array) or batch of images
            - For batch: list of images or 4D tensor (B, C, H, W)
        size: Target size for resize and crop (default: 224)
    
    Returns:
        torch.Tensor: Preprocessed image tensor(s)
                    - Single image: (C, H, W)
                    - Batch: (B, C, H, W)
    """
    
    def process_single_image(img):
        # Convert to PIL Image if it's a tensor or numpy array
        if isinstance(img, torch.Tensor):
            if img.dim() == 3:  # Single image tensor (C, H, W)
                img = TF.to_pil_image(img)
            else:
                raise ValueError(f"Unexpected tensor dimensions: {img.dim()}")
        elif isinstance(img, np.ndarray):
            img = Image.fromarray(img)
        
        # 1. Resize with bicubic interpolation
        w, h = img.size
        if w < h:
            new_w = size
            new_h = int(size * h / w)
        else:
            new_h = size
            new_w = int(size * w / h)
        
        img = img.resize((new_w, new_h), Image.BICUBIC)
        
        # 2. Center Crop
        img = TF.center_crop(img, (size, size))
        
        # 3. Convert to Tensor
        tensor = TF.to_tensor(img)
        
        # 4. Normalize
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
        std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)
        tensor = (tensor - mean) / std
        
        return tensor
    
    # Handle different input types
    if isinstance(images, (list, tuple)):
        # Batch of images as list/tuple
        processed_batch = []
        for img in images:
            processed_batch.append(process_single_image(img))
        return torch.stack(processed_batch)
    
    elif isinstance(images, torch.Tensor) and images.dim() == 4:
        # Batch tensor (B, C, H, W)
        batch_size = images.shape[0]
        processed_batch = []
        for i in range(batch_size):
            single_img = images[i]
            processed_batch.append(process_single_image(single_img))
        return torch.stack(processed_batch)
    
    elif isinstance(images, np.ndarray) and images.ndim == 4:
        # Batch numpy array (B, H, W, C) or (B, C, H, W)
        batch_size = images.shape[0]
        processed_batch = []
        for i in range(batch_size):
            single_img = images[i]
            processed_batch.append(process_single_image(single_img))
        return torch.stack(processed_batch)
    
    else:
        # Single image
        return process_single_image(images)

In [6]:
from imageio import imread
def get_image(path,preprocess):
    """
    Load an image from the given path and preprocess it for CLIP.
    
    Args:
        path (str): Path to the image file.
        
    Returns:
        torch.Tensor: Preprocessed image tensor.
    """
    imagesA = []
    imagesB = []
    imagesegs = []
    
    if isinstance(path, str):
        path = [path]
    elif isinstance(path, list):
        path = path
    if len(path) == 1:
        path = path[0]
    
        split = path.split('_')[0]
        imgA = f'data/LEVIR-MCI-dataset/images/{split}/A/{path}'
        imgB = f'data/LEVIR-MCI-dataset/images/{split}/B/{path}'
        seg_label = f'data/LEVIR-MCI-dataset/images/{split}/label/{path}'
        

        
        imageA = Image.open(imgA).convert('RGB')
        imageB = Image.open(imgB).convert('RGB')
        imageseg = np.asarray(Image.open(seg_label))
    
        seg_output = manual_preprocess((torch.tensor(np.array(imageseg)) / 128 > 0).float().permute(2, 0, 1)) > 0
        return preprocess(imageA),preprocess(imageB), seg_output[:1,:,:] # Add batch dimension
    else:
        for i in range(len(path)):
            split = path[i].split('_')[0]
            imgA = f'data/LEVIR-MCI-dataset/images/{split}/A/{path[i]}'
            imgB = f'data/LEVIR-MCI-dataset/images/{split}/B/{path[i]}'
            seg_label = f'data/LEVIR-MCI-dataset/images/{split}/label/{path[i]}'
            
            imageA = Image.open(imgA).convert('RGB')
            imageB = Image.open(imgB).convert('RGB')
            imageseg = np.asarray(Image.open(seg_label))
            
            imagesA.append(preprocess(imageA).unsqueeze(0))
            imagesB.append(preprocess(imageB).unsqueeze(0))
            imagesegs.append(imageseg)  # Convert to binary mask (0 or 1)
        
        
        seg_output = (manual_preprocess((torch.tensor(imagesegs) / 128 > 0).float().permute(0, 3, 1, 2)) > 0).float()
        
        return torch.cat(imagesA, dim=0),torch.cat(imagesB, dim=0 ),  seg_output[:,:1,:,:]#(manual_preprocess((torch.tensor( np.array(imagesegs))/128 > 0).float().unsqueeze(0).permute(0, 3, 1, 2)) > 0).float()  # Convert to binary mask (0 or 1)

## Data

### Dataloader

In [7]:
import torch
from torch.utils.data import Dataset
from preprocess_data import encode
import json
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
#import cv2 as cv
from imageio import imread
from random import *
class LEVIRMCIDataset_Modified(Dataset):
    """
    A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
    """

    def __init__(self, data_folder, list_path, split, preprocess=False,token_folder = None, vocab_file = None, max_length = 41, allow_unk = 0, max_iters=None):
        """
        :param data_folder: folder where image files are stored
        :param list_path: folder where the file name-lists of Train/val/test.txt sets are stored
        :param split: split, one of 'TRAIN', 'VAL', or 'TEST'
        :param token_folder: folder where token files are stored
        :param vocab_file: the name of vocab file
        :param max_length: the maximum length of each caption sentence
        :param max_iters: the maximum iteration when loading the data
        :param allow_unk: whether to allow the tokens have unknow word or not
        """
        self.mean = [0.39073*255,  0.38623*255, 0.32989*255]
        self.std = [0.15329*255,  0.14628*255, 0.13648*255]
        self.list_path = list_path
        self.split = split
        self.max_length = max_length
        self.preprocess = preprocess
        
        assert self.split in {'train', 'val', 'test'}
        self.img_ids = [i_id.strip() for i_id in open(os.path.join(list_path + split + '.txt'))]
        if vocab_file is not None:
            with open(os.path.join(list_path + vocab_file + '.json'), 'r') as f:
                self.word_vocab = json.load(f)
            self.allow_unk = allow_unk
        if not max_iters==None:
            n_repeat = int(np.ceil(max_iters / len(self.img_ids)))
            self.img_ids = self.img_ids * n_repeat + self.img_ids[:max_iters-n_repeat*len(self.img_ids)]
        self.files = []
        
        
        if split =='train':
            for name in self.img_ids:
                img_fileA = os.path.join(data_folder + '/' + split +'/A/' + name.split('-')[0])
                img_fileB = img_fileA.replace('A', 'B')
                #print(self.preprocess)
                #print(self.max_length)
                if self.preprocess:
                    imgA,imgB,seg_label = get_image(name,preprocess=preprocess)
                    #print("Successfully loaded image with preprocess")
                else:
                    imgA = imread(img_fileA)
                    imgB = imread(img_fileB)
                    seg_label = imread(img_fileA.replace('A', 'label'))
                    #print("Failure")

                if '-' in name:
                    token_id = name.split('-')[-1]
                else:
                    token_id = None
                if token_folder is not None:
                    token_file = os.path.join(token_folder + name.split('.')[0] + '.txt')
                else:
                    token_file = None
                self.files.append({
                    "imgA": imgA,
                    "imgB": imgB,
                    "seg_label": seg_label,
                    "token": token_file,
                    "token_id": token_id,
                    "name": name.split('-')[0]
                })
        elif split =='val':
            for name in self.img_ids:
                img_fileA = os.path.join(data_folder + '/' + split +'/A/' + name)
                img_fileB = img_fileA.replace('A', 'B')
                
                if self.preprocess:
                    imgA,imgB,seg_label = get_image(name,preprocess=preprocess)
                else:
                    imgA = imread(img_fileA)
                    imgB = imread(img_fileB)
                    seg_label = imread(img_fileA.replace('A', 'label'))
                    
                token_id = None
                if token_folder is not None:
                    token_file = os.path.join(token_folder + name.split('.')[0] + '.txt')
                else:
                    token_file = None
                self.files.append({
                    "imgA": imgA,
                    "imgB": imgB,
                    "seg_label": seg_label,
                    "token": token_file,
                    "token_id": token_id,
                    "name": name
                })
        elif split =='test':
            for name in self.img_ids:
                img_fileA = os.path.join(data_folder + '/' + split +'/A/' + name)
                img_fileB = img_fileA.replace('A', 'B')

                if self.preprocess:
                    imgA,imgB,seg_label = get_image(name,preprocess=preprocess)
                else:
                    imgA = imread(img_fileA)
                    imgB = imread(img_fileB)
                    seg_label = imread(img_fileA.replace('A', 'label'))

                token_id = None
                if token_folder is not None:
                    token_file = os.path.join(token_folder + name.split('.')[0] + '.txt')
                else:
                    token_file = None
                self.files.append({
                    "imgA": imgA,
                    "imgB": imgB,
                    "seg_label": seg_label,
                    "token": token_file,
                    "token_id": token_id,
                    "name": name
                })
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        datafiles = self.files[index]
        name = datafiles["name"]

        imgA = datafiles["imgA"]
        imgB = datafiles["imgB"]
        seg_label = datafiles["seg_label"]
        
        if not self.preprocess:
            imgA = np.asarray(imgA, np.float32)
            imgB = np.asarray(imgB, np.float32)
            imgA = imgA.transpose(2, 0, 1)
            imgB = imgB.transpose(2, 0, 1)
            seg_label = seg_label.transpose(2, 0, 1)[0]
            seg_label[seg_label==255] = 2
            seg_label[seg_label==128] = 1


            for i in range(len(self.mean)):
                imgA[i,:,:] -= self.mean[i]
                imgA[i,:,:] /= self.std[i]
                imgB[i,:,:] -= self.mean[i]
                imgB[i,:,:] /= self.std[i]      
                
        if datafiles["token"] is not None:
            caption = open(datafiles["token"])
            caption = caption.read()
            caption_list = json.loads(caption)

            #token = np.zeros((1, self.max_length), dtype=int)
            #j = randint(0, len(caption_list) - 1)
            #tokens_encode = encode(caption_list[j], self.word_vocab,
            #            allow_unk=self.allow_unk == 1)
            #token[0, :len(tokens_encode)] = tokens_encode
            #token_len = len(tokens_encode)

            token_all = np.zeros((len(caption_list),self.max_length),dtype=int)
            token_all_len = np.zeros((len(caption_list),1),dtype=int)
            for j, tokens in enumerate(caption_list):
                nochange_cap = ['<START>', 'the', 'scene', 'is', 'the', 'same', 'as', 'before', '<END>']
                if self.split == 'train' and nochange_cap in caption_list:
                    tokens = nochange_cap
                tokens_encode = encode(tokens, self.word_vocab,
                                    allow_unk=self.allow_unk == 1)
                token_all[j,:len(tokens_encode)] = tokens_encode
                token_all_len[j] = len(tokens_encode)
            if datafiles["token_id"] is not None:
                id = int(datafiles["token_id"])
                token = token_all[id]
                token_len = token_all_len[id].item()
            else:
                j = randint(0, len(caption_list) - 1)
                token = token_all[j]
                token_len = token_all_len[j].item()
        else:
            token_all = np.zeros(1, dtype=int)
            token = np.zeros(1, dtype=int)
            token_len = np.zeros(1, dtype=int)
            token_all_len = np.zeros(1, dtype=int)
        #print(imgA.shape)
        return imgA, imgB, seg_label, token_all.copy(), token_all_len.copy(), token.copy(), np.array(token_len), name


In [8]:
import matplotlib.pyplot as plt

def plot_metrics_over_epochs(metrics_dict):
    """
    Plots metric values over epochs.

    Args:
        metrics_dict (dict): Dictionary with epochs as keys and each value is a dict with keys:
            'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'Rouge', 'Cider', 'test_time'
    """
    # Sort epochs
    epochs = sorted(metrics_dict.keys())
    metrics = ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4' ,'test_time']
    values = {m: [metrics_dict[e][m] for e in epochs] for m in metrics}
    
    plt.figure(figsize=(12, 7))
    for m in metrics:
        if m != 'test_time':
            plt.plot(epochs, np.array(values[m]) * 100, marker='o', label=m)
    plt.xlabel('Epoch')
    plt.ylabel('Score (%)')
    plt.title('Evaluation Metrics over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

    # Optionally, plot test_time separately
    plt.figure(figsize=(8, 4))
    plt.plot(epochs, values['test_time'], marker='o', color='gray', label='test_time')
    plt.xlabel('Epoch')
    plt.ylabel('Test Time (s)')
    plt.title('Test Time over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

In [9]:

def model_validation(encoder,encoder_trans,decoder,dataloader):
    test_start_time = time.time()
    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    print("Validation.....\n")
    with torch.no_grad():
        # Batches
        for idx,batch_data in enumerate(dataloader):
            imgA, imgB, seg_label, token_all, token_all_len, token, token_len, name = batch_data
            
            #Getting Data and moving to GPU if possible
            imgA = imgA.cuda(device)
            imgB = imgB.cuda(device)
            #imgSM = imgSM.cuda(device)
            token = token.cuda(device)
            token_len = token_len.cuda(device)
            
            #Texts = get_text_inputs(name,split='test')
            
            
            
            #imgSM = ChangeCLIP.forward(xA=imgA,xB=imgB,Texts=Texts)
            
            #imgA, imgB, seg_label = Image_store_test[idx]

            feat1,feat2 = encoder(
                imgA.to(device),
                imgB.to(device),
                (seg_label > 0).int().to(device)
                )

            #eat1,feat2,feat3 = encoder(imgA.to(device),imgB.to(device),seg_label.to(device))
            feat = encoder_trans(feat1,feat2)


            seq = decoder.sample(feat,k=1)
            
            except_tokens = {word_vocab['<START>'], word_vocab['<END>'], word_vocab['<NULL>']}
            img_token = token_all.tolist()
            img_tokens = list(map(lambda c: [w for w in c if w not in except_tokens],
                        img_token[0]))  # remove <start> and pads
            references.append(img_tokens)
            
            pred_seq = [w for w in seq if w not in except_tokens]
            hypotheses.append(pred_seq)
            
            pred_caption = ""
            ref_caption = ""
            for i in pred_seq:
                pred_caption += (list(word_vocab.keys())[i]) + " "
            ref_caption = ""
            for i in img_tokens[0]:
                ref_caption += (list(word_vocab.keys())[i]) + " "
            ref_captions = ""
            for i in img_tokens:
                for j in i:
                    ref_captions += (list(word_vocab.keys())[j]) + " "
                ref_captions += ".    "
                    
    test_time = time.time() - test_start_time
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"])
    ]

    hypo = [[' '.join(hypo)] for hypo in [[str(x) for x in hypo] for hypo in hypotheses]]
    ref = [[' '.join(reft) for reft in reftmp] for reftmp in
            [[[str(x) for x in reft] for reft in reftmp] for reftmp in references]]
    score = []
    method = []


    for scorer, method_i in scorers:
        score_i, scores_i = scorer.compute_score(ref, hypo)
        score.extend(score_i) if isinstance(score_i, list) else score.append(score_i)
        method.extend(method_i) if isinstance(method_i, list) else method.append(method_i)
        #print("{} {}".format(method_i, score_i))
    score_dict = dict(zip(method, score))

    #get_eval_score#score_dict = get_eval_score(references, hypotheses)
    Bleu_1 = score_dict['Bleu_1']
    Bleu_2 = score_dict['Bleu_2']
    Bleu_3 = score_dict['Bleu_3']
    Bleu_4 = score_dict['Bleu_4']
    #Meteor = score_dict['METEOR']

    #print(f"{VERSION}_{epoch} Results")
    #print(f'Testing:\n Time: {test_time}s\n BLEU-1: {Bleu_1*100}  %\n BLEU-2: {Bleu_2*100}  %\n BLEU-3: {Bleu_3*100}  %\n BLEU-4: {Bleu_4*100}  %\n Rouge: {Rouge*100}  %\n Cider: {Cider}\t')
    
    return {
        'Bleu_1': Bleu_1,
        'Bleu_2': Bleu_2,
        'Bleu_3': Bleu_3,
        'Bleu_4': Bleu_4,
        'test_time': test_time,}
    

In [10]:
def training_history_plot(hist, save_path='data/Pre-Trained Models/Finetuning/training_history.4.2.1.json'):
    import matplotlib.pyplot as plt
    hist = np.array(hist)
    plt.figure(figsize=(10, 5))
    plt.plot(hist[:, 0], hist[:, 1], label='Loss')
    plt.plot(hist[:, 0], hist[:, 2], label='Top-5 Accuracy')
    plt.xlabel('Iteration')
    plt.ylabel('Value')
    plt.title('Training History')
    plt.legend()
    plt.savefig(save_path.replace('.json', '.png'))
    plt.show()

### Loading data

In [11]:
Dataset_Path = 'data/LEVIR-MCI-dataset/images'
token_path = 'Change-Agent/Multi_change/data/LEVIR_MCI/tokens/'

network='CLIP-ViT-B/32'
clip_model_type = network.replace("CLIP-", "")
clip_model, preprocess = clip.load(clip_model_type,device=device)

In [12]:
train_loader = data.DataLoader(
                LEVIRMCIDataset_Modified(data_folder=Dataset_Path, list_path='Change-Agent/Multi_change/data/LEVIR_MCI/',preprocess=preprocess, split='train', token_folder=token_path, vocab_file='vocab', max_length=42, allow_unk=1),
                batch_size=8, shuffle=True, num_workers=36, pin_memory=True)

In [13]:
val_loader = data.DataLoader(
                LEVIRMCIDataset_Modified(data_folder=Dataset_Path, list_path='Change-Agent/Multi_change/data/LEVIR_MCI/',preprocess=preprocess, split='val', token_folder=token_path, vocab_file='vocab', max_length=42, allow_unk=1),
                batch_size=1, shuffle=True, num_workers=36, pin_memory=True)

In [14]:
word_vocab = load_json('assets/vocab_mci.json')

## Models

VERSION denotes the specifiactions of the model's architecture

In [15]:
VERSION = '4.3.1.16'
layers, atten_layers, decoder_layers ,heads= VERSION.split('.')

decoder = DecoderTransformer(decoder_type='transformer_decoder',embed_dim=768,
                                    vocab_size=len(word_vocab), max_lengths=42,
                                    word_vocab=word_vocab, n_head=8,
                                    n_layers=int(decoder_layers), dropout=0.1,device=device).to(device) # The decoder used in RSCaMa


encoder_trans = CrossAttentiveEncoder(n_layers=int(layers),
                                        feature_size=[7, 7, 768],
                                        heads=int(heads), dropout=0.1,atten_layers=int(atten_layers),device=device).to(device) #Encoder Transformer of RSCaMa with cross attention

encoder = EnhancedEncoder('CLIP-ViT-B/32').to(device) # Encoder with image enhancement features
encoder.fine_tune(True)

decoder_n_layers= 1
decoder_type= transformer_decoder


## Training loop

### Optimizers

In [16]:
# Optimizers
num_epochs = 50

encoder_optimizer = torch.optim.Adam(params=encoder.parameters(),
                                            lr=1e-4)
encoder_trans_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, encoder_trans.parameters()),
    lr=1e-4)
decoder_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, decoder.parameters()),
    lr=1e-4)

# Move to GPU, if available
encoder_trans.cuda(device)
decoder.cuda(device)

encoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(encoder_optimizer, step_size=5,
                                                            gamma=1.0)
encoder_trans_lr_scheduler = torch.optim.lr_scheduler.StepLR(encoder_trans_optimizer, step_size=5,
                                                                    gamma=1.0)
decoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(decoder_optimizer, step_size=5,
                                                            gamma=1.0)
hist = np.zeros((num_epochs*2 * len(train_loader), 5))

l_resizeA = torch.nn.Upsample(size = (256, 256), mode ='bilinear', align_corners = True)
l_resizeB = torch.nn.Upsample(size = (256, 256), mode ='bilinear', align_corners = True)
index_i = 0

criterion_cap = torch.nn.CrossEntropyLoss().cuda(device)

### Loop

In [None]:
print_freq = 5000
EPOCHS = num_epochs
index_i = 0
network='CLIP-ViT-B/32'
clip_model_type = network.replace("CLIP-", "")
clip_model, preprocess = clip.load(clip_model_type,device=device)

encoder.train()
encoder_trans.train()
decoder.train()

decoder_optimizer.zero_grad()
encoder_trans_optimizer.zero_grad()

#Image_store = {id:get_image(batch_data[-1],preprocess=preprocess) for id,batch_data in tqdm(enumerate(train_loader))}

benchmark = {}
MAX_SCORE = [0]
Prev_Epoch = False
for epoch in range(5,EPOCHS):
    loss_set = []
    acc_set = []
    for id,batch_data in enumerate(train_loader):
        
        
        imgA, imgB, seg_label, token_all, token_all_len, token, token_len, name = batch_data

        #Texts = get_text_inputs(name,split='train')
        start_time = time.time()
        accum_steps = 64//64
        
        #Getting Data and moving to GPU if possible

        #imgA = imgA.cuda(device)
        #imgB = imgB.cuda(device)
        #imgSM = imgSM.cuda(device)
        token = token.cuda(device) 
        token_len = token_len.cuda(device)

        #Feat1 and Feat2
        '''with torch.no_grad():
            
            imgSM = ChangeCLIP.forward(xA=imgA,xB=imgB,Texts=Texts)'''
        
        #del imgA
        #del imgB
        
        #imgA, imgB, seg_label = Image_store[id]

        feat1,feat2 = encoder(
            imgA.to(device),
            imgB.to(device),
            (seg_label > 0).int().to(device)
            )
                
        featcap = encoder_trans(feat1,feat2)
        
        scores, caps_sorted, decode_lengths, sort_ind = decoder(featcap, token, token_len)
        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]
        
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data
        
        '''seg_targets = imgSM[:,:,:,:1].permute(0,3,1,2)
        seg_targets = seg_targets.float()  # Convert to float
        
        if seg_targets.max() > 1.0:
            seg_targets = seg_targets / 255.0  # Normalize if needed
        '''
        # Ensure targets are long integers for captioning
        
        targets = targets.long()
        loss = criterion_cap(scores, targets.to(torch.int64))
        
        loss = loss / accum_steps
        loss.backward()
        
        if (id + 1) % accum_steps == 0 or (id + 1) == len(train_loader):
            decoder_optimizer.step()
            encoder_trans_optimizer.step()

            # Adjust learning rate
            decoder_lr_scheduler.step()
            encoder_trans_lr_scheduler.step()

            decoder_optimizer.zero_grad()
            encoder_trans_optimizer.zero_grad()

                
        hist[index_i, 0] = time.time() - start_time #batch_time
        hist[index_i, 1] = loss.item()  # train_loss
        
        accuracy = accuracy_v0(scores, targets, 5)
        
        hist[index_i, 2] = accuracy #top5
        
        index_i += 1
        
        if index_i % 5 == 0:
            rem_print(f'Training Epoch: {epoch} | Index:{index_i} | Loss: {loss} | Top-5 Accuracy: {accuracy} ')
        loss_set.append(loss.item())
        acc_set.append(accuracy)

    #print(f'Training Epoch: {epoch} | Index:{index_i} | Mean Loss: {np.mean(loss_set)}\n')
    
    if (epoch % 5 == 0 and epoch) or Prev_Epoch:
        print(f'Training Epoch: {epoch} | Index:{index_i} | Mean Loss: {np.mean(loss_set)} | Mean Accuracy : {np.mean(acc_set)}\n')
        benchmark[epoch] = model_validation(encoder, encoder_trans, decoder, val_loader)
        print('\n')
        save_json(benchmark, f'data/Pre-Trained Models/Finetuning/benchmark.{VERSION}.json')

        if benchmark[epoch]['Bleu_1']*100 > MAX_SCORE[0]:
            MAX_SCORE[0] = benchmark[epoch]['Bleu_4'] * 100
            print(f'\nNew Best Score: {MAX_SCORE[0]} at epoch {epoch}\n')
            
            torch.save(encoder.state_dict(),f'data/Pre-Trained Models/Finetuning/encoder_{VERSION}_best.pt')
            torch.save(encoder_trans.state_dict(),f'data/Pre-Trained Models/Finetuning/encoder_trans_{VERSION}_best.pt')
            torch.save(decoder.state_dict(),f'data/Pre-Trained Models/Finetuning/decoder_{VERSION}_best.pt')
            
            Prev_Epoch = True
        else:
            print(f'No Improvement at epoch {epoch}, Previous Best Bleu_1: {MAX_SCORE[0]}')
            Prev_Epoch = False
    print('\n')
    save_json(hist.tolist(), f'data/Pre-Trained Models/Finetuning/training_history.{VERSION}.json')

#torch.save(encoder.state_dict(),f'data/Pre-Trained Models/Finetuning/encoder_{VERSION}_{epoch}.pt')
#torch.save(encoder_trans.state_dict(),f'data/Pre-Trained Models/Finetuning/encoder_trans_{VERSION}_{epoch}.pt')
#torch.save(decoder.state_dict(),f'data/Pre-Trained Models/Finetuning/decoder_{VERSION}_{epoch}.pt')

save_json(benchmark, f'data/Pre-Trained Models/Finetuning/benchmark.{VERSION}.json')
save_json(hist.tolist(), f'data/Pre-Trained Models/Finetuning/training_history.{VERSION}.json')
plot_metrics_over_epochs(benchmark)



Training Epoch: 5 | Index:200 | Loss: 2.205578088760376 | Top-5 Accuracy: 70.51282051282053                                                                                                                                                               

KeyboardInterrupt: 

## Evaluation

### Data

In [17]:
test_loader = data.DataLoader(
                LEVIRMCIDataset_Modified(data_folder=Dataset_Path, list_path='Change-Agent/Multi_change/data/LEVIR_MCI/',preprocess=preprocess, split='test', token_folder=token_path, vocab_file='vocab', max_length=42, allow_unk=1),
                batch_size=1, shuffle=True, num_workers=24, pin_memory=True)

nochange_list = ["the scene is the same as before ", "there is no difference ",
                         "the two scenes seem identical ", "no change has occurred ",
                         "almost nothing has changed "]

l_resize1 = torch.nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True)
l_resize2 = torch.nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True)

### Models

In [18]:
encoder.load_state_dict(
    torch.load(f'data/Pre-Trained Models/Finetuning/encoder_{VERSION}_best.pt')
)
encoder_trans.load_state_dict(
    torch.load(f'data/Pre-Trained Models/Finetuning/encoder_trans_{VERSION}_best.pt')
)
decoder.load_state_dict(
    torch.load(f'data/Pre-Trained Models/Finetuning/decoder_{VERSION}_best.pt')
)

<All keys matched successfully>

### Testing Loop

In [34]:
!pip install torchinfo

Collecting torchinfo
  Obtaining dependency information for torchinfo from https://files.pythonhosted.org/packages/72/25/973bd6128381951b23cdcd8a9870c6dcfc5606cb864df8eabd82e529f9c1/torchinfo-1.8.0-py3-none-any.whl.metadata
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0
[0m

In [31]:
import os
import torch
from torchinfo import summary

# ❗️ Add this at the VERY TOP of your script for a more accurate error traceback
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# ... (rest of your imports, model definitions, etc.) ...

# ===================================================================
# 🚀 GET MODEL SUMMARIES HERE
# ===================================================================

# Get a sample batch from your dataloader
# These tensors will likely be on your GPU by default
sample_batch = next(iter(test_loader))
imgA_sample, imgB_sample, seg_label_sample, _, _, _, _, _ = sample_batch


# --- 🧐 STEP 1: Inspect your inputs ---
print("\n" + "="*50)
print("🧐 INSPECTING MODEL INPUTS")
print("="*50)
seg_input = (seg_label_sample > 0).int()
print(f"imgA shape: {imgA_sample.shape}, dtype: {imgA_sample.dtype}, device: {imgA_sample.device}")
print(f"imgB shape: {imgB_sample.shape}, dtype: {imgB_sample.dtype}, device: {imgB_sample.device}")
print(f"seg_input shape: {seg_input.shape}, dtype: {seg_input.dtype}, device: {seg_input.device}")
print(f"🚨 seg_input Min value: {seg_input.min()}, Max value: {seg_input.max()}")
print("="*50 + "\n")
# The error is very likely caused by the min/max values of `seg_input`.
# Check if these values are valid for your model (e.g., not negative for an embedding layer).


# --- STEP 2: Run the summary on the CPU ---
print("--- Encoder Summary ---")
# Move model and data to CPU for the summary to isolate the issue
summary(encoder.to('cpu'),
        input_data=[imgA_sample.to('cpu'), imgB_sample.to('cpu'), seg_input.to('cpu')],
        device="cpu",
        col_names=["input_size", "output_size", "num_params", "mult_adds"])

# --- Decoder Summary ---
# To summarize the decoder, you need the output shape from the encoder part.
# We'll run a single forward pass on the CPU with the sample data to get it.
with torch.no_grad():
    feat1_sample, feat2_sample = encoder.to('cpu')(imgA_sample.to('cpu'), imgB_sample.to('cpu'), (seg_label_sample > 0).int().to('cpu'))
    feat_sample = encoder_trans.to('cpu')(feat1_sample.to('cpu'), feat2_sample.to('cpu'))
    decoder_input_shape = feat_sample.shape

# ... (Your code to get feat_sample and define the device) ...
# Ensure all sample data and the model are on the same device
feat_sample = feat_sample.to(device)
decoder.to(device)


print("\n--- Decoder Summary ---")
# The traceback shows the decoder's forward pass needs caption data (for teacher-forcing).
# We'll create dummy tensors with appropriate shapes to satisfy the model's signature.

# Get the batch size from your feature sample
batch_size = feat_sample.shape[0]

# 1. Create a dummy tensor for `encoded_captions`
# This is usually shaped (batch_size, max_caption_length).
# Adjust `max_caption_length` if you know your model's specific value.
max_caption_length = 42
dummy_captions = torch.randint(low=1, high=1000, size=(batch_size, max_caption_length), device=device)

# 2. Create a dummy tensor for `caption_lengths`
# This is usually shaped (batch_size,)
dummy_lengths = torch.full(size=(batch_size,), fill_value=max_caption_length, dtype=torch.long, device=device)


# 3. Provide all required arguments as a list to the `input_data` parameter
summary(decoder.to('cpu'),
        input_data=[feat_sample.to('cpu'), dummy_captions.to('cpu'), dummy_lengths.to('cpu')],
        device='cpu',
        col_names=["input_size", "output_size", "num_params", "mult_adds"])

# Move models back to GPU for testing
encoder.to(device)
encoder_trans.to(device)
decoder.to(device)
# ===================================================================


# Your existing testing loop starts here
test_start_time = time.time()
references = list()
hypotheses = list()

# ... (the rest of your testing loop code) ...


🧐 INSPECTING MODEL INPUTS
imgA shape: torch.Size([1, 3, 224, 224]), dtype: torch.float32, device: cpu
imgB shape: torch.Size([1, 3, 224, 224]), dtype: torch.float32, device: cpu
seg_input shape: torch.Size([1, 1, 224, 224]), dtype: torch.int32, device: cpu
🚨 seg_input Min value: 0, Max value: 1

--- Encoder Summary ---


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [39]:
encoder_trans.cpu()

CrossAttentiveEncoder(
  (h_embedding): Embedding(7, 384)
  (w_embedding): Embedding(7, 384)
  (CaMalayer_list): ModuleList(
    (0-3): 4 x ModuleList(
      (0-1): 2 x CaMambaModel(
        (embeddings): Embedding(50280, 768)
        (layers): ModuleList(
          (0): CaMambaBlock(
            (norm): MambaRMSNorm()
            (mixer): MambaMixer(
              (conv1d): Conv1d(1536, 1536, kernel_size=(3,), stride=(1,), padding=(2,), groups=1536)
              (conv1d_back): Conv1d(1536, 1536, kernel_size=(3,), stride=(1,), padding=(2,), groups=1536)
              (act): SiLU()
              (in_proj): Linear(in_features=768, out_features=3072, bias=False)
              (in_proj_dif): Linear(in_features=768, out_features=3072, bias=False)
              (x_proj): Linear(in_features=1536, out_features=80, bias=False)
              (x_proj_back): Linear(in_features=1536, out_features=80, bias=False)
              (x_proj_dif): Linear(in_features=1536, out_features=80, bias=False)
    

In [37]:
feat1_sample

tensor([[[ 0.1016,  0.7099,  0.1865,  ..., -0.0747,  0.3210, -0.1381],
         [ 0.2476,  1.2200,  0.0899,  ...,  0.0731,  0.1660, -0.0896],
         [ 0.5294,  1.1039,  0.0724,  ..., -0.2011,  0.3734,  0.5038],
         ...,
         [ 0.2633,  0.4328,  0.0766,  ..., -0.2279,  0.4134,  0.0924],
         [ 0.5186,  0.9856,  0.2324,  ...,  0.0081,  0.4800, -0.4232],
         [-0.1222,  0.8506,  0.0174,  ..., -0.0981,  0.4955, -0.1269]]])

In [23]:
test_start_time = time.time()
references = list()  # references (true captions) for calculating BLEU-4 score
hypotheses = list()  # hypotheses (predictions)


with torch.no_grad():
    # Batches
    for idx,batch_data in enumerate(tqdm(test_loader,desc='test_' + " EVALUATING AT BEAM SIZE " + str(1))):
        
        imgA, imgB, seg_label, token_all, token_all_len, token, token_len, name = batch_data
        
        #Getting Data and moving to GPU if possible
        imgA = imgA.cuda(device)
        imgB = imgB.cuda(device)
        #imgSM = imgSM.cuda(device)
        token = token.cuda(device)
        token_len = token_len.cuda(device)
        
        #Texts = get_text_inputs(name,split='test')
        
        
        
        #imgSM = ChangeCLIP.forward(xA=imgA,xB=imgB,Texts=Texts)
        
        #imgA, imgB, seg_label = Image_store_test[idx]

        feat1,feat2 = encoder(
            imgA.to(device),
            imgB.to(device),
            (seg_label > 0).int().to(device)
            )

        #eat1,feat2,feat3 = encoder(imgA.to(device),imgB.to(device),seg_label.to(device))
        feat = encoder_trans(feat1,feat2)
        seq = decoder.sample(feat,k=1)
        
        except_tokens = {word_vocab['<START>'], word_vocab['<END>'], word_vocab['<NULL>']}
        img_token = token_all.tolist()
        img_tokens = list(map(lambda c: [w for w in c if w not in except_tokens],
                    img_token[0]))  # remove <start> and pads
        references.append(img_tokens)
        
        pred_seq = [w for w in seq if w not in except_tokens]
        hypotheses.append(pred_seq)
        
        pred_caption = ""
        ref_caption = ""
        for i in pred_seq:
            pred_caption += (list(word_vocab.keys())[i]) + " "
        ref_caption = ""
        for i in img_tokens[0]:
            ref_caption += (list(word_vocab.keys())[i]) + " "
        ref_captions = ""
        for i in img_tokens:
            for j in i:
                ref_captions += (list(word_vocab.keys())[j]) + " "
            ref_captions += ".    "
            
        test_time = time.time() - test_start_time

        # Fast test during the training
        # Calculate evaluation scores
        

from eval_func.bleu.bleu import Bleu
from eval_func.rouge.rouge import Rouge
from eval_func.cider.cider import Cider
from eval_func.meteor.meteor import Meteor

scorers = [
    (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
    #(Meteor(), "METEOR"),
    (Rouge(), "ROUGE_L"),
    (Cider(), "CIDEr")
]

hypo = [[' '.join(hypo)] for hypo in [[str(x) for x in hypo] for hypo in tqdm(hypotheses)]]
ref = [[' '.join(reft) for reft in reftmp] for reftmp in
        [[[str(x) for x in reft] for reft in reftmp] for reftmp in tqdm(references)]]
score = []
method = []


for scorer, method_i in tqdm(scorers):
    score_i, scores_i = scorer.compute_score(ref, hypo)
    score.extend(score_i) if isinstance(score_i, list) else score.append(score_i)
    method.extend(method_i) if isinstance(method_i, list) else method.append(method_i)
    #print("{} {}".format(method_i, score_i))
score_dict = dict(zip(method, score))

get_eval_score#score_dict = get_eval_score(references, hypotheses)
Bleu_1 = score_dict['Bleu_1']
Bleu_2 = score_dict['Bleu_2']
Bleu_3 = score_dict['Bleu_3']
Bleu_4 = score_dict['Bleu_4']
#Meteor = score_dict['METEOR']
Rouge = score_dict['ROUGE_L']
Cider = score_dict['CIDEr']

print(f"{VERSION} Results")
print(f'Testing:\n Time: {test_time}s\n BLEU-1: {Bleu_1*100}  %\n BLEU-2: {Bleu_2*100}  %\n BLEU-3: {Bleu_3*100}  %\n BLEU-4: {Bleu_4*100}  %\n Rouge: {Rouge*100}  %\n Cider: {Cider}\t')


test_ EVALUATING AT BEAM SIZE 1:  82%|████████▏ | 1579/1929 [01:31<00:20, 17.25it/s]


KeyboardInterrupt: 

In [None]:
!pip install torchinfo

# Segmenting

### Training Loop

In [None]:
import matplotlib.pyplot as plt
import random

random.seed(3) #3
Binarize = False

# Assuming RSCD_Dict and mod are already defined
for itr in range(5):
    fig, axes = plt.subplots(1, 5, figsize=(20, 15))  # Adjust figsize for better spacing
    random_indices = [random.randint(1, 1900) for _ in range(1)]
    ChangeCLIP.eval()

    for i, idx in enumerate(random_indices):
        
        batch_data = test_loader.dataset.__getitem__(idx)
        
        imgA, imgB, seg_label, token_all, token_all_len, token, token_len, name = batch_data
        
        Ground_Truths = [token_to_text(tokens) for tokens in token_all]
        
        Before = torch.tensor(imgA)
        After = torch.tensor(imgB)
        Ground_Truth = torch.tensor(seg_label)
        
        Texts =  ["l" for _ in range(4)]  #get_text_inputs(name,split='test')
        
        with torch.no_grad():
            
            Pred = ChangeCLIP(xA=Before.unsqueeze(0).to(device), xB=After.unsqueeze(0).to(device), Texts=Texts)
            feat1,feat2,feat3 = encoder(manual_preprocess(Before).unsqueeze(0).to(device),manual_preprocess(After).unsqueeze(0).to(device),manual_preprocess(Pred).to(device))
            featcap = encoder_trans(feat1,feat2,feat3)
            seq = decoder.sample(featcap,k=1)
            
        Pred_UB = Pred[0].permute(1,2,0).cpu().detach().numpy()
        Pred_B = binarize(Pred)[0].permute(1,2,0).cpu().detach().numpy()

        pred_seq = [w for w in seq if w not in except_tokens]
        caption = [invert[token] for token in pred_seq]
        
        output = ''
        for word in caption:
            output += word + ' '
            
        print(f"Predicted_Caption : {output}")
        for j,ground_truth in enumerate(Ground_Truths):
            print(f'Ground Truth {j} : {ground_truth}')
        print('\n')
        

        # Plotting the panels for each index
        axes[0].imshow(Before.permute(1,2,0))
        axes[0].set_title(f'Before\nIndex: {idx}')
        axes[0].axis('off')
        axes[1].imshow(After.permute(1,2,0))
        axes[1].set_title(f'After\nIndex: {idx}')
        axes[1].axis('off')
        axes[2].imshow(Ground_Truth, cmap='gray')
        axes[2].set_title(f'Ground Truth\nIndex: {idx}')
        axes[2].axis('off')
        axes[3].imshow(Pred_B, cmap='gray')
        axes[3].set_title(f'Prediction\nIndex: {idx}')
        axes[3].axis('off')
        axes[4].imshow(Pred_UB, cmap='gray')
        axes[4].set_title(f'Prediction Unbinarized\nIndex: {idx}')
        axes[4].axis('off')

    plt.tight_layout()

    # Save the figure after completing all plots
    #plt.savefig(f'/home/guest/Documents/Siraj TM/DATA/Predictions_{itr}.png', dpi=200)
    plt.show()


# Both