In [1]:
# IMPORTS
import os
import numpy as np
import cv2
from sklearn.metrics import jaccard_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support as prfs
import skimage.io as io
from collections import defaultdict

# UNet imports
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.nn.functional import relu

from torch.autograd import Variable

# Custom imports
# from utilities import *

# tany
from torch.utils.data import Dataset, DataLoader
import torch.utils.data
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

from tqdm import tqdm
import random
import logging
import datetime
from tensorboardX import SummaryWriter
# import metrics
import gc


In [2]:
# initialize cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## ***Deep Learning Methods***

1. UNet
2. Siamese
3. SUNet

### **UNet**

UNet Architecture
1. Encoder (Contracting Path): down sampling the input image size while depth increases

    Each Block:
    - Two 3*3 Convolutional Layers zero-padded with stride=1 Each Followed by a RELU Activation
    - Max Pooling Layer 2*2 with stride=2 (Dimension halved)(Same Depth) [‚¨á Down Sampling] 

2. Decoder

### **Siamese UNet**

1. Load the dataset using dataloaders

In [3]:
class LoadDataset(Dataset):
    def __init__(self, input_folder, transforms_list=[]):
        
        self.before_folder = os.path.join(input_folder, 'A')
        self.after_folder = os.path.join(input_folder, 'B')
        self.label_folder = os.path.join(input_folder, 'label')

        self.file_names = os.listdir(self.before_folder) # any folder msh far2a

        self.transforms = transforms_list
        
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        before_image = io.imread(os.path.join(self.before_folder, self.file_names[idx]))
        after_image = io.imread(os.path.join(self.before_folder, self.file_names[idx]))
        label = io.imread(os.path.join(self.label_folder, self.file_names[idx]))

        
        label = label.astype('float32')  # Convert to floating point to allow division
        label = label > 0
        label = label.astype(np.int64)
        label = torch.as_tensor(label, dtype=torch.float32)
        label = label.squeeze()

        if len(self.transforms) == 2:
            before_image = self.transforms[0](before_image)
            after_image = self.transforms[1](after_image)


        return {'images': (before_image, after_image), 'label': label}
    
# Define the transformations
transform = [transforms.Compose([transforms.ToTensor()]), transforms.Compose([transforms.ToTensor()])]

# Load the dataset
dataset = LoadDataset('/kaggle/input/sat-dataset/trainval', transform)

# Split the dataset into training, test, and validation sets (80, 10, 10)
train_set, test_set = train_test_split(dataset, test_size=0.2, random_state=42)
# val_set, test_set = train_test_split(temp_set, test_size=0.5, random_state=42)

# create the DataLoader
dataloader = {
    'train': DataLoader(train_set, batch_size=16, shuffle=True),
#     'val': DataLoader(val_set, batch_size=16, shuffle=False),
    'test': DataLoader(test_set, batch_size=16, shuffle=False)
}

print("DATASET LOADED")

DATASET LOADED


2. Build the Siamese model

<img src="siamese_architecture.jpg"/>


3. Train the model

4. Test the model

In [4]:
# class FocalLoss(nn.Module):
#     def __init__(self, gamma=0, alpha=None, size_average=True):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         self.alpha = alpha
#         if isinstance(alpha, (float, int)):
#             self.alpha = torch.Tensor([alpha, 1-alpha])
#         if isinstance(alpha, list):
#             self.alpha = torch.Tensor(alpha)
#         self.size_average = size_average

#     def forward(self, input, target):
#         if input.dim() > 2:
#             # N, C, H, W => N, H*W, C
#             input = input.view(input.size(0), input.size(1), -1).transpose(1, 2).contiguous()
#             # N, H*W, C => N*H*W, C
#             input = input.view(-1, input.size(2))
#         else:
#             input = input.contiguous().view(-1, input.size(1))


#         target = target.view(-1, 1)
#         logpt = F.log_softmax(input)
#         logpt = logpt.gather(1, target)
#         logpt = logpt.view(-1)
#         pt = Variable(logpt.data.exp())

#         if self.alpha is not None:
#             if self.alpha.type() != input.data.type():
#                 self.alpha = self.alpha.type_as(input.data)
#             at = self.alpha.gather(0, target.data.view(-1))
#             logpt = logpt * Variable(at)

#         loss = -1 * (1-pt)**self.gamma * logpt

#         if self.size_average:
#             return loss.mean()
#         else:
#             return loss.sum()

