In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TVF
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from network import Modele
from tqdm import tqdm
from utils import class_weights

%load_ext autoreload

%autoreload 2

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
#Simple Dataset class for our change detection dataset

class ChangeDetectionDataset(Dataset):
    def __init__(self, csv_file="data.csv", data_dir="./data", batch_size=1, transform=None, crop_size=128):
        #repeat the data 5 times to have more data
        self.data = pd.read_csv(csv_file).sample(frac=20, replace=True)
        self.data_dir = data_dir
        self.transform = transform
        self.batch_size = batch_size
        self.crop_size = crop_size
    def __len__(self):
        return len(self.data)

    def random_crop(self, img1, img2, cm, size):
        x = np.random.randint(0, img1.shape[2]-size)
        y = np.random.randint(0, img1.shape[1]-size)
        img1 = img1[:,y:y+size, x:x+size]
        img2 = img2[:,y:y+size, x:x+size]
        cm = cm[0:1, y:y+size, x:x+size]
        return img1, img2, cm

    def random_flip(self, img1,img2,cm, chance=0.5):
        if (np.random.randint(0,1)> chance):
            img1 = TVF.hflip(img1)
            img2 = TVF.hflip(img2)
            cm = TVF.hflip(cm)

        if (np.random.randint(0,1)> chance):
            img1 = TVF.vflip(img1)
            img2 = TVF.vflip(img2)
            cm = TVF.vflip(cm)

        return img1, img2, cm

    def __getitem__(self, idx):
        img1 = read_image(self.data_dir+'/'+self.data.iloc[idx,0])
        img2 = read_image(self.data_dir+'/'+self.data.iloc[idx,1])
        cm = read_image(self.data_dir+'/'+self.data.iloc[idx,2])

        img1Tensor = torch.zeros((3, self.crop_size, self.crop_size), dtype=torch.float32)
        img2Tensor = torch.zeros((3, self.crop_size, self.crop_size), dtype=torch.float32)
        cmTensor = torch.zeros((1, self.crop_size, self.crop_size), dtype=torch.float32)
        
        crop1, crop2, cropcm = self.random_crop(img1[:,:,:], img2[:,:,:], cm[:,:,:], self.crop_size)
        crop1, crop2, cropcm = self.random_flip(crop1[:,:,:], crop2[:,:,:], cropcm[:,:,:])
        img1Tensor[:,:,:] = crop1.float()/255
        img2Tensor[:,:,:] = crop2.float()/255
        cmTensor[:,:,:] = cropcm.float()/255
        #apply the same transformation to all images as batch dimension
        return img1Tensor, img2Tensor, cmTensor
        

In [4]:
#Simple DataLoader class for our change detection dataset
batch_size = 16

weights = class_weights("data.csv")
weights = torch.tensor(weights[1]/weights[0]).to(device)

train_dataset= ChangeDetectionDataset(data_dir="data",csv_file="train.csv", batch_size=batch_size, transform=None)
val_dataset = ChangeDetectionDataset(data_dir="data",csv_file="val.csv", batch_size=1, transform=None)
train_loader = DataLoader(batch_size=batch_size, dataset=train_dataset, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, shuffle=True)
#Simple function to display a batch of images

def show_batch(batch):
    img1s, img2s, cms = batch

    for i in range(len(img1s)):
        img1 = img1s[i,:,:,:]
        img2 = img2s[i,:,:,:]
        cm = cms[i,:,:,:]
        fig, ax = plt.subplots(1,3)
        ax[0].imshow(img1.permute(1,2,0))
        ax[1].imshow(img2.permute(1,2,0))
        ax[2].imshow(cm.permute(1,2,0), cmap='gray')
        plt.show()

#a = next(iter(train_loader))
#show_batch(a)

  weights = torch.tensor(weights[1]/weights[0]).to(device)


In [None]:
n_epoch = 50
learning_rate = 0.0005
model = Modele()
loss_fn = nn.BCEWithLogitsLoss(pos_weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
train_loss, val_loss = [], []
train_precision, val_precision = [], []
model = model.to(device)
best_loss =1000


for epoch in range(n_epoch):
    model.train()
    loss_cumu = 0
    prec = 0
    for img1,img2,cm in tqdm(train_loader, ascii=" >="):
        img1,img2,cm = img1.to(device),img2.to(device),cm.to(device)
        #Forward pass
        y_pred = model(img1, img2)
        loss = loss_fn(y_pred, cm)
        loss_cumu += loss
        #Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        prec += torch.sum((torch.round(torch.sigmoid(y_pred))*cm))/(torch.sum(cm)+1)
    train_loss.append(loss_cumu/len(train_loader))
    train_precision.append(prec/len(train_loader))

    model.eval()
    loss_cumu=0
    prec=0
    with torch.no_grad():
        for img1,img2,cm in val_loader:
            img1,img2,cm = img1.to(device),img2.to(device),cm.to(device)
            y_pred=model(img1,img2)
            #y_pred, cm = torch.flatten(y_pred, start_dim=1), torch.flatten(cm, start_dim=1)
            loss = loss_fn(y_pred,cm)
            loss_cumu += loss
            prec += torch.sum((torch.round(torch.sigmoid(y_pred))*cm))/(torch.sum(cm)+1)
            if(loss_cumu/len(val_loader) < best_loss):
                best_loss = loss_cumu/len(val_loader)
                torch.save(model.state_dict(), "best_model.pt")
        val_loss.append(loss_cumu/len(val_loader))
        val_precision.append(prec/len(val_loader))
    print(f"Epoch {epoch+1} : Training, loss: {train_loss[-1]}, accuracy: {train_precision[-1]} | Validation, loss: {val_loss[-1]}, accuracy: {val_precision[-1]}")
    



Epoch 1 : Training, loss: 0.7341694235801697, accuracy: 0.8535774946212769 | Validation, loss: 0.7057780027389526, accuracy: nan




Epoch 2 : Training, loss: 0.7046517729759216, accuracy: 0.9461609125137329 | Validation, loss: 0.6965171098709106, accuracy: nan




Epoch 3 : Training, loss: 0.6858428716659546, accuracy: 0.5260029435157776 | Validation, loss: 0.6824877858161926, accuracy: nan




Epoch 4 : Training, loss: 0.6796907782554626, accuracy: 0.08978279680013657 | Validation, loss: 0.679288387298584, accuracy: nan


 29%|==>       | 7/24 [00:04<00:10,  1.65it/s]


KeyboardInterrupt: 

In [None]:
model.cpu()
model.eval()
im1,im2,cm = next(iter(train_loader))
cm_pred = model(im1,im2, with_attn=False)
plt.figure()
plt.subplot(1,2,1)
plt.imshow(cm[0].permute(1,2,0), cmap='gray')
plt.subplot(1,2,2)
plt.imshow(cm_pred.detach()[0].permute(1,2,0), cmap='gray')
plt.show()

plt.figure()
plt.subplot(1,2,1)
plt.imshow(im1[0].permute(1,2,0))
plt.subplot(1,2,2)
plt.imshow(im2[0].permute(1,2,0))
plt.show()


In [None]:
model = torch.load("best_model.pt")

model.eval()


# Notes
U-Net peut être utile

# Soutenance
 - Explication du problème et comment le transcrire
 - Pré-traitement des données
 - Architecture du réseau
 - Présentation des résultats
 
# Rendu 
 - Slides de présentation (10 minutes+ 10 min de questions)
 - Notebook avec le code

À rendre en séance. 