In [1]:
# IMPORTS
import os
import numpy as np
import cv2
from sklearn.metrics import jaccard_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support as prfs
import skimage.io as io

# UNet imports
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.nn.functional import relu


# Custom imports
from utilities import *

# tany
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

from tqdm import tqdm
import random
import logging
import datetime
from tensorboardX import SummaryWriter
import metrics


In [2]:
# initialize cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
# # Read dataset
# def read_images(folder_path):
#    '''
#     Read images from a folder and return them as a list of numpy arrays.
#     Normalize the images
#    '''
#    images = []
#    file_names = os.listdir(folder_path)
   
#    for file_name in file_names:
#       # image = np.array(io.imread(os.path.join(folder_path, file_name)))/255.0
#       # images.append(image)
#       image = np.array(io.imread(os.path.join(folder_path, file_name)))
#       images.append(image)
   
#    return images
 
# # read dataset 
# A = read_images('dataset/trainval/A') # initial images
# B = read_images('dataset/trainval/B') # images after a certain amount of time
# labels = read_images('dataset/trainval/label') # ground truth images (actual change)

# assert len(A) == len(B) == len(labels), "Number of images in A, B and labels are not equal."

In [4]:
# Split dataset into training, validation and test sets
# A_train, A_test, B_train, B_test, labels_train, labels_test = train_test_split(A, B, labels, test_size=0.3, random_state=42)

# we can use the same function to split the training set into training and validation sets
# A_train, A_val, B_train, B_val, labels_train, labels_val = train_test_split(A_train, B_train, labels_train, test_size=0.2, random_state=42)

In [5]:
# # Show a random image in the training set
# random_index = np.random.randint(0, len(A_train))

# print("Random image before change")
# io.imshow(A_train[random_index])
# io.show()

# print("Random image after change")
# io.imshow(B_train[random_index])
# io.show()

# print("Change ground truth of the random image")
# io.imshow(labels_train[random_index])
# io.show()

## ***Deep Learning Methods***

1. UNet
2. Siamese
3. SUNet

### **UNet**

UNet Architecture
1. Encoder (Contracting Path): down sampling the input image size while depth increases

    Each Block:
    - Two 3*3 Convolutional Layers zero-padded with stride=1 Each Followed by a RELU Activation
    - Max Pooling Layer 2*2 with stride=2 (Dimension halved)(Same Depth) [⬇ Down Sampling] 

2. Decoder

In [6]:
# # First we build our convolution block
# # A convolution block for the UNet consists of two 3x3 convolutions with ReLU activation and batch normalization
# class ConvBlock(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(ConvBlock, self).__init__()
#         self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
#         self.bn = nn.BatchNorm2d(out_channels)
#         self.relu = nn.ReLU()
    
#     def forward(self, input):
#         # first 3x3 convolution
#         conv_output = self.conv(input)
#         conv_output = self.bn(conv_output)
#         conv_output = self.relu(conv_output)

#         # second  convolution
#         conv_output = self.conv(conv_output)
#         conv_output = self.bn(conv_output)
#         conv_output = self.relu(conv_output)

#         return conv_output
    
# # Next we build the encoder block
# # The encoder block consists of 4 blocks with the following layers:
# # 1. A convolution block
# # 2. A max pooling layer
# # Note: The last encoder block does not have a max pooling layer

# class EncoderBlock(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(EncoderBlock, self).__init__()
#         self.conv_block = ConvBlock(in_channels, out_channels)
#         self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    
#     def forward(self, input):
#         conv_output = self.conv_block(input)
#         pool_output = self.max_pool(conv_output)
#         return conv_output, pool_output
    
# # Next we build the decoder block
# # The decoder block consists of 4 blocks with the following layers:
# # 1. nn.ConvTranspose2d convolution (transposed convolution)
# # 2. convoloition block
# class DecoderBlock(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(DecoderBlock, self).__init__()
#         self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
#         self.conv_block = ConvBlock(2*out_channels, out_channels)
    
#     def forward(self, inputs, skip):
#         up_conv_output = self.up_conv(inputs)
#         up_conv_output = torch.cat([up_conv_output, skip], axis=1)
#         up_conv_output = self.conv_block(up_conv_output)
#         return up_conv_output


