In [25]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
from torch import tensor
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn.functional as F
from skimage import io, color
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torchmetrics import JaccardIndex, Dice
from tqdm import tqdm
from sklearn.metrics import jaccard_score
import gc
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
# Load images from folders A and B
folder_A = 'testval/A'
folder_B = 'testval/B'
folder_label = 'testval/label'

In [3]:
device = ("cuda" if torch.cuda.is_available() else "cpu") # Use GPU or CPU for training

In [4]:
print(device)

cuda


In [5]:
# Function to load images from directory
def load_images_from_folder(folder, is_gray = True):
    images = []
    for filename in os.listdir(folder):
        img = io.imread(os.path.join(folder,filename)).astype(np.uint8)
        if img is not None:
            if is_gray:
                images.append(color.rgb2gray(img))
            else:    
                images.append(img)
    return images


In [6]:
images_A = load_images_from_folder(folder_A, is_gray=False)

In [7]:
images_B = load_images_from_folder(folder_B, is_gray=False)

In [8]:
labels = load_images_from_folder(folder_label, is_gray = False)

In [9]:
print(len(images_A), len(images_B), len(labels)) 

1200 1200 1200


In [10]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5, 0.5),
])

In [11]:
class ChangeDetectionDataset(Dataset):
    def __init__(self, images_A, images_B, labels):
        self.images_A = images_A
        self.images_B = images_B
        self.labels = labels
#         self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_A = torch.tensor(self.images_A[idx], dtype=torch.float32).permute(2, 0, 1) / 255
        image_B = torch.tensor(self.images_B[idx], dtype=torch.float32).permute(2, 0, 1) / 255
        label = torch.tensor(self.labels[idx], dtype=torch.float32) / 255


        return image_A, image_B, label

In [12]:
class Spatial_Attention(torch.nn.Module):
    def __init__(self):
        super(Spatial_Attention, self).__init__()

    def forward(self, x):
        _, _, h, w = x.size()
        q = x.mean(dim=[2, 3], keepdim=True)
        k = x
        square = (k - q).pow(2)
        sigma = square.sum(dim=[2, 3], keepdim=True) / (h * w)
        att_score = square / (2 * sigma + np.finfo(np.float32).eps) + 0.5
        att_weight = nn.Sigmoid()(att_score)

        return x * att_weight

class First_DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(First_DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)

