# Projet Apprentissage Supervisée: Change detection in bi-temporal remote sensing image 

### Clément Guigon - Ophélia Urbing - Etienne Bardet

#### Libraries

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TVF
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from network import *
from network import *
from tqdm import tqdm
from utils import *
from skimage import exposure
from torcheval.metrics.functional import binary_f1_score

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

#### Dataset

In [None]:
#Simple Dataset class for our change detection dataset

class ChangeDetectionDataset(Dataset):
    def __init__(self, csv_file="data.csv", data_dir="./data", batch_size=1, transform=None, crop_size=128):
        #repeat the data 5 times to have more data
        self.data = pd.read_csv(csv_file).sample(frac=30, replace=True)
        self.data_dir = data_dir
        self.transform = transform
        self.batch_size = batch_size
        self.crop_size = crop_size
    def __len__(self):
        return len(self.data)

    def random_crop(self, img1, img2, cm, size):
        x = np.random.randint(0, img1.shape[2]-size)
        y = np.random.randint(0, img1.shape[1]-size)
        img1 = img1[:,y:y+size, x:x+size]
        img2 = img2[:,y:y+size, x:x+size]
        cm = cm[0:1, y:y+size, x:x+size]
        return img1, img2, cm

    def random_flip(self, img1,img2,cm, chance=0.5):
        if (np.random.randint(0,1)> chance):
            img1 = TVF.hflip(img1)
            img2 = TVF.hflip(img2)
            cm = TVF.hflip(cm)

        if (np.random.randint(0,1)> chance):
            img1 = TVF.vflip(img1)
            img2 = TVF.vflip(img2)
            cm = TVF.vflip(cm)

        return img1, img2, cm

    def apply_hist_dynamic(self, source, target):
        for i in range(3):
            mean1_cn = torch.mean(source[i,:,:])
            std1_cn = torch.std(source[i,:,:])
            mean2_cn = torch.mean(target[i,:,:])
            std2_cn = torch.std(target[i,:,:])
            target[i,:,:] = ((target[i,:,:] - mean2_cn) / std2_cn)*std1_cn + mean1_cn
        return target

    def __getitem__(self, idx):
        img1 = read_image(self.data_dir+'/'+self.data.iloc[idx,0])
        img2 = read_image(self.data_dir+'/'+self.data.iloc[idx,1])
        cm = read_image(self.data_dir+'/'+self.data.iloc[idx,2])

        img1Tensor = torch.zeros((3, self.crop_size, self.crop_size), dtype=torch.float32)
        img2Tensor = torch.zeros((3, self.crop_size, self.crop_size), dtype=torch.float32)
        cmTensor = torch.zeros((1, self.crop_size, self.crop_size), dtype=torch.float32)
        
        crop1, crop2, cropcm = self.random_crop(img1[:,:,:], img2[:,:,:], cm[:,:,:], self.crop_size)
        crop1, crop2, cropcm = self.random_flip(crop1[:,:,:], crop2[:,:,:], cropcm[:,:,:])
        img1Tensor[:,:,:] = crop1.float()/255
        img2Tensor[:,:,:] = crop2.float()/255
        cmTensor[:,:,:] = cropcm.float()/255
        #apply the same transformation to all images as batch dimension
        img2Tensor = self.apply_hist_dynamic(img1Tensor, img2Tensor)
        img2Tensor = torch.clamp(img2Tensor, 0, 1)
        img1Tensor, img2Tensor = torch.Tensor(exposure.equalize_adapthist(img1Tensor.permute(1,2,0).numpy())).permute(2,0,1), torch.Tensor(exposure.equalize_adapthist(img2Tensor.permute(1,2,0).numpy())).permute(2,0,1)
        return img1Tensor, img2Tensor, cmTensor
        

#### DataLoader class for our change detection dataset

In [None]:
#Simple DataLoader class for our change detection dataset
batch_size = 32

weights = class_weights("data.csv")
weights = torch.tensor(weights[0]/weights[1]).to(device)

train_dataset= ChangeDetectionDataset(data_dir="data",csv_file="train.csv", batch_size=batch_size, transform=None)
val_dataset = ChangeDetectionDataset(data_dir="data",csv_file="val.csv", batch_size=1, transform=None)
train_loader = DataLoader(batch_size=batch_size, dataset=train_dataset, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, shuffle=True)


#### Function to display a batch of images

In [None]:
#Simple function to display a batch of images

def show_batch(batch):
    img1s, img2s, cms = batch

    for i in range(len(img1s)):
        img1 = img1s[i,:,:,:]
        img2 = img2s[i,:,:,:]
        cm = cms[i,:,:,:]
        fig, ax = plt.subplots(1,3)
        ax[0].imshow(img1.permute(1,2,0))
        ax[1].imshow(img2.permute(1,2,0))
        ax[2].imshow(cm.permute(1,2,0), cmap='gray')
        plt.show()