In [7]:
# # build the UNet model
# class UNet(nn.Module):
#     def __init__(self, n_classes):
#         super(UNet, self).__init__()
#         ''' Encoder '''
#         self.encoder_block1 = EncoderBlock(n_classes, 64)
#         self.encoder_block2 = EncoderBlock(64, 128)
#         self.encoder_block3 = EncoderBlock(128, 256)
#         self.encoder_block4 = EncoderBlock(256, 512)
        
#         ''' Bottle Neck'''
#         self.bottleneck = ConvBlock(512, 1024)
#         # No max pooling in the bottleneck => considered last layer in the encoder
        
#         ''' Decoder '''
#         self.decoder_block4 = DecoderBlock(1024, 512)
#         self.decoder_block3 = DecoderBlock(512, 256)
#         self.decoder_block2 = DecoderBlock(256, 128)
#         self.decoder_block1 = DecoderBlock(128, 64)

        
#         ''' Classifier '''
#         self.output_conv = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0)
    
#     def forward(self, input):
#         # encoder
#         e1, p1 = self.encoder_block1(input)
#         e2, p2 = self.encoder_block2(p1)
#         e3, p3 = self.encoder_block3(p2)
#         e4, p4 = self.encoder_block4(p3)
        
#         # bottle neck
#         bn = self.center(p4)
        
#         # decoder
#         d4 = self.decoder_block4(bn, e4)
#         d3 = self.decoder_block3(d4, e3)
#         d2 = self.decoder_block2(d3, e2)
#         d1 = self.decoder_block1(d2, e1)
        
#         # output
#         output = self.output_conv(d1)
#         return output

In [8]:
# loss_per_epoch,bce_loss_per_epoch,dice_loss_per_epoch,jacord_index_per_epoch,loss_per_val_epoch,bce_loss_per_val_epoch,dice_loss_per_val_epoch,jacord_index_per_val_epoch = run(UNet, A, B)

### **Siamese UNet**

1. Load the dataset using dataloaders

In [9]:
class LoadDataset(Dataset):
    def __init__(self, input_folder, transforms_list=[]):
        
        self.before_folder = os.path.join(input_folder, 'A')
        self.after_folder = os.path.join(input_folder, 'B')
        self.label_folder = os.path.join(input_folder, 'label')

        self.file_names = os.listdir(self.before_folder) # any folder msh far2a

        self.transforms = transforms_list
        
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        before_image = io.imread(os.path.join(self.before_folder, self.file_names[idx]))
        after_image = io.imread(os.path.join(self.before_folder, self.file_names[idx]))
        label = io.imread(os.path.join(self.label_folder, self.file_names[idx]))

        
        label = label.astype('float32')  # Convert to floating point to allow division
        label = label > 0
        label = label.astype(np.int64)
        label = torch.as_tensor(label, dtype=torch.float32)
        label = label.squeeze()

        if len(self.transforms) == 2:
            before_image = self.transforms[0](before_image)
            after_image = self.transforms[1](after_image)


        return {'images': (before_image, after_image), 'label': label}
    
# Define the transformations
transform = [transforms.Compose([transforms.ToTensor()]), transforms.Compose([transforms.ToTensor()])]

# Load the dataset
dataset = LoadDataset('dataset/trainval', transform)

# Split the dataset into training, test, and validation sets (80, 10, 10)
train_set, temp_set = train_test_split(dataset, test_size=0.3, random_state=42)
val_set, test_set = train_test_split(temp_set, test_size=0.5, random_state=42)

# create the DataLoader
dataloader = {
    'train': DataLoader(train_set, batch_size=8, shuffle=True),
    'val': DataLoader(val_set, batch_size=8, shuffle=False),
    'test': DataLoader(test_set, batch_size=8, shuffle=False)
}

print("DATASET LOADED")

DATASET LOADED


2. Build the Siamese model

<img src="siamese_architecture.jpg"/>


In [10]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d