# # def dice_loss(logits, true, device, eps=1e-7):
# #     """Computes the S√∏rensen‚ÄìDice loss.
# #     Note that PyTorch optimizers minimize a loss. In this
# #     case, we would like to maximize the dice loss so we
# #     return the negated dice loss.
# #     Args:
# #         true: a tensor of shape [B, 1, H, W].
# #         logits: a tensor of shape [B, C, H, W]. Corresponds to
# #             the raw output or logits of the model.
# #         eps: added to the denominator for numerical stability.
# #     Returns:
# #         dice_loss: the S√∏rensen‚ÄìDice loss.
# #     """
# #     num_classes = logits.shape[1]
# #     if num_classes == 1:
# #         true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
# #         true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
# #         true_1_hot_f = true_1_hot[:, 0:1, :, :]
# #         true_1_hot_s = true_1_hot[:, 1:2, :, :]
# #         true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
# #         pos_prob = torch.sigmoid(logits)
# #         neg_prob = 1 - pos_prob
# #         probas = torch.cat([pos_prob, neg_prob], dim=1)
# #     else:
# #         true_1_hot = torch.eye(num_classes, device=device)[true.squeeze(1)]
# #         true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
# #         probas = F.softmax(logits, dim=1)
# #     true_1_hot = true_1_hot.type(logits.type())
# #     dims = (0,) + tuple(range(2, true.ndimension()))
# #     intersection = torch.sum(probas * true_1_hot, dims)
# #     cardinality = torch.sum(probas + true_1_hot, dims)
# #     dice_loss = (2. * intersection / (cardinality + eps)).mean()
# #     return (1 - dice_loss)


# # # def dice_loss(pred, target, smooth=1.):
# # #     '''
# # #      The Dice coefficient D between two sets ùê¥ and ùêµ is defined as:
# # #      D= (2√ó‚à£A‚à©B‚à£)/ (‚à£A‚à£+‚à£B‚à£)
# # #      ‚à£A‚à©B‚à£: total no of pixels in pred,gold that has +ve
# # #     '''
# # #     pred = pred.contiguous() # contiguous() is a method that is used to ensure that the tensor is stored in a contiguous block of memory.
# # #     target = target.contiguous()
    
# # #     print("Predicted:",pred)
# # #     print("Target:", target)
    

# # #     intersection = (pred * target).sum(dim=2).sum(dim=2)  # Sumation of Both Width & Height

# # #     loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))

# # #     return loss.mean()

