# A Transformer-Based Siamese Network for vegetation Change Detection 

### Imports

In [1]:
import os, sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, jaccard_score
from models.ChangeFormer import ChangeFormerV3
sys.path.append(os.path.abspath(".."))
from image_preprocessing.image_preprocessing import load_image_pairs_labels

  from .autonotebook import tqdm as notebook_tqdm


### Create Training Dataset

In [2]:
image_paths = [
    ('../../Data/Antwerpen/Antwerpen_2018/JPEG2000/OMWRGB18VL_11002.jp2', '../../Data/Antwerpen/Antwerpen_2022/JPEG2000/OMWRGB22VL_11002.jp2', 8500, 7000, 4420, 6980, 3320, 5880, 256),
    ('../../Data/Leuven/Leuven_2018/JPEG2000/OMWRGB18VL_24062.jp2', '../../Data/Leuven/Leuven_2022/JPEG2000/OMWRGB22VL_24062.jp2', 8500, 7000, 3620, 6180, 2320, 4880, 256),
    ('../../Data/Kortrijk/Kortrijk_2018/JPEG2000/OMWRGB18VL_34022.jp2', '../../Data/Kortrijk/Kortrijk_2022/JPEG2000/OMWRGB22VL_34022.jp2', 8500, 7000, 2120, 4680, 1520, 4080, 256),
    ('../../Data/Brugge/Brugge_2018/JPEG2000/OMWRGB18VL_31005.jp2', '../../Data/Brugge/Brugge_2022/JPEG2000/OMWRGB22VL_31005.jp2', 8000, 6500, 4470, 7030, 2020, 4580, 256),
    ('../../Data/Hasselt/Hasselt_2018/JPEG2000/OMWRGB18VL_71022.jp2', '../../Data/Hasselt/Hasselt_2022/JPEG2000/OMWRGB22VL_71022.jp2', 8500, 7000, 2570, 5130, 3020, 5580, 256),
    ('../../Data/Mechelen/Mechelen_2018/JPEG2000/OMWRGB18VL_12025.jp2', '../../Data/Mechelen/Mechelen_2022/JPEG2000/OMWRGB22VL_12025.jp2', 8500, 7000, 3570, 6130, 3020, 5580, 256),
               ]

image_pairs_train, labels_train = load_image_pairs_labels(image_paths, normalized=False)

# cleanup
if 255 in np.unique(labels_train):
   labels_train = np.clip(labels_train, 0, 1).astype(np.uint8)

Reading ../../Data/Antwerpen/Antwerpen_2018/JPEG2000/OMWRGB18VL_11002.jp2 into shape (3, 8500, 7000)
Reading ../../Data/Antwerpen/Antwerpen_2022/JPEG2000/OMWRGB22VL_11002.jp2 into shape (3, 8500, 7000)
Reading ../../Data/Leuven/Leuven_2018/JPEG2000/OMWRGB18VL_24062.jp2 into shape (3, 8500, 7000)
Reading ../../Data/Leuven/Leuven_2022/JPEG2000/OMWRGB22VL_24062.jp2 into shape (3, 8500, 7000)
Reading ../../Data/Kortrijk/Kortrijk_2018/JPEG2000/OMWRGB18VL_34022.jp2 into shape (3, 8500, 7000)
Reading ../../Data/Kortrijk/Kortrijk_2022/JPEG2000/OMWRGB22VL_34022.jp2 into shape (3, 8500, 7000)
Reading ../../Data/Brugge/Brugge_2018/JPEG2000/OMWRGB18VL_31005.jp2 into shape (3, 8000, 6500)
Reading ../../Data/Brugge/Brugge_2022/JPEG2000/OMWRGB22VL_31005.jp2 into shape (3, 8000, 6500)
Reading ../../Data/Hasselt/Hasselt_2018/JPEG2000/OMWRGB18VL_71022.jp2 into shape (3, 8500, 7000)
Reading ../../Data/Hasselt/Hasselt_2022/JPEG2000/OMWRGB22VL_71022.jp2 into shape (3, 8500, 7000)
Reading ../../Data/Mechele

### Create Test Dataset