class SiameseUNet(nn.Module):
    """SiamUnet_diff segmentation network."""

    def __init__(self, input_nbr, label_nbr):
        super(SiameseUNet, self).__init__()

        self.input_nbr = input_nbr

        self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(16)
        self.do11 = nn.Dropout2d(p=0.2)
        self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(16)
        self.do12 = nn.Dropout2d(p=0.2)

        self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(32)
        self.do21 = nn.Dropout2d(p=0.2)
        self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(32)
        self.do22 = nn.Dropout2d(p=0.2)

        self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(64)
        self.do31 = nn.Dropout2d(p=0.2)
        self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(64)
        self.do32 = nn.Dropout2d(p=0.2)
        self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(64)
        self.do33 = nn.Dropout2d(p=0.2)

        self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(128)
        self.do41 = nn.Dropout2d(p=0.2)
        self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(128)
        self.do42 = nn.Dropout2d(p=0.2)
        self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(128)
        self.do43 = nn.Dropout2d(p=0.2)

        self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(128)
        self.do43d = nn.Dropout2d(p=0.2)
        self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(128)
        self.do42d = nn.Dropout2d(p=0.2)
        self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(64)
        self.do41d = nn.Dropout2d(p=0.2)

        self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(64)
        self.do33d = nn.Dropout2d(p=0.2)
        self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(64)
        self.do32d = nn.Dropout2d(p=0.2)
        self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(32)
        self.do31d = nn.Dropout2d(p=0.2)

        self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(32)
        self.do22d = nn.Dropout2d(p=0.2)
        self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(16)
        self.do21d = nn.Dropout2d(p=0.2)

        self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(16)
        self.do12d = nn.Dropout2d(p=0.2)
        self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)

        self.sm = nn.LogSoftmax(dim=1)

    def forward(self, x1, x2):


        """Forward method."""
        # for imput image 1
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
        x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)


        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)

        ####################################################
        # for input image 2
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
        x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)


        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)



        # Stage 4d
        x4d = self.upconv4(x4p)
        pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1)
        x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
        x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
        x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1)
        x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
        x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
        x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1)
        x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
        x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1)
        x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
        x11d = self.conv11d(x12d)

        return self.sm(x11d)

3. Train the model

In [11]:
# epochs = 10
# criterion = nn.NLLLoss()
# model = SiameseUNet(3, 2) # 3 input channels, 2 classes

# model.cuda()

# model.to(device)

# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95)

# for epoch in tqdm(range(epochs)):
#     model.train()

#     print('Epoch {}/{}'.format(epoch + 1, epochs))
#     print('-' * 10)

#     running_loss = 0.0

#     for batch in dataloader['train']:

#         # load the data to the device
#         before_image = batch[0].to(device)
#         after_image = batch[1].to(device)
#         label = batch[2].to(device)

#         optimizer.zero_grad()

#         output = model(before_image, after_image)


#         loss = criterion(output, label.long())
#         loss.backward()
#         optimizer.step()

        
#         # Calculate accuracy
#         _, predicted_indices = torch.max(output.data, 1)

#         running_loss += loss.item()

#         predicted_indices = predicted_indices.int().cpu().numpy()

#         label_np = label.cpu().numpy()    

4. Test the model

In [12]:
# jaccard_indices=[]

# model.eval()

# for batch in dataloader['test']:
#     before_image = batch[0].to(device)
#     after_image = batch[1].to(device)
#     label = batch[2].to(device)

#     output = model(before_image, after_image)

#     _, predicted_indices = torch.max(output.data, 1)

#     predicted_indices = predicted_indices.int().cpu().numpy()

#     label_np = label.cpu().numpy()

    
#     for i in range(predicted_indices.shape[0]):
#         cv2.imwrite("/content/images/"+f'{i}'+".jpg", predicted_indices[i].reshape(256 , 256, 1)*255)
#         cv2.imwrite("/content/labels/"+f'{i}'+".jpg", label_np[i].reshape(256 , 256, 1)*255)

#     jaccard_indices.append(jaccard_score(label_np.flatten(), predicted_indices.flatten(), zero_division=1))
    

# avg_jaccard_index_sklearn = np.mean(jaccard_indices)*100

# print("AVERAGE JACCARD INDEX sklearn + ",avg_jaccard_index_sklearn)


### **Siamese UNet ECAM**

In [13]:
# Model
 
