In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Download the dataset
!gdown 1sFJ-rCj4eDRi8CBaBrkY_lT0JLhpCEGQ --output data.zip
!unzip data.zip -d data

In [None]:
!git clone https://KirollosSamy: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@github.com/KirollosSamy/Change_Detection.git

In [None]:
!mv Change_Detection/paths.pkl .
!pip install torchmetrics

In [None]:
import sys
sys.path.append("Change_Detection")

In [None]:
from src.datasets.classical_loader import create_classical_loader
from src.training.evaluation import jaccard_batch
from src.models.classical import ImageDiff

import os
import torch
from torchvision.utils import save_image
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryPrecision
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

In [None]:
def visualize_batch(A, B, delta, change_map):
    batch_size = A.shape[0]

    for i in range(batch_size):
        plt.figure(figsize=(8, 4))
        plt.subplot(1, 4, 1)
        plt.imshow(A[i].permute(1, 2, 0).numpy(), cmap='gray')
        plt.title('Image A')
        plt.axis('off')

        plt.subplot(1, 4, 2)
        plt.imshow(B[i].permute(1, 2, 0).numpy(), cmap='gray')
        plt.title('Image B')
        plt.axis('off')

        plt.subplot(1, 4, 3)
        plt.imshow(change_map[i].squeeze().numpy(), cmap='gray')
        plt.title('Model Output')
        plt.axis('off')

        plt.subplot(1, 4, 4)
        plt.imshow(delta[i].squeeze().numpy(), cmap='gray')
        plt.title('Delta')
        plt.axis('off')

        plt.show()

def evaluate_classical(model, dataloader, verbose=False):
    total_jaccard = 0.0
    total_accuracy = 0.0
    total_recall = 0.0
    total_precesion = 0.0

    binary_accuracy = BinaryAccuracy()
    binary_precesion = BinaryPrecision()
    binary_recall = BinaryRecall()

    for batch in tqdm(dataloader):
        A, B, delta = batch

        change_map = model.predict(A, B)

        if verbose:
            visualize_batch(A, B, delta, change_map)

        total_jaccard += jaccard_batch(change_map, delta)
        total_accuracy += binary_accuracy(change_map, delta).item()
        total_precesion += binary_precesion(change_map, delta).item()
        total_recall += binary_recall(change_map, delta).item()

    avg_jaggard = total_jaccard / len(dataloader)
    avg_accuracy = total_accuracy / len(dataloader)
    avg_precesion = total_precesion / len(dataloader)
    avg_recall = total_recall / len(dataloader)
    print(f'Jaggard Index: {avg_jaggard:.6f}')
    print(f'Accuracy: {avg_accuracy:.6f}')
    print(f'Precision: {avg_precesion:.6f}')
    print(f'Recall: {avg_recall:.6f}')

def test_classical(model, test_loader, device='cpu', verbose=False, save_dir=None):

    for batch_idx, batch in enumerate(tqdm(test_loader)):
        A, B = batch
        change_map = model.predict(A, B)

        # Save change maps
        if save_dir is not None:
            for i in range(len(change_map)):
                image_name = os.path.join(save_dir, f'{(batch_idx * len(change_map) + i):04d}.png')
                save_image(change_map[i].to(torch.float16), image_name)

        if verbose:
            visualize_batch(A, B, change_map)

In [None]:
data_dir = 'data'
batch_size = 16
classical_dir = 'classical'

In [None]:
dataloader = create_classical_loader(data_dir, batch_size, grayscale=True)

In [None]:
model = ImageDiff(threshold=0.5)

In [None]:
evaluate_classical(model, dataloader, verbose=False)

In [None]:
test_classical(model, dataloader, verbose=False, save_dir=classical_dir)