In [5]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)):
            self.alpha = torch.Tensor([alpha, 1-alpha])
        if isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            # N,C,H,W => N,C,H*W
            input = input.view(input.size(0), input.size(1), -1)

            # N,C,H*W => N,H*W,C
            input = input.transpose(1, 2)

            # N,H*W,C => N*H*W,C
            input = input.contiguous().view(-1, input.size(2))


        target = target.view(-1, 1)
        logpt = F.log_softmax(input,dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt

        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

def dice_loss(logits, true, positive_weight=1, eps=1e-7):
    """Computes the S√∏rensen‚ÄìDice loss with weighted positive class."""
    true_1_hot = torch.eye(2, device=logits.device)[true.squeeze(1)]
    true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
    probas = F.softmax(logits, dim=1)
    true_1_hot = true_1_hot.type(logits.type())

    dims = (0,) + tuple(range(2, true.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)

    # Weighted sum of intersection and cardinality for positive class
    weighted_intersection = intersection[:, 1] * positive_weight
    weighted_cardinality = cardinality[:, 1] * positive_weight

    dice_loss = (2. * weighted_intersection / (weighted_cardinality + eps)).mean()

    return (1 - dice_loss)

### **Siamese UNet ECAM**

In [6]:
# Model
 
# The convolution block architecture consists of:
# 1. Convolution layer with kernel size 3x3 and padding 1 (in_channels, mid_channel)
# 2. Batch normalization
# 3. ReLU activation
# 4. Second convolution layer with kernel size 3x3 and padding 1 (mid_channel, out_channels)
# 5. Batch normalization
# 6. ReLU activation of the fist convolution layer with the output from second batch normalization

class ConvBlock(nn.Module):
    def __init__(self, in_channels, mid_channel, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, mid_channel, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(mid_channel)
        self.conv2 = nn.Conv2d(mid_channel, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True) # activation function (inplace modifies input directly)
    
    def forward(self, input):
        input = self.conv1(input) # first convolution layer

        # save the result of the first convolution for the last layer
        x = input

        input = self.bn1(input) # first batch normalization
        input = self.relu(input) # activation function

        input = self.conv2(input) # second convolution layer
        input = self.bn2(input)

        # add the result of the first convolution to the output of the second convolution
        input += x
        output = self.relu(input) # final activation function
        return output


# The channel attention module

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio = 16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False)
        self.sigmod = nn.Sigmoid()

    def forward(self,x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmod(out)
    

# cuild the model
class SiameseUNetECAM(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(SiameseUNetECAM, self).__init__()
        torch.nn.Module.dump_patches = True # enables a feature in PyTorch where any changes to the module hierarchy are tracked and patches are dumped to files.

        n1 = 32     # the initial number of channels of feature map
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv0_0 = ConvBlock(input_channels, filters[0], filters[0])
        self.conv1_0 = ConvBlock(filters[0], filters[1], filters[1])

        self.Up1_0 = nn.ConvTranspose2d(filters[1], filters[1], 2, stride=2)

        self.conv2_0 = ConvBlock(filters[1], filters[2], filters[2])

        self.Up2_0 = nn.ConvTranspose2d(filters[2], filters[2], 2, stride=2)

        self.conv3_0 = ConvBlock(filters[2], filters[3], filters[3])

        self.Up3_0 = nn.ConvTranspose2d(filters[3], filters[3], 2, stride=2)
        self.conv4_0 = ConvBlock(filters[3], filters[4], filters[4])

        self.Up4_0 = nn.ConvTranspose2d(filters[4], filters[4], 2, stride=2)

        self.conv0_1 = ConvBlock(filters[0] * 2 + filters[1], filters[0], filters[0])
        self.conv1_1 = ConvBlock(filters[1] * 2 + filters[2], filters[1], filters[1])
        self.Up1_1 = nn.ConvTranspose2d(filters[1], filters[1], 2, stride=2)
        self.conv2_1 = ConvBlock(filters[2] * 2 + filters[3], filters[2], filters[2])
        self.Up2_1 = nn.ConvTranspose2d(filters[2], filters[2], 2, stride=2)
        self.conv3_1 = ConvBlock(filters[3] * 2 + filters[4], filters[3], filters[3])
        self.Up3_1 = nn.ConvTranspose2d(filters[3], filters[3], 2, stride=2)

        self.conv0_2 = ConvBlock(filters[0] * 3 + filters[1], filters[0], filters[0])
        self.conv1_2 = ConvBlock(filters[1] * 3 + filters[2], filters[1], filters[1])
        self.Up1_2 = nn.ConvTranspose2d(filters[1], filters[1], 2, stride=2)
        self.conv2_2 = ConvBlock(filters[2] * 3 + filters[3], filters[2], filters[2])
        self.Up2_2 = nn.ConvTranspose2d(filters[2], filters[2], 2, stride=2)

        self.conv0_3 = ConvBlock(filters[0] * 4 + filters[1], filters[0], filters[0])
        self.conv1_3 = ConvBlock(filters[1] * 4 + filters[2], filters[1], filters[1])
        self.Up1_3 = nn.ConvTranspose2d(filters[1], filters[1], 2, stride=2)

        self.conv0_4 = ConvBlock(filters[0] * 5 + filters[1], filters[0], filters[0])

        self.ca = ChannelAttention(filters[0] * 4, ratio=16)
        self.ca1 = ChannelAttention(filters[0], ratio=16 // 4)

        self.conv_final = nn.Conv2d(filters[0] * 4, output_channels, kernel_size=1)

        # msh fahma dy beta3mel eh bas mashy ba3deen
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def forward(self, xA, xB):
        '''xA'''
        x0_0A = self.conv0_0(xA)
        x1_0A = self.conv1_0(self.pool(x0_0A))
        x2_0A = self.conv2_0(self.pool(x1_0A))
        x3_0A = self.conv3_0(self.pool(x2_0A))
        # x4_0A = self.conv4_0(self.pool(x3_0A))
        '''xB'''
        x0_0B = self.conv0_0(xB)
        x1_0B = self.conv1_0(self.pool(x0_0B))
        x2_0B = self.conv2_0(self.pool(x1_0B))
        x3_0B = self.conv3_0(self.pool(x2_0B))
        x4_0B = self.conv4_0(self.pool(x3_0B))

        x0_1 = self.conv0_1(torch.cat([x0_0A, x0_0B, self.Up1_0(x1_0B)], 1))
        x1_1 = self.conv1_1(torch.cat([x1_0A, x1_0B, self.Up2_0(x2_0B)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0A, x0_0B, x0_1, self.Up1_1(x1_1)], 1))


        x2_1 = self.conv2_1(torch.cat([x2_0A, x2_0B, self.Up3_0(x3_0B)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0A, x1_0B, x1_1, self.Up2_1(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0A, x0_0B, x0_1, x0_2, self.Up1_2(x1_2)], 1))

        x3_1 = self.conv3_1(torch.cat([x3_0A, x3_0B, self.Up4_0(x4_0B)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0A, x2_0B, x2_1, self.Up3_1(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0A, x1_0B, x1_1, x1_2, self.Up2_2(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0A, x0_0B, x0_1, x0_2, x0_3, self.Up1_3(x1_3)], 1))

        output = torch.cat([x0_1, x0_2, x0_3, x0_4], 1)

        intra = torch.sum(torch.stack((x0_1, x0_2, x0_3, x0_4)), dim=0)
        ca1 = self.ca1(intra)
        output = self.ca(output) * (output + ca1.repeat(1, 4, 1, 1))
        output = self.conv_final(output)

        return (output, )

In [7]:
# some functions and definitions for training
parameters = {
  "patch_size": 256,
  "num_gpus": 1,
  "num_workers": 8,
  "num_channel": 3,
  "epochs": 1,
  "batch_size": 16,
  "learning_rate": 1e-3,
  "loss_function": "hybrid",
  "dataset_dir": "./dataset/trainval/",
  "weight_dir": "./content/",
  "log_dir": "./log/"
}

train_set = dataloader['train']
# val_set = dataloader['val']
test_set = dataloader['test']

def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def initialize_metrics():
    """Generates a dictionary of metrics with metrics as keys
       and empty lists as values

    Returns
    -------
    dict
        a dictionary of metrics

    """
    metrics = {
        'cd_losses': [],
        'cd_corrects': [],
        'cd_precisions': [],
        'cd_recalls': [],
        'cd_f1scores': [],
        'learning_rate': [],
        'jaccard_scores': []
    }

    return metrics

def set_metrics(metric_dict, cd_loss, cd_corrects, cd_report, lr, jaccard_score):
    """Updates metric dict with batch metrics

    Parameters
    ----------
    metric_dict : dict
        dict of metrics
    cd_loss : dict(?)
        loss value
    cd_corrects : dict(?)
        number of correct results (to generate accuracy
    cd_report : list
        precision, recall, f1 values

    Returns
    -------
    dict
        dict of  updated metrics


    """
    metric_dict['cd_losses'].append(cd_loss.item())
    metric_dict['cd_corrects'].append(cd_corrects.item())
    metric_dict['cd_precisions'].append(cd_report[0])
    metric_dict['cd_recalls'].append(cd_report[1])
    metric_dict['cd_f1scores'].append(cd_report[2])
    metric_dict['learning_rate'].append(lr)
    metric_dict['jaccard_scores'].append(jaccard_score)

    return metric_dict



def get_mean_metrics(metric_dict):
    """takes a dictionary of lists for metrics and returns dict of mean values

    Parameters
    ----------
    metric_dict : dict
        A dictionary of metrics

    Returns
    -------
    dict
        dict of floats that reflect mean metric value

    """
    return {k: np.mean(v) for k, v in metric_dict.items()}


def hybrid_loss(predictions, target):
    """Calculating the loss"""
    loss = 0

    # gamma=0, alpha=None --> CE
    focal = FocalLoss()

    for prediction in predictions:

        bce = focal(prediction, target)
        dice = dice_loss(prediction, target)
        loss += bce + dice

    return loss


def jaccard_index(pred, target, smooth=1.0):
    '''
    Jaccard Index (IoU) between two sets ùê¥ and ùêµ is defined as:
    J(A, B) = 1 - (‚à£A‚à©B‚à£ / ‚à£A‚à™B‚à£)
    Where:
    ‚à£A‚à©B‚à£: Intersection of sets A and B
    ‚à£A‚à™B‚à£: Union of sets A and B
    '''
    pred = pred.contiguous() 
    target = target.contiguous() 

    intersection = (pred * target).sum(dim=2).sum(dim=2)  
    union = pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) - intersection

    IOU = ((intersection + smooth) / (union + smooth))
    
    return 1- IOU.mean()


def calc_loss(predictions, labels, metrics, bce_weight=0.5):
    # Binary Cress Entropy
    # In PyTorch, binary_cross_entropy_with_logits is a loss function that combines a sigmoid activation function and binary cross-entropy loss.
    # However, it doesn't explicitly apply the sigmoid function to the input. Instead, it expects the input to be logits, which are the raw outputs of a model without applying any activation function.
    for prediction, label in zip(predictions, labels):
        
#         print("Prediction:", prediction)
#         print("Label:", label)
#         print(label.size(0))
        
#         print(type(label))
#         print(type(prediction))
        prediction = F.sigmoid(prediction)
    
        dice = dice_loss(prediction, label)

        bce = F.binary_cross_entropy_with_logits(prediction.float(), label.float())


        # Custom Loss function that combines bce & dice losses
        # Binary Cross-Entropy (BCE) Loss: BCE loss aims to minimize the difference between the predicted probability distribution and the ground truth binary labels.
        # It penalizes deviations from the true binary labels, typically encouraging the model to output probabilities that align well with the ground truth.
        # Dice Loss: Dice loss aims to maximize the overlap between the predicted segmentation mask and the ground truth mask.
        # It penalizes deviations from the true segmentation mask, typically encouraging the model to produce segmentations that align well with the ground truth boundaries.
        loss = bce * bce_weight + dice * (1 - bce_weight)

        jac_index=jaccard_index(prediction, label)


        metrics['bce'] += bce.data.cpu().numpy() * label.size(0)
        metrics['dice'] += dice.data.cpu().numpy() * label.size(0)
        metrics['loss'] += loss.data.cpu().numpy() * label.size(0)
        metrics['jaccrod_index']+=jac_index.data.cpu().numpy() * label.size(0)

    return loss, metrics


In [8]:
def evaluate(model):
    model.eval()
    
    jaccard_scores = []
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    with torch.no_grad():
        tbar = tqdm(test_set)
        
        for batch in tbar:
            # load the data to the device
            before_images = batch['images'][0].to(device)
            after_images = batch['images'][1].to(device)
            labels = batch['label'].long().to(device)
            
            
            predictions = model(before_images, after_images)
            predictions = predictions[0]
            _, predictions = torch.max(predictions, 1)
            
            print(len(predictions))
            
            for prediction, label in zip(predictions, labels):
                predicted = prediction.cpu().numpy().astype(np.uint8)
                
                ground_truth = label.cpu().numpy().astype(np.uint8)
                
                # calculate jaccard score
                
                jaccard_scores.append(jaccard_score(predicted.flatten(), ground_truth.flatten(), zero_division=1))
            
            del before_images, after_images, labels
            
        jaccard_mean = np.mean(jaccard_scores)
            
        print("Test Jaccard Mean:", jaccard_mean)
        
        return jaccard_mean
        

In [9]:
# train the model


"""
Initialize experiments log
"""
# logging.basicConfig(level=logging.INFO)
# writer = SummaryWriter(parameters['log_dir'] + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

"""
Set up environment: define paths, download data, and set device
"""
# logging.info('GPU AVAILABLE? ' + str(torch.cuda.is_available()))

seed_torch(seed=777)


"""
Load Model then define other aspects of the model
"""
# logging.info('LOADING Model')
model = SiameseUNetECAM(3, 2).to(device)

# criterion = hybrid_loss # loss function bce + dice
criterion = hybrid_loss
optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['learning_rate']) # Be careful when you adjust learning rate, you can refer to the linear scaling rule
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)

"""
 Set starting values
"""
# best_metrics = {'cd_f1scores': -1, 'cd_recalls': -1, 'cd_precisions': -1}
# logging.info('STARTING training')
total_step = -1

validation_jacc=float('-inf')

# training loop
for epoch in range(parameters['epochs']):
    epoch_loss = []
#     train_metrics = initialize_metrics()
#     val_metrics = initialize_metrics()
    
#     metrics = defaultdict(float)

    """
    Begin Training
    """
    model.train()
#     logging.info('SET model mode to train!')

    batch_iteration = 0

    tbar = tqdm(train_set)
    for batch in tbar:
        tbar.set_description("epoch {} info ".format(epoch) + str(batch_iteration) + " - " + str(batch_iteration + parameters['batch_size']))
        batch_iteration = batch_iteration + parameters['batch_size']
        total_step += 1

        # load the data to the device
        before_images = batch['images'][0].to(device)
        after_images = batch['images'][1].to(device)
        labels = batch['label'].long().to(device)

        
        # Zero the gradient
        optimizer.zero_grad()

        # Get model predictions, calculate loss, backprop
        predictions = model(before_images, after_images)

        # calculate the loss
        cd_loss = criterion(predictions, labels)
        loss = cd_loss

        # backpropagation
        loss.backward()
        optimizer.step()        
        
        predictions = predictions[0]
#         print(len(predictions))
#         print(len(labels))
        _, predictions = torch.max(predictions, 1)
    
        epoch_loss.append(loss.item())



        
#         print(len(predictions))
#         print(len(labels))

        # evaluation and metrics
#         jac_score = jaccard_score(labels.data.cpu().numpy().flatten(),
#                                 predictions.data.cpu().numpy().flatten(), 
#                                 zero_division=1)

#         jac_score = metrics['jaccrod_index'] 

#         cd_corrects = (100 *
#                        (predictions.squeeze().byte() == labels.squeeze().byte()).sum() /
#                        (labels.size()[0] * (parameters['patch_size']**2)))

#         cd_train_report = prfs(labels.data.cpu().numpy().flatten(),
#                                predictions.data.cpu().numpy().flatten(),
#                                average='binary',
#                                zero_division=0,
#                                pos_label=1)

#         train_metrics = set_metrics(train_metrics,
#                                     cd_loss,
#                                     cd_corrects,
#                                     cd_train_report,
#                                     scheduler.get_last_lr(),
#                                     0)

        # log the batch mean metrics
#         mean_train_metrics = get_mean_metrics(train_metrics)

#         for k, v in mean_train_metrics.items():
#             writer.add_scalars(str(k), {'train': v}, total_step)

        # clear batch variables from memory
        del before_images, after_images, labels
    
#     scheduler.step()
    
    current_loss=sum(epoch_loss)/len(epoch_loss)
    gc.collect()
    
    jaccard_test = evaluate(model)

    # print("JACCARD INDEX EVALUATION ",m_jaccard)
#     print("TOTAL LOSS EVALUATION ",m_loss)
    scheduler.step()
    torch.save(model.state_dict(), f"/kaggle/working/models/pretrained_epoch_{epoch_index}.pth")
    if validation_jacc<m_jaccard:

        torch.save(model.state_dict(), f"/kaggle/working/models/best_pretrained_post_aug_pretrained.pth")
        validation_jacc=m_jaccard

#     logging.info("EPOCH {} TRAIN METRICS".format(epoch) + str(mean_train_metrics))
#     print("EPOCH {} TRAIN METRICS".format(epoch) + str(mean_train_metrics))

    print('An epoch finished.')
    
    
# writer.close()  # close tensor board
print('Done!')


epoch 0 info 3888 - 3904: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 244/244 [06:53<00:00,  1.70s/it]
  0%|          | 0/61 [00:00<?, ?it/s]

16


  2%|‚ñè         | 1/61 [00:01<01:07,  1.13s/it]

16


  3%|‚ñé         | 2/61 [00:02<01:06,  1.13s/it]

16


  5%|‚ñç         | 3/61 [00:03<01:05,  1.13s/it]

16


  7%|‚ñã         | 4/61 [00:04<01:03,  1.12s/it]

16


  8%|‚ñä         | 5/61 [00:05<01:02,  1.12s/it]

16


 10%|‚ñâ         | 6/61 [00:06<01:01,  1.12s/it]

16


 11%|‚ñà‚ñè        | 7/61 [00:07<01:00,  1.12s/it]

16


 13%|‚ñà‚ñé        | 8/61 [00:08<00:59,  1.12s/it]

16


 15%|‚ñà‚ñç        | 9/61 [00:10<00:58,  1.12s/it]

16


 16%|‚ñà‚ñã        | 10/61 [00:11<00:56,  1.11s/it]

16


 18%|‚ñà‚ñä        | 11/61 [00:12<00:55,  1.11s/it]

16


 20%|‚ñà‚ñâ        | 12/61 [00:13<00:54,  1.12s/it]

16


 21%|‚ñà‚ñà‚ñè       | 13/61 [00:14<00:53,  1.11s/it]

16


 23%|‚ñà‚ñà‚ñé       | 14/61 [00:15<00:52,  1.12s/it]

16


 25%|‚ñà‚ñà‚ñç       | 15/61 [00:16<00:51,  1.12s/it]

16


 26%|‚ñà‚ñà‚ñå       | 16/61 [00:17<00:50,  1.11s/it]

16


 28%|‚ñà‚ñà‚ñä       | 17/61 [00:18<00:48,  1.11s/it]

16


 30%|‚ñà‚ñà‚ñâ       | 18/61 [00:20<00:47,  1.11s/it]

16


 31%|‚ñà‚ñà‚ñà       | 19/61 [00:21<00:46,  1.11s/it]

16


 33%|‚ñà‚ñà‚ñà‚ñé      | 20/61 [00:22<00:45,  1.12s/it]

16


 34%|‚ñà‚ñà‚ñà‚ñç      | 21/61 [00:23<00:44,  1.12s/it]

16


 36%|‚ñà‚ñà‚ñà‚ñå      | 22/61 [00:24<00:43,  1.11s/it]

16


 38%|‚ñà‚ñà‚ñà‚ñä      | 23/61 [00:25<00:42,  1.12s/it]

16


 39%|‚ñà‚ñà‚ñà‚ñâ      | 24/61 [00:26<00:41,  1.12s/it]

16


 41%|‚ñà‚ñà‚ñà‚ñà      | 25/61 [00:27<00:39,  1.11s/it]

16


 43%|‚ñà‚ñà‚ñà‚ñà‚ñé     | 26/61 [00:28<00:38,  1.11s/it]

16


 44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 27/61 [00:30<00:37,  1.11s/it]

16


 46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 28/61 [00:31<00:36,  1.11s/it]

16


 48%|‚ñà‚ñà‚ñà‚ñà‚ñä     | 29/61 [00:32<00:35,  1.11s/it]

16


 49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 30/61 [00:33<00:34,  1.11s/it]

16


 51%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 31/61 [00:34<00:33,  1.11s/it]

16


 52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 32/61 [00:35<00:32,  1.12s/it]

16


 54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 33/61 [00:36<00:31,  1.12s/it]

16


 56%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 34/61 [00:37<00:30,  1.12s/it]

16


 57%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã    | 35/61 [00:39<00:28,  1.11s/it]

16


 59%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 36/61 [00:40<00:27,  1.11s/it]

16


 61%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 37/61 [00:41<00:26,  1.11s/it]

16


 62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 38/61 [00:42<00:25,  1.11s/it]

16


 64%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 39/61 [00:43<00:24,  1.11s/it]

16


 66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 40/61 [00:44<00:23,  1.11s/it]

16


 67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 41/61 [00:45<00:22,  1.11s/it]

16


 69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 42/61 [00:46<00:21,  1.11s/it]

16


 70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 43/61 [00:47<00:19,  1.11s/it]

16


 72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 44/61 [00:48<00:18,  1.10s/it]

16


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç  | 45/61 [00:50<00:17,  1.11s/it]

16


 75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 46/61 [00:51<00:16,  1.11s/it]

16


 77%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã  | 47/61 [00:52<00:15,  1.11s/it]

16


 79%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä  | 48/61 [00:53<00:14,  1.12s/it]

16


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 49/61 [00:54<00:13,  1.12s/it]

16


 82%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 50/61 [00:55<00:12,  1.13s/it]

16


 84%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 51/61 [00:56<00:11,  1.12s/it]

16


 85%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 52/61 [00:57<00:10,  1.12s/it]

16


 87%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã | 53/61 [00:59<00:08,  1.12s/it]

16


 89%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä | 54/61 [01:00<00:07,  1.12s/it]

16


 90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 55/61 [01:01<00:06,  1.11s/it]

16


 92%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè| 56/61 [01:02<00:05,  1.11s/it]

16


 93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 57/61 [01:03<00:04,  1.11s/it]

16


 95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 58/61 [01:04<00:03,  1.11s/it]

16


 97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 59/61 [01:05<00:02,  1.11s/it]

16


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 60/61 [01:06<00:01,  1.11s/it]

14


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 61/61 [01:07<00:00,  1.11s/it]


Test Jaccard Mean: 0.01430858270003216


NameError: name 'epoch_index' is not defined