In [1]:
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader
from tqdm import tqdm

from utils.dataloader import CadastreSen2Dataset
from utils.index_calculation import BSI, NDBI, NDMI, NDVI, NDWI

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the data

Get the data

Prepare pytorch dataset

In [3]:
dataset = CadastreSen2Dataset(image_path="./data/")

No numpy patches found in ./data/31035\patches
No numpy patches found in ./data/57591\patches


Prepare dataloader

In [4]:
batch_size = 32
num_threads = 0

#Split into train and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_threads)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_threads)

# Model definition

In [5]:
class Conv2DRegressionModel(nn.Module): # prédiction de l'image 2024
    def __init__(self, int_channels: int = 11):
        super(Conv2DRegressionModel, self).__init__()
        self.nb_channel=int_channels
        self.layers = nn.Sequential(
            nn.Conv2d(int_channels, 64, kernel_size=3, padding=1),
            nn.LazyBatchNorm2d(),
            # nn.BatchNorm2d(int_channels*2)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2)),
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.LazyBatchNorm2d(),
            # nn.BatchNorm2d(int_channels*4)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2)),
            # nn.FullyConnected ?
            nn.Flatten(),
            nn.Linear(256*16*16, 64*64*(int_channels-1)),
            # nn.ReLU(),
            # nn.Linear(64*64, 64*64*(int_channels-1))
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        y = self.layers(X)
        y= y.view(-1, self.nb_channel-1, 64, 64)
        # y = y.view(-1, 64, 64) # sortie en image NDVI
        return y


# Train the model

In [None]:
import ignite.metrics as im

In [6]:
n_epochs = 10
learning_rate = 0.001

model = Conv2DRegressionModel()
model = model.to(device)


# metric = im.SSIM(data_range=1.0)
# metric.attach(im.default_evaluator, 'ssim')
# preds = torch.rand([4, 3, 16, 16])
# target = preds * 0.75
# state = im.default_evaluator.run([[preds, target]])
# print(state.metrics['ssim'])

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_set_len = len(train_loader)
val_set_len = len(val_loader)

train_loss_comp, val_loss_comp = [], []
train_loss_pred, val_loss_pred = [], []
train_loss_comp_pred, val_loss_comp_pred = [], []
train_loss_tot, val_loss_tot = [], []
best_loss=1000000
for epoch in tqdm(range(n_epochs)):
    model.train()
    accu_comp = 0.0
    accu_pred = 0.0
    accu_comp_pred = 0.0
    accu_total = 0.0

    for x, y, mask in train_loader:
        x,y,mask = x.to(device), y.to(device), mask.to(device)
        x,y,mask = x.float(), y.float(), mask.float()
        optimizer.zero_grad()
        # img_2018, img_2024 = img_2018.to(device), img_2024.to(device)
        # Forward pass
        in_x = torch.cat((x, mask), dim=1)
        y_pred = model(in_x) # image 2024 prédite à partir de 2018

        Ind_comp = [NDVI(y_pred), NDWI(y_pred), NDBI(y_pred), NDMI(y_pred), BSI(y_pred)] # indices calculés
        Ind_real = y[:,5:,:,:]
        Ind_pred = y_pred[:,5:,:,:]
        weights = torch.tensor([0.4, 0.2, 0.2, 0.1, 0.1], device=device) # pondération à ajuster
        loss_comp = sum(w * loss_fn(comp, true) for w, comp, true in zip(weights, Ind_comp, Ind_real)) # comparaison des indices calculés
        loss_pred = sum(w * loss_fn(pred, true) for w, pred, true in zip(weights, Ind_pred, Ind_real)) # comparaison des indices prédits
        loss_comp_pred = sum(w * loss_fn(comp, pred) for w, comp, pred in zip(weights, Ind_comp, Ind_pred)) # comparaison des indices calculés et prédits
        total_loss = loss_comp + loss_pred + loss_comp_pred

        accu_comp += loss_comp.item()
        accu_pred += loss_pred.item()
        accu_comp_pred += loss_comp_pred.item()
        accu_total += total_loss.item()
        # Backward pass
        total_loss.backward()
        optimizer.step()

    train_loss_comp.append(accu_comp / train_set_len)
    train_loss_pred.append(accu_pred / train_set_len)
    train_loss_comp_pred.append(accu_comp_pred / train_set_len)
    train_loss_tot.append(accu_total / train_set_len)
    # Validation - no gradient & eval mode
    model.eval()
    accu_comp = 0.0
    accu_pred = 0.0
    accu_comp_pred = 0.0
    accu_total = 0.0

    with torch.no_grad():
        for x, y, mask in val_loader:
            x,y,mask = x.to(device), y.to(device), mask.to(device)
            x,y,mask = x.float(), y.float(), mask.float()
            in_x = torch.cat((x, mask), dim=1)
            y_pred = model(in_x)

            Ind_comp = [NDVI(y_pred), NDWI(y_pred), NDBI(y_pred), NDMI(y_pred), BSI(y_pred)]
            Ind_real = y[:,5:,:,:]
            Ind_pred = y_pred[:,5:,:,:]
            weights = torch.tensor([0.4, 0.2, 0.2, 0.1, 0.1], device=device)
            loss_comp = sum(w * loss_fn(comp, true) for w, comp, true in zip(weights, Ind_comp, Ind_real))
            loss_pred = sum(w * loss_fn(pred, true) for w, pred, true in zip(weights, Ind_pred, Ind_real))
            loss_comp_pred = sum(w * loss_fn(comp, pred) for w, comp, pred in zip(weights, Ind_comp, Ind_pred))
            total_loss = loss_comp + loss_pred + loss_comp_pred

            accu_comp += loss_comp.item()
            accu_pred += loss_pred.item()
            accu_comp_pred += loss_comp_pred.item()
            accu_total += total_loss.item()

    val_loss_comp.append(accu_comp / val_set_len)
    val_loss_pred.append(accu_pred / val_set_len)
    val_loss_comp_pred.append(accu_comp_pred / val_set_len)
    val_loss_tot.append(accu_total / val_set_len)

    if accu_total < best_loss:
        best_loss = accu_total
        torch.save(model.state_dict(), "best_model.pt")

    if epoch % 1 == 0:
        print(f"Epoch {epoch + 1}/{n_epochs} - Train loss: {train_loss_tot[-1]:.4f} - Val loss: {val_loss_tot[-1]:.4f}")

    if epoch % 10 == 0:
        torch.save(model.state_dict(), f"model_{epoch}.pt")
    

  0%|          | 0/10 [00:00<?, ?it/s]

Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), After: (10, 64, 64), Mask: (1, 64, 64)
Before: (10, 64, 64), Af

  0%|          | 0/10 [00:04<?, ?it/s]

torch.Size([32, 10, 64, 64])





RuntimeError: Input type (double) and bias type (float) should be the same

In [None]:
plt.plot(train_loss_comp, label="Train Loss 'comp'")
plt.plot(val_loss_comp, label="Validation Loss 'comp'")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()

plt.plot(train_loss_pred, label="Train Loss 'pred'")
plt.plot(val_loss_pred, label="Validation Loss 'pred'")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()

plt.plot(train_loss_comp_pred, label="Train Loss 'comp_pred'")
plt.plot(val_loss_comp_pred, label="Validation Loss 'comp_pred'")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()