#a = next(iter(train_loader))
#show_batch(a)

#### Model

In [None]:
n_epoch = 30
learning_rate = 0.0005
model = ChangeDetectUnet(in_chan=9).to(device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
train_loss, val_loss = [], []
train_precision, val_precision = [], []
model = model.to(device)
best_loss =1000

for epoch in range(n_epoch):
    model.train()
    loss_cumu, f1 = 0,0

    for img1, img2, cm in tqdm(train_loader, ascii=" >="):
        img1, img2, cm = img1.to(device), img2.to(device), cm.to(device)

        # Forward pass
        y_pred = model(img1, img2)
        pred_binary = torch.ceil(torch.threshold(y_pred, 0.1, 0))
        loss = loss_fn(y_pred, cm)  # Supervision profonde
        loss_cumu += loss.item()
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Précision
        ground_truth_flat = cm.flatten()
        predictions_flat = pred_binary.flatten() 

        f1 += binary_f1_score(ground_truth_flat, predictions_flat)  
    train_loss.append(loss_cumu / len(train_loader))
    train_precision.append(f1/len(train_loader))

    # Validation
    model.eval()
    loss_cumu, f1 = 0, 0
    with torch.no_grad():
        for img1, img2, cm in val_loader:
            img1, img2, cm = img1.to(device), img2.to(device), cm.to(device)
                        
            y_pred = model(img1, img2)
            pred_binary = torch.ceil(torch.threshold(y_pred, 0.1, 0))
            loss = loss_fn(y_pred, cm)
            loss_cumu += loss.item()
  
            ground_truth_flat = cm.flatten()
            predictions_flat = pred_binary.flatten()
            f1 += binary_f1_score(ground_truth_flat, predictions_flat)  
            if loss_cumu / len(val_loader) < best_loss:
                best_loss = loss_cumu / len(val_loader)
                torch.save(model.state_dict(), "best_model.pt")
        val_loss.append(loss_cumu / len(val_loader))
        val_precision.append(f1/len(val_loader))
    
    print(f"Epoch {epoch+1} : Training, loss: {train_loss[-1]:.4f}, precision: {train_precision[-1]:.4f} | Validation, loss: {val_loss[-1]:.4f}, precision: {val_precision[-1]:.7f}")

In [None]:

# Affichage des courbes
plt.figure(figsize=(12, 6))

# Courbes de perte
plt.subplot(1, 2, 1)
plt.plot(train_loss, label='Train Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
 
# Courbes de précision
plt.subplot(1, 2, 2)
plt.plot(train_precision, label='Train Precision')
plt.plot(val_precision, label='Validation Precision')
plt.title('Precision Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Precision')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
model.cpu()
model.eval()

im1,im2,cm = next(iter(val_loader))

cm_pred = model(im1,im2)#, with_attn=False)
cm_pred_bin = torch.ceil(torch.threshold(cm_pred.detach(), 0.1, 0))

### Affichage des masques
plt.figure()
plt.subplot(1,3,1)
plt.imshow(cm[0].permute(1,2,0), cmap='gray')
plt.title("Ground truth")
plt.subplot(1,3,2)
plt.imshow(cm_pred.detach()[0].permute(1,2,0), cmap='plasma')
plt.title("Heatmap prediction")
plt.subplot(1,3,3)
plt.imshow(cm_pred_bin.detach()[0].permute(1,2,0), cmap='gray')
plt.title("Prediction")
plt.show()

### Affichage des images bi-temporelles
dif = im1[0]-im2[0]
dif_norm = torch.clamp((torch.abs(dif)-torch.mean(dif)),0,1)
plt.figure()
plt.subplot(1,3,1)
plt.imshow(im1[0].permute(1,2,0))
plt.title("Image 1")
plt.subplot(1,3,2)
plt.imshow(im2[0].permute(1,2,0))
plt.title("Image 2")
plt.subplot(1,3,3)
plt.imshow(dif_norm.permute(1,2,0))
plt.title("Normalized difference")
plt.figure()
plt.imshow(dif_norm.permute(1,2,0))
plt.imshow(cm_pred.detach()[0].permute(1,2,0), cmap='plasma', alpha=0.15)
plt.imshow(cm[0].permute(1,2,0), cmap='winter', alpha=0.15)
plt.title("Normalized difference with heatmap")

plt.show()

# Notes
U-Net peut être utile

# Soutenance
 - Explication du problème et comment le transcrire
 - Pré-traitement des données
 - Architecture du réseau
 - Présentation des résultats
 
# Rendu 
 - Slides de présentation (10 minutes+ 10 min de questions)
 - Notebook avec le code

À rendre en séance. 