# The convolution block architecture consists of:
# 1. Convolution layer with kernel size 3x3 and padding 1 (in_channels, mid_channel)
# 2. Batch normalization
# 3. ReLU activation
# 4. Second convolution layer with kernel size 3x3 and padding 1 (mid_channel, out_channels)
# 5. Batch normalization
# 6. ReLU activation of the fist convolution layer with the output from second batch normalization

class ConvBlock(nn.Module):
    def __init__(self, in_channels, mid_channel, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, mid_channel, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(mid_channel)
        self.conv2 = nn.Conv2d(mid_channel, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True) # activation function (inplace modifies input directly)
    
    def forward(self, input):
        input = self.conv1(input) # first convolution layer

        # save the result of the first convolution for the last layer
        x = input

        input = self.bn1(input) # first batch normalization
        input = self.relu(input) # activation function

        input = self.conv2(input) # second convolution layer
        input = self.bn2(input)

        # add the result of the first convolution to the output of the second convolution
        input += x
        output = self.relu(input) # final activation function
        return output


# The channel attention module

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio = 16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False)
        self.sigmod = nn.Sigmoid()

    def forward(self,x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmod(out)
    

# cuild the model
class SiameseUNetECAM(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(SiameseUNetECAM, self).__init__()
        torch.nn.Module.dump_patches = True # enables a feature in PyTorch where any changes to the module hierarchy are tracked and patches are dumped to files.

        n1 = 32     # the initial number of channels of feature map
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv0_0 = ConvBlock(input_channels, filters[0], filters[0])
        self.conv1_0 = ConvBlock(filters[0], filters[1], filters[1])

        self.Up1_0 = nn.ConvTranspose2d(filters[1], filters[1], 2, stride=2)

        self.conv2_0 = ConvBlock(filters[1], filters[2], filters[2])

        self.Up2_0 = nn.ConvTranspose2d(filters[2], filters[2], 2, stride=2)

        self.conv3_0 = ConvBlock(filters[2], filters[3], filters[3])

        self.Up3_0 = nn.ConvTranspose2d(filters[3], filters[3], 2, stride=2)
        self.conv4_0 = ConvBlock(filters[3], filters[4], filters[4])

        self.Up4_0 = nn.ConvTranspose2d(filters[4], filters[4], 2, stride=2)

        self.conv0_1 = ConvBlock(filters[0] * 2 + filters[1], filters[0], filters[0])
        self.conv1_1 = ConvBlock(filters[1] * 2 + filters[2], filters[1], filters[1])
        self.Up1_1 = nn.ConvTranspose2d(filters[1], filters[1], 2, stride=2)
        self.conv2_1 = ConvBlock(filters[2] * 2 + filters[3], filters[2], filters[2])
        self.Up2_1 = nn.ConvTranspose2d(filters[2], filters[2], 2, stride=2)
        self.conv3_1 = ConvBlock(filters[3] * 2 + filters[4], filters[3], filters[3])
        self.Up3_1 = nn.ConvTranspose2d(filters[3], filters[3], 2, stride=2)

        self.conv0_2 = ConvBlock(filters[0] * 3 + filters[1], filters[0], filters[0])
        self.conv1_2 = ConvBlock(filters[1] * 3 + filters[2], filters[1], filters[1])
        self.Up1_2 = nn.ConvTranspose2d(filters[1], filters[1], 2, stride=2)
        self.conv2_2 = ConvBlock(filters[2] * 3 + filters[3], filters[2], filters[2])
        self.Up2_2 = nn.ConvTranspose2d(filters[2], filters[2], 2, stride=2)

        self.conv0_3 = ConvBlock(filters[0] * 4 + filters[1], filters[0], filters[0])
        self.conv1_3 = ConvBlock(filters[1] * 4 + filters[2], filters[1], filters[1])
        self.Up1_3 = nn.ConvTranspose2d(filters[1], filters[1], 2, stride=2)

        self.conv0_4 = ConvBlock(filters[0] * 5 + filters[1], filters[0], filters[0])

        self.ca = ChannelAttention(filters[0] * 4, ratio=16)
        self.ca1 = ChannelAttention(filters[0], ratio=16 // 4)

        self.conv_final = nn.Conv2d(filters[0] * 4, output_channels, kernel_size=1)

        # msh fahma dy beta3mel eh bas mashy ba3deen
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def forward(self, xA, xB):
        '''xA'''
        x0_0A = self.conv0_0(xA)
        x1_0A = self.conv1_0(self.pool(x0_0A))
        x2_0A = self.conv2_0(self.pool(x1_0A))
        x3_0A = self.conv3_0(self.pool(x2_0A))
        # x4_0A = self.conv4_0(self.pool(x3_0A))
        '''xB'''
        x0_0B = self.conv0_0(xB)
        x1_0B = self.conv1_0(self.pool(x0_0B))
        x2_0B = self.conv2_0(self.pool(x1_0B))
        x3_0B = self.conv3_0(self.pool(x2_0B))
        x4_0B = self.conv4_0(self.pool(x3_0B))

        x0_1 = self.conv0_1(torch.cat([x0_0A, x0_0B, self.Up1_0(x1_0B)], 1))
        x1_1 = self.conv1_1(torch.cat([x1_0A, x1_0B, self.Up2_0(x2_0B)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0A, x0_0B, x0_1, self.Up1_1(x1_1)], 1))


        x2_1 = self.conv2_1(torch.cat([x2_0A, x2_0B, self.Up3_0(x3_0B)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0A, x1_0B, x1_1, self.Up2_1(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0A, x0_0B, x0_1, x0_2, self.Up1_2(x1_2)], 1))

        x3_1 = self.conv3_1(torch.cat([x3_0A, x3_0B, self.Up4_0(x4_0B)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0A, x2_0B, x2_1, self.Up3_1(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0A, x1_0B, x1_1, x1_2, self.Up2_2(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0A, x0_0B, x0_1, x0_2, x0_3, self.Up1_3(x1_3)], 1))

        output = torch.cat([x0_1, x0_2, x0_3, x0_4], 1)

        intra = torch.sum(torch.stack((x0_1, x0_2, x0_3, x0_4)), dim=0)
        ca1 = self.ca1(intra)
        output = self.ca(output) * (output + ca1.repeat(1, 4, 1, 1))
        output = self.conv_final(output)

        return (output, )

In [14]:
# some functions and definitions for training
parameters = {
  "patch_size": 256,
  "num_gpus": 1,
  "num_workers": 8,
  "num_channel": 3,
  "epochs": 10,
  "batch_size": 8,
  "learning_rate": 1e-3,
  "loss_function": "hybrid",
  "dataset_dir": "./dataset/trainval/",
  "weight_dir": "./content/",
  "log_dir": "./log/"
}

train_set = dataloader['train']
val_set = dataloader['val']

def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def initialize_metrics():
    """Generates a dictionary of metrics with metrics as keys
       and empty lists as values

    Returns
    -------
    dict
        a dictionary of metrics

    """
    metrics = {
        'cd_losses': [],
        'cd_corrects': [],
        'cd_precisions': [],
        'cd_recalls': [],
        'cd_f1scores': [],
        'learning_rate': [],
        'jaccard_scores': []
    }

    return metrics

def set_metrics(metric_dict, cd_loss, cd_corrects, cd_report, lr, jaccard_score):
    """Updates metric dict with batch metrics

    Parameters
    ----------
    metric_dict : dict
        dict of metrics
    cd_loss : dict(?)
        loss value
    cd_corrects : dict(?)
        number of correct results (to generate accuracy
    cd_report : list
        precision, recall, f1 values

    Returns
    -------
    dict
        dict of  updated metrics


    """
    metric_dict['cd_losses'].append(cd_loss.item())
    metric_dict['cd_corrects'].append(cd_corrects.item())
    metric_dict['cd_precisions'].append(cd_report[0])
    metric_dict['cd_recalls'].append(cd_report[1])
    metric_dict['cd_f1scores'].append(cd_report[2])
    metric_dict['learning_rate'].append(lr)
    metric_dict['jaccard_scores'].append(jaccard_score)

    return metric_dict



def get_mean_metrics(metric_dict):
    """takes a dictionary of lists for metrics and returns dict of mean values

    Parameters
    ----------
    metric_dict : dict
        A dictionary of metrics

    Returns
    -------
    dict
        dict of floats that reflect mean metric value

    """
    return {k: np.mean(v) for k, v in metric_dict.items()}


def hybrid_loss(predictions, target, device):
    """Calculating the loss"""
    loss = 0

    # gamma=0, alpha=None --> CE
    focal = metrics.FocalLoss(gamma=0, alpha=None)

    for prediction in predictions:

        bce = focal(prediction, target)
        dice = metrics.dice_loss(prediction, target, device)
        loss += bce + dice

    return loss

In [15]:
# train the model


"""
Initialize experiments log
"""
logging.basicConfig(level=logging.INFO)
writer = SummaryWriter(parameters['log_dir'] + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

"""
Set up environment: define paths, download data, and set device
"""
logging.info('GPU AVAILABLE? ' + str(torch.cuda.is_available()))

seed_torch(seed=777)


"""
Load Model then define other aspects of the model
"""
logging.info('LOADING Model')
model = SiameseUNetECAM(3, 2).to(device)

criterion = hybrid_loss # loss function bce + dice
optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['learning_rate']) # Be careful when you adjust learning rate, you can refer to the linear scaling rule
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)

"""
 Set starting values
"""
best_metrics = {'cd_f1scores': -1, 'cd_recalls': -1, 'cd_precisions': -1}
logging.info('STARTING training')
total_step = -1


loss_per_epoch=[]
bce_loss_per_epoch=[]
dice_loss_per_epoch=[]
jacord_index_per_epoch=[]

loss_per_val_epoch=[]
bce_loss_per_val_epoch=[]
dice_loss_per_val_epoch=[]
jacord_index_per_val_epoch=[]

# training loop
for epoch in range(parameters['epochs']):
    train_metrics = initialize_metrics()
    val_metrics = initialize_metrics()

    lab_metrics = defaultdict(float)

    """
    Begin Training
    """
    model.train()
    logging.info('SET model mode to train!')

    batch_iteration = 0

    tbar = tqdm(train_set)
    for batch in tbar:
        tbar.set_description("epoch {} info ".format(epoch) + str(batch_iteration) + " - " + str(batch_iteration + parameters['batch_size']))
        batch_iteration = batch_iteration + parameters['batch_size']
        total_step += 1

        # load the data to the device
        before_images = batch['images'][0].to(device)
        after_images = batch['images'][1].to(device)
        labels = batch['label'].long().to(device)

        
        # Zero the gradient
        optimizer.zero_grad()

        # Get model predictions, calculate loss, backprop
        predictions = model(before_images, after_images)

        # calculate the loss
        cd_loss = criterion(predictions, labels, device)
        loss = cd_loss

        # backpropagation
        loss.backward()
        optimizer.step()

        
        predictions = predictions[-1]
        _, predictions = torch.max(predictions, 1)

        # evaluation and metrics
        jac_score = jaccard_score(labels.data.cpu().numpy().flatten(),
                                predictions.data.cpu().numpy().flatten(), 
                                zero_division=1)

        cd_corrects = (100 *
                       (predictions.squeeze().byte() == labels.squeeze().byte()).sum() /
                       (labels.size()[0] * (parameters['patch_size']**2)))

        cd_train_report = prfs(labels.data.cpu().numpy().flatten(),
                               predictions.data.cpu().numpy().flatten(),
                               average='binary',
                               zero_division=0,
                               pos_label=1)

        train_metrics = set_metrics(train_metrics,
                                    cd_loss,
                                    cd_corrects,
                                    cd_train_report,
                                    scheduler.get_last_lr(),
                                    jac_score)

        # log the batch mean metrics
        mean_train_metrics = get_mean_metrics(train_metrics)

        for k, v in mean_train_metrics.items():
            writer.add_scalars(str(k), {'train': v}, total_step)

        # clear batch variables from memory
        del before_images, after_images, labels
    
    scheduler.step()
    logging.info("EPOCH {} TRAIN METRICS".format(epoch) + str(mean_train_metrics))

    print('An epoch finished.')
    
    
writer.close()  # close tensor board
print('Done!')


INFO:root:GPU AVAILABLE? True


INFO:root:LOADING Model
INFO:root:STARTING training
INFO:root:SET model mode to train!
  logpt = F.log_softmax(input)
epoch 0 info 0 - 8:   0%|          | 0/426 [02:35<?, ?it/s]


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