In [3]:
image_paths = [
    ('../../Data/Gent/Gent_2020/JPEG2000/OMWRGB20VL_44021.jp2', '../../Data/Gent/Gent_2024/JPEG2000/OMWRGB24VL_44021.jp2', 8500, 7000, 4220, 6780, 2520, 5080, 256)
               ]

test_image_pairs, test_labels = load_image_pairs_labels(image_paths, normalized=False)

# cleanup
if 255 in np.unique(test_labels):
   test_labels = np.clip(test_labels, 0, 1).astype(np.uint8)

Reading ../../Data/Gent/Gent_2020/JPEG2000/OMWRGB20VL_44021.jp2 into shape (3, 8500, 7000)
Reading ../../Data/Gent/Gent_2024/JPEG2000/OMWRGB24VL_44021.jp2 into shape (3, 8500, 7000)


In [4]:
class ChangeDataset(Dataset):
    def __init__(self, image_pairs, labels):
        self.image_pairs = image_pairs
        self.labels = labels

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        img1, img2 = self.image_pairs[idx]
        label = self.labels[idx]
        # Convert to tensors
        img1 = torch.from_numpy(img1.transpose(2, 0, 1)).float()  # [C, H, W]
        img2 = torch.from_numpy(img2.transpose(2, 0, 1)).float()
        label = torch.from_numpy(label).long()  # [H, W]
        return img1, img2, label

In [19]:
batch_size = 4
train_dataset = ChangeDataset(image_pairs_train, labels_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = ChangeDataset(test_image_pairs, test_labels)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChangeFormerV3(input_nc=3, output_nc=2, decoder_softmax=False).to(device)

# ========= Loss & Optimizer ==========
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

### Train Model

In [7]:
def train(num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for img1, img2, label in train_loader:
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            optimizer.zero_grad()
            output = model(img1, img2)  # output: [B, 2, H, W]
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")


train(num_epochs=10)

Epoch [1/10], Loss: 0.5584
Epoch [2/10], Loss: 0.5071
Epoch [3/10], Loss: 0.5119
Epoch [4/10], Loss: 0.5032
Epoch [5/10], Loss: 0.4993
Epoch [6/10], Loss: 0.4994
Epoch [7/10], Loss: 0.4944
Epoch [8/10], Loss: 0.4863
Epoch [9/10], Loss: 0.4820
Epoch [10/10], Loss: 0.4549


### Evaluate Model

In [8]:
def visualize_predictions(img1, img2, gt, pred, idx):
    fig, axs = plt.subplots(1, 4, figsize=(12, 4))
    axs[0].imshow(img1.transpose(1, 2, 0).astype(np.uint8))
    axs[0].set_title("Image T1")
    axs[1].imshow(img2.transpose(1, 2, 0).astype(np.uint8))
    axs[1].set_title("Image T2")
    axs[2].imshow(gt, cmap='gray')
    axs[2].set_title("Ground Truth")
    axs[3].imshow(pred, cmap='gray')
    axs[3].set_title("Prediction")
    for ax in axs:
        ax.axis("off")
    plt.tight_layout()
    plt.savefig(f"prediction_vs_gt_{idx}.png")
    plt.show()

In [None]:
def evaluate():
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for idx, (img1, img2, label) in enumerate(test_loader):
            img1, img2 = img1.to(device), img2.to(device)
            output = model(img1, img2)  # [1, 2, H, W]
            pred = torch.argmax(output, dim=1).cpu().numpy()[0]
            y_true.extend(label.numpy().flatten())
            y_pred.extend(pred.flatten())

            # Visualization (first 5 only)
            # if idx < 5:
            #     visualize_predictions(
            #         img1.cpu().numpy()[0],
            #         img2.cpu().numpy()[0],
            #         label.numpy()[0],
            #         pred,
            #         idx
            #     )


    # Compute metrics
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    iou = jaccard_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    return {
        "Accuracy": acc,
        "Precision": prec,
        "Recall": rec,
        "IoU": iou,
        "f1": f1,
    }, y_pred, y_true

metrics, y_pred, y_true = evaluate()

for name, value in metrics.items():
    print(f"{name}: {value:.4f}")

Accuracy: 0.7628
Precision: 0.4534
Recall: 0.1077
IoU: 0.0953
f1: 0.1740


### Save Model

In [22]:
torch.save(model.state_dict(), "change_former.pth")