class Conv_With_Attention(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Conv_With_Attention, self).__init__()
        out_channels_without_att = out_channels // 2

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels_without_att, kernel_size = 1, stride = 1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels_without_att),
            nn.ReLU(inplace=True),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels_without_att, out_channels_without_att, kernel_size = 3, stride = 1, padding = 1, groups= out_channels_without_att // 4, bias=False),
            nn.BatchNorm2d(out_channels_without_att),
            nn.ReLU(inplace=True),
        )

        self.attention = Spatial_Attention()

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.attention(x2)
        out = torch.cat([x1, x3], dim=1)
        return out


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.Conv = nn.Sequential(
            Conv_With_Attention(in_channels, out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            Conv_With_Attention(out_channels, out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, input):
        return self.Conv(input)


class UNet_Siamese(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet_Siamese, self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1_1 = First_DoubleConv(in_channels, 32)
        self.Conv1_2 = First_DoubleConv(in_channels, 32)
        self.Conv2_1 = DoubleConv(32, 64)
        self.Conv2_2 = DoubleConv(32, 64)
        self.Conv3_1 = DoubleConv(64, 128)
        self.Conv3_2 = DoubleConv(64, 128)
        self.Conv4_1 = DoubleConv(128, 256)
        self.Conv4_2 = DoubleConv(128, 256)
        self.Conv5_1 = DoubleConv(256, 512)
        self.Conv5_2 = DoubleConv(256, 512)

        self.Up5 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.Up_conv5 = DoubleConv(512, 256)

        self.Up4 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.Up_conv4 = DoubleConv(256, 128)

        self.Up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.Up_conv3 = DoubleConv(128, 64)

        self.Up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.Up_conv2 = DoubleConv(64, 32)

        self.Conv_1x1 = nn.Conv2d(32, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x1, x2):
        # encoding
        c1_1 = self.Conv1_1(x1)
        c1_2 = self.Conv1_2(x2)
        x1 = torch.abs(torch.sub(c1_1, c1_2))
        
        c2_1 = self.Maxpool(c1_1)
        c2_1 = self.Conv2_1(c2_1)
        c2_2 = self.Maxpool(c1_2)
        c2_2 = self.Conv2_2(c2_2)
        x2 = torch.abs(torch.sub(c2_1, c2_2))
        
        c3_1 = self.Maxpool(c2_1)
        c3_1 = self.Conv3_1(c3_1)
        c3_2 = self.Maxpool(c2_2)
        c3_2 = self.Conv3_2(c3_2)
        x3 = torch.abs(torch.sub(c3_1, c3_2))

        c4_1 = self.Maxpool(c3_1)
        c4_1 = self.Conv4_1(c4_1)
        c4_2 = self.Maxpool(c3_2)
        c4_2 = self.Conv4_2(c4_2)
        x4 = torch.abs(torch.sub(c4_1, c4_2))

        c5_1 = self.Maxpool(c4_1)
        c5_1 = self.Conv5_1(c5_1)
        c5_2 = self.Maxpool(c4_2)
        c5_2 = self.Conv5_2(c5_2)
        x5 = torch.abs(torch.sub(c5_1, c5_2))
        # x5 = nn.Dropout2d(0.2)(x5)

        # decoding
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        # d5 = nn.Dropout2d(0.2)(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        # d4 = nn.Dropout2d(0.2)(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)
        # d3 = nn.Dropout2d(0.2)(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        # d2 = nn.Dropout2d(0.2)(d2)

        d1 = self.Conv_1x1(d2)
        out = nn.Sigmoid()(d1)

        return out

In [13]:
# def load_checkpoint(filepath):
#     checkpoint = torch.load(filepath)
#     model = checkpoint['model']
#     model.load_state_dict(checkpoint['state_dict'])
#     for parameter in model.parameters():
#         parameter.requires_grad = False
#     model = model.to(device)
#     model.eval()
#     return model

# model = load_checkpoint('checkpoint_2.pth')
# print(model)

In [14]:
model = UNet_Siamese(in_channels=3, out_channels=1)
PATH = 'model.pth'
model.load_state_dict(torch.load(PATH))
model = model.to(device)
model.eval()

UNet_Siamese(
  (Maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv1_1): First_DoubleConv(
    (conv): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (Conv1_2): First_DoubleConv(
    (conv): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(in

In [19]:
images_len = len(images_A)
num_of_digits = len(str(images_len))
print(num_of_digits)

4


In [22]:
def getImageName(number):
    numberStr = str(number)
    while(len(numberStr) < num_of_digits):
        numberStr = '0'+numberStr
    
    numberStr = 'predicted/'+numberStr+'.png'
    return numberStr    

In [28]:
dataset = ChangeDetectionDataset(images_A, images_B, labels)
data_loader = DataLoader(dataset, batch_size=16)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
val_total_intersections = 0
val_total_unions = 0

with torch.no_grad():
    jaccardArr = []
    skArr = []
    counter = 0
    for a, b, l in tqdm(data_loader):
        a, b, l = a.to(device), b.to(device), l.to(device)
        outputs = model(a,b)
        outputs = outputs.reshape((-1,l.shape[1],l.shape[2]))
        outputs_binary = (outputs>0.5) * 1
        if counter == 0:
            print(l.shape[1],l.shape[2])   
        val_total_intersections += np.logical_and(outputs_binary.cpu().numpy(), l.cpu().numpy()).sum()
        val_total_unions += np.logical_or(outputs_binary.cpu().numpy(), l.cpu().numpy()).sum()
        for i in range(len(outputs_binary)):
            ###############################################################
            
            
            plt.imsave(getImageName(counter), outputs_binary[i].cpu().numpy(), cmap='gray')
            counter += 1
            
            ###############################################################
            skArr.append(jaccard_score(outputs_binary[i].cpu().numpy().flatten(), l[i].cpu().numpy().flatten(), average='binary', zero_division=1))
            intersection = np.logical_and(outputs_binary[i].cpu().numpy(), l[i].cpu().numpy()).sum()
            union = np.logical_or(outputs_binary[i].cpu().numpy(), l[i].cpu().numpy()).sum()
            if union != 0:
                jaccardArr.append((intersection)/(union))
            else:
                jaccardArr.append(1)
val_jaccard = val_total_intersections / val_total_unions
secondWay = np.array(jaccardArr).mean()
sklearnWay = np.array(skArr).mean()
print(f'Validation Jaccard Index : {val_jaccard} | Jaccard Avg : {secondWay} | sklearnWay : {sklearnWay} ')

  0%|          | 0/75 [00:00<?, ?it/s]

256 256


100%|██████████| 75/75 [00:44<00:00,  1.67it/s]


Validation Jaccard Index : 0.4969266234190536 | Jaccard Avg : 0.8744428411173835 | sklearnWay : 0.8744428411173835 
