In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan  5 15:10:40 2022

@author: mostafa3_local  (Alireza Mostafavi)
"""
# A network is developed for pothole semantic segmentation and
# depth estimation. First of all, the model is trsined for
# pothole segmentation and the trianed model for segmentaion
# is trianed for pothole depth estimation. By this approach, the best results
# are optained for both segmentation and depth estimation.
# The input for segmentaion is depth map. However, the input for
# segmentation can be changed to RGB image but the peformance will
# be reduced. The input for depth estimatin is RGB image.

# For training depth estimation network, first of all I pretrained it for segmentation.
# The resason is that all monocular depth estimation networks have used a powerfull backbone
# such as DenseNet or Resnet (DenseNet and Resnet are image classification algorithm).
# However we did not want to use these kind of heavy network with over 40M parameters.
# As a result I decided to create my own backbone. So I fisrt of all pretrain the network for
# image segmentation and then continue training for depth estimation.

#%% importing data
import os
from progress.bar import Bar
import numpy as np
import cv2
# create npy from the dataset
TRAIN_RGB_PATH = './data/train/rgb/'
TRAIN_DEPTH_PATH = './data/train/depths/'
TEST_DEPTH_PATH = './data/test/depths/'
TEST_RGB_PATH = './data/test/rgb/'

rgb_train = [file for file in os.listdir(TRAIN_RGB_PATH) if file.endswith('JPG')]
depth_train = os.listdir(TRAIN_DEPTH_PATH)
depth_test = os.listdir(TEST_DEPTH_PATH)
rgb_test = os.listdir(TEST_RGB_PATH)

# make sure the length of the dataset is the same
assert len(rgb_train) == len(depth_train), 'The length of the dataset is not the same, {} rgb vs {} depths'.format(len(rgb_train), len(depth_train))
assert len(rgb_test) == len(depth_test), "The length of the test dataset is not equal to the length of the test dataset"
# pickle the data using numpy
train_rgb_npy = []
train_depth_npy = []
test_depth_npy = []
test_rgb_npy = []

# create a progress bar
progress_bar = Bar('Processing', max=len(rgb_train)+len(rgb_test)+len(depth_train)+len(depth_test))
# npy files do not exist, lets create them here. This will take a while.
if not os.path.exists('./data/npy/train_rgb.npy'):
    for file_name in rgb_train:
        img = cv2.imread(TRAIN_RGB_PATH + file_name)
        train_rgb_npy.append(img)
        progress_bar.next()

    for file_name in depth_train:
        img = cv2.imread(TRAIN_DEPTH_PATH + file_name, cv2.IMREAD_GRAYSCALE)
        train_depth_npy.append(img)
        progress_bar.next()

    for file_name in depth_test:
        img = cv2.imread(TEST_DEPTH_PATH + file_name, cv2.IMREAD_GRAYSCALE)
        test_depth_npy.append(img)
        progress_bar.next()

    for file_name in rgb_test:
        img = cv2.imread(TEST_RGB_PATH + file_name)
        test_rgb_npy.append(img)
        progress_bar.next()
    progress_bar.finish()
    # save the npy files
    if not os.path.exists("./data/npy"):
        os.makedirs("./data/npy")
    np.save('./data/npy/train_rgb.npy', train_rgb_npy)
    np.save('./data/npy/train_depth.npy', train_depth_npy)
    np.save('./data/npy/test_depth.npy', test_depth_npy)
    np.save('./data/npy/test_rgb.npy', test_rgb_npy)



if os.path.exists('./data/npy/train_rgb.npy'):
    test_rgb = np.load('data/npy/test_rgb.npy')
    test_depth = np.load('data/npy/test_depth.npy')
    train_rgb = np.load('data/npy/train_rgb.npy')
    train_depth = np.load('data/npy/train_depth.npy')
    # train_label = np.load('/home/mostafa3_local/Documents/data/train_label.npy')
    # test_label = np.load('/home/mostafa3_local/Documents/data/test_label.npy')
    # train_label = train_label.reshape(420, 400, 400, 1)



import matplotlib.pyplot as plt


#indices = np.random.permutation(test_rgb.shape[0])
indices = range(test_rgb.shape[0])
test_idx, validation_idx = indices[:100], indices[100:]
test_rgb, validation_rgb = test_rgb[test_idx,:], test_rgb[validation_idx,:]
test_depth, validation_depth = test_depth[test_idx,:], test_depth[validation_idx,:]
test_label, validation_label = test_label[test_idx,:], test_label[validation_idx,:]

test_label = test_label.reshape(test_label.shape[0],400,400,1)
validation_label = validation_label.reshape(validation_label.shape[0],400,400,1)

plt.imshow(test_rgb[20])
plt.imshow(test_depth[20])
plt.imshow(test_label[20])

#%% preprocessing
import math
batch_size = 20
total_samples = len(train_rgb)
n_iterations = math.ceil(total_samples/batch_size)
print(n_iterations)
validation_iteration = validation_rgb.shape[0]/batch_size 
print(validation_iteration)
raise Exception('stop')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchsummary import summary

from torch.optim import lr_scheduler
import time

#pip install matplotlib
import matplotlib.pyplot as plt
# Device configuration
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

cuda = torch.device('cuda')
cuda0 = torch.device('cuda:0')
device = torch.device('cuda:2')

print(device)

train_rgb_tns = torch.from_numpy(np.transpose(train_rgb,(0,3,1,2)))
train_depth_tns = torch.from_numpy(np.transpose(train_depth,(0,3,1,2)))
train_label_tns = torch.from_numpy(np.transpose(train_label,(0,3,1,2)))

valid_rgb_tns = torch.from_numpy(np.transpose(validation_rgb,(0,3,1,2)))
valid_depth_tns = torch.from_numpy(np.transpose(validation_depth,(0,3,1,2)))
valid_label_tns = torch.from_numpy(np.transpose(validation_label,(0,3,1,2)))

test_rgb_tns = torch.from_numpy(np.transpose(test_rgb,(0,3,1,2)))
test_depth_tns = torch.from_numpy(np.transpose(test_depth,(0,3,1,2)))
test_label_tns = torch.from_numpy(np.transpose(test_label,(0,3,1,2)))
#%% learning rate scheduling

from torch.optim.lr_scheduler import _LRScheduler

class PolynomialLRDecay(_LRScheduler):
    """Polynomial learning rate decay until step reach to max_decay_step
    
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        max_decay_steps: after this step, we stop decreasing learning rate
        end_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value
        power: The power of the polynomial.
    """
    
    def __init__(self, optimizer, max_decay_steps, end_learning_rate=0.0001, power=1.0):
        if max_decay_steps <= 1.:
            raise ValueError('max_decay_steps should be greater than 1.')
        self.max_decay_steps = max_decay_steps
        self.end_learning_rate = end_learning_rate
        self.power = power
        self.last_step = 0
        super().__init__(optimizer)
        
    def get_lr(self):
        if self.last_step > self.max_decay_steps:
            return [self.end_learning_rate for _ in self.base_lrs]

        return [(base_lr - self.end_learning_rate) * 
                ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 
                self.end_learning_rate for base_lr in self.base_lrs]
    
    def step(self, step=None):
        if step is None:
            step = self.last_step + 1
        self.last_step = step if step != 0 else 1
        if self.last_step <= self.max_decay_steps:
            decay_lrs = [(base_lr - self.end_learning_rate) * 
                         ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 
                         self.end_learning_rate for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, decay_lrs):
                param_group['lr'] = lr


#%% defining depthwise separable convolutions

class depthwise_separable_conv(nn.Module):
 def __init__(self, nin, nout,n_stride,n_padding): 
   super(depthwise_separable_conv, self).__init__() 
   kernels_per_layer =1
   self.depthwise = nn.Conv2d(nin, nin * kernels_per_layer, kernel_size=3, padding=n_padding, groups=nin,stride = n_stride) 
   self.pointwise = nn.Conv2d(nin * kernels_per_layer, nout, kernel_size=1) 
   self.bn = nn.BatchNorm2d(int(nout))
   self.relu = nn.ReLU(inplace=True)
   
 def forward(self, x): 
   out = self.depthwise(x)
   out = self.relu(out)
   out = self.pointwise(out) 
   out = self.bn(out)
   out = self.relu(out)
   return out

class depthwise_separable_conv_sig(nn.Module):
 def __init__(self, nin, nout,n_stride,n_padding): 
   super(depthwise_separable_conv_sig, self).__init__() 
   kernels_per_layer =1
   self.depthwise = nn.Conv2d(nin, nin * kernels_per_layer, kernel_size=3, padding=n_padding, groups=nin,stride = n_stride) 
   self.pointwise = nn.Conv2d(nin * kernels_per_layer, nout, kernel_size=1) 
   self.bn = nn.BatchNorm2d(int(nout))
   self.relu = nn.ReLU(inplace=True)

 def forward(self, x): 
   out = self.depthwise(x)
   out = self.relu(out)
   out = self.pointwise(out) 
   out = self.bn(out)
   out = torch.sigmoid(out)
   return out


#%% Attention based on Atrous Special Pooling Pyramid
class Attention(nn.Module):
    def __init__(self,n_input_channels):
        super(Attention, self).__init__()

        self.fc = nn.Sequential(nn.Linear(n_input_channels, n_input_channels // 4, 1),
                               nn.ReLU(),
                               nn.Linear(n_input_channels // 4, n_input_channels, 1))

        self.PW = nn.Conv2d(in_channels =n_input_channels, out_channels = n_input_channels//2, kernel_size = 1, stride=1, padding='same')
        self.dilated_conv1 = nn.Sequential(nn.Conv2d(n_input_channels,n_input_channels, kernel_size=5, stride=1, padding='same',dilation = 2, groups=n_input_channels),
                                           nn.Conv2d(n_input_channels, n_input_channels//2, kernel_size=1),
                                           nn.ReLU()) 
        self.dilated_conv2 = nn.Sequential(nn.Conv2d(n_input_channels,n_input_channels, kernel_size=5, stride=1, padding='same',dilation = 4, groups=n_input_channels),
                                           nn.Conv2d(n_input_channels, n_input_channels//2, kernel_size=1),
                                           nn.ReLU()) 
        self.dilated_conv3 = nn.Sequential(nn.Conv2d(n_input_channels,n_input_channels, kernel_size=5, stride=1, padding='same',dilation = 8, groups=n_input_channels),
                                           nn.Conv2d(n_input_channels, n_input_channels//2, kernel_size=1),
                                           nn.ReLU()) 
        
        self.conv1by1 = nn.Conv2d(in_channels =n_input_channels*2, out_channels = n_input_channels, kernel_size = 1, stride=1, padding='same')

    def forward(self, input):
          dimention = input.size()
          avg_pool = F.adaptive_avg_pool2d(input, (1, 1))
          avg_pool = torch.flatten(avg_pool, 1)

          max_pool = F.adaptive_max_pool2d(input, (1, 1))
          max_pool = torch.flatten(max_pool, 1)

          max_pool = self.fc(max_pool)
          avg_pool = self.fc(avg_pool)

          channel_attention = torch.sigmoid(avg_pool + max_pool)
          channel_attention = input*channel_attention.view(dimention[0],dimention[1],1,1)

          dilated_conv1 = self.dilated_conv1(input)
          dilated_conv2 = self.dilated_conv2(input)
          dilated_conv3 = self.dilated_conv3(input)
          pw = self.PW(input)
          
          spatial_attention = torch.cat((dilated_conv1, dilated_conv2,dilated_conv3,pw), 1)
          spatial_attention = torch.sigmoid(self.conv1by1(spatial_attention)) 
          
          attention = torch.cat((channel_attention,spatial_attention),1)
          return attention

#%% Attention based on Dual attention module ( this performs better) and building the model
class Attention(nn.Module):
    def __init__(self,n_input_channels):
        super(Attention, self).__init__()

        self.fc = nn.Sequential(nn.Linear(n_input_channels, n_input_channels // 4, 1),
                               nn.ReLU(),
                               nn.Linear(n_input_channels // 4, n_input_channels, 1))

        #self.dilated_conv1 = dilated_separable_conv(n_input_channels,1,n_input_channels//2,4)
        #self.dilated_conv2 = dilated_separable_conv(n_input_channels,1,n_input_channels//2,8)
        #self.dilated_conv3 = dilated_separable_conv(n_input_channels,1,n_input_channels//2,12)
        self.dilated_conv1 = nn.Sequential(nn.Conv2d(n_input_channels,n_input_channels, kernel_size=5, stride=1, padding='same',dilation = 4, groups=n_input_channels),
                                           nn.Conv2d(n_input_channels, n_input_channels//2, kernel_size=1)) 
        self.dilated_conv2 = nn.Sequential(nn.Conv2d(n_input_channels,n_input_channels, kernel_size=5, stride=1, padding='same',dilation = 8, groups=n_input_channels),
                                           nn.Conv2d(n_input_channels, n_input_channels//2, kernel_size=1)) 
        self.dilated_conv3 = nn.Sequential(nn.Conv2d(n_input_channels,n_input_channels, kernel_size=5, stride=1, padding='same',dilation = 12, groups=n_input_channels),
                                           nn.Conv2d(n_input_channels, n_input_channels//2, kernel_size=1)) 
        
        
        self.conv1by1 = nn.Conv2d(in_channels =n_input_channels*3//2, out_channels = n_input_channels, kernel_size = 1, stride=1, padding='same')

    def forward(self, input):
          dimention = input.size()
          avg_pool = F.adaptive_avg_pool2d(input, (1, 1))
          avg_pool = torch.flatten(avg_pool, 1)

          max_pool = F.adaptive_max_pool2d(input, (1, 1))
          max_pool = torch.flatten(max_pool, 1)

          max_pool = self.fc(max_pool)
          avg_pool = self.fc(avg_pool)

          channel_attention = torch.sigmoid(avg_pool + max_pool)
          channel_attention = input*channel_attention.view(dimention[0],dimention[1],1,1)

          dilated_conv1 = self.dilated_conv1(channel_attention)
          dilated_conv2 = self.dilated_conv2(channel_attention)
          dilated_conv3 = self.dilated_conv3(channel_attention)

          spatial_attention = torch.cat((dilated_conv1, dilated_conv2,dilated_conv3), 1)
          spatial_attention = torch.sigmoid(self.conv1by1(spatial_attention)) 

          Dual_attention = channel_attention*spatial_attention
          return Dual_attention



class NET(nn.Module):
    def __init__(self,features=[16, 32, 64, 128,256]):
        super(NET, self).__init__()
        self.entry = nn.Sequential(
        nn.Conv2d(3, 3*2, kernel_size=7, padding=3, groups=3,stride =2), 
        nn.Conv2d(3*2, features[0], kernel_size=1),
        nn.ReLU()) 
        
        self.dropout = nn.Dropout(0.2)
        
        self.down1 = depthwise_separable_conv(features[0],features[1],1,1)
        self.down2 = depthwise_separable_conv(features[1]+features[0],features[2],2,1)
        self.down3 = depthwise_separable_conv(features[2],features[2],1,1)
        self.down4 = depthwise_separable_conv(features[2]*2,features[3],2,1)
        self.down5 = depthwise_separable_conv(features[3],features[3],1,1)
        self.down6 = depthwise_separable_conv(features[3]*2,features[4],2,1)
        self.mid1 = depthwise_separable_conv(features[4],features[4],1,1)
        self.Attention1 = Attention(features[4])
        self.Conv1 =  nn.Conv2d(in_channels =features[4]*2, out_channels = features[4], kernel_size = 1, stride=1, padding='same')
        self.mid2 = depthwise_separable_conv(features[4],features[4],1,1)
        
        self.mid1_1 = depthwise_separable_conv(features[4],features[4],1,1)
        self.Attention1_1 = Attention(features[4])
        self.Conv1_1 =  nn.Conv2d(in_channels =features[4]*2, out_channels = features[4], kernel_size = 1, stride=1, padding='same')
        self.mid2_1 = depthwise_separable_conv(features[4],features[4],1,1)
        
        self.mid1_2 = depthwise_separable_conv(features[4],features[4],1,1)
        self.Attention1_2 = Attention(features[4])
        self.Conv1_2 =  nn.Conv2d(in_channels =features[4]*2, out_channels = features[4], kernel_size = 1, stride=1, padding='same')
        self.mid2_2 = depthwise_separable_conv(features[4],features[4],1,1)

        self.UPsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=None)
        self.UP1 = depthwise_separable_conv(features[4],features[3],1,1)
        self.UP2 = depthwise_separable_conv(features[3]*2,features[3],1,1)
        self.Attention2 = Attention(features[3])
        self.Conv2 =  nn.Conv2d(in_channels =features[3]*2, out_channels = features[3], kernel_size = 1, stride=1, padding='same')
        
        self.UPsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=None)
        self.UP3 = depthwise_separable_conv(features[3],features[2],1,1)
        self.UP4 = depthwise_separable_conv(features[2]*2,features[2],1,1)
        self.Attention3 = Attention(features[2])
        self.Conv3 =  nn.Conv2d(in_channels =features[2]*2, out_channels = features[2], kernel_size = 1, stride=1, padding='same')

        self.UPsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=None)
        self.UP5 = depthwise_separable_conv(features[2],features[1],1,1)
        self.UP6 = depthwise_separable_conv(features[1]*2,features[1],1,1)
        self.Attention4 = Attention(features[1])
        self.Conv4 =  nn.Conv2d(in_channels =features[1]*2, out_channels = features[1], kernel_size = 1, stride=1, padding='same')
        
        self.UPsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=None)
        self.UP7 = depthwise_separable_conv(features[1],features[0],1,1)
        self.UP8 = depthwise_separable_conv_sig(features[0],1,1,1)
        #self.UP9 = depthwise_separable_conv_sig(4,1,1,1)

  
    def forward(self, input):           #H*W
        x0 = self.entry(input)          #H/2*W/2
        x1 = self.down1(x0)             #H/2*W/2
        x2 = torch.cat((x0, x1), 1)
        x2 = self.down2(x2)             #H/4*W/4
        x2 = self.dropout(x2)
        x3 = self.down3(x2)             #H/4*W/4
        x4 = torch.cat((x2, x3), 1)
        x4 = self.down4(x4)             #H/8*W/8
        x4 = self.dropout(x4)
        x5 = self.down5(x4)             #H/8*W/8
        x6 = torch.cat((x4, x5), 1)
        x6 = self.down6(x6)             #H/16*W/16
        x6 = self.dropout(x6)
        mid1 = self.mid1(x6)            #H/16*W/16
        attention1 = self.Attention1(x6)
        conv1 = torch.relu(self.Conv1(torch.cat((attention1,mid1),1)))
        mid2 = self.mid2(conv1)                     #H/16*W/16
        
        mid1_1 = self.mid1_1(mid2)                  #H/16*W/16
        attention1_1 = self.Attention1_1(mid2)
        conv1_1 = torch.relu(self.Conv1_1(torch.cat((attention1_1,mid1_1),1)))
        mid2_1 = self.mid2_1(conv1_1)               #H/16*W/16
        
        
        mid1_2 = self.mid1_2(mid2_1)                  #H/16*W/16
        attention1_2 = self.Attention1_2(mid2_1)
        conv1_2 = torch.relu(self.Conv1_2(torch.cat((attention1_2,mid1_2),1)))
        mid2_2 = self.mid2_2(conv1_2)               #H/16*W/16
        
        upsample1 = self.UPsample1(mid2_2)          #H/8*W/8
        up1 = self.UP1(upsample1)                   #H/8*W/8
        cat1 = torch.cat((up1,x5),1)    
        up2 = self.UP2(cat1)                        #H/8*W/8
        attention2 = self.Attention2(up1)
        conv2 = torch.relu(self.Conv2(torch.cat((attention2,up2),1)))
        
        upsample2 = self.UPsample2(conv2)           #H/4*W/4
        up3 = self.UP3(upsample2)                   #H/4*W/4
        cat2 = torch.cat((up3,x3),1)    
        up4 = self.UP4(cat2)                        #H/4*W/4
        attention3 = self.Attention3(up3)
        conv3 = torch.relu(self.Conv3(torch.cat((attention3,up4),1)))
        
        upsample3 = self.UPsample3(conv3)            #H/2*W/2
        up5 = self.UP5(upsample3)                    #H/2*W/2
        cat3 = torch.cat((up5,x1),1)    
        up6 = self.UP6(cat3)                         #H/2*W/2
        attention4 = self.Attention4(up5)
        conv4 = torch.relu(self.Conv4(torch.cat((attention4,up6),1)))
        
        upsample4 = self.UPsample4(conv4)            #H*W
        up7 = self.UP7(upsample4)                    #H*W
        up8 = self.UP8(up7)                          #H*W
        #up9 = self.UP9(up8)       #H*W

        return up8

#torch.cuda.empty_cache()
mynetwork = NET().to(device)
#summary(mynetwork,(3,400,400), device = device)
#%%
net = NET().to('cuda')
summary(net,(3,400,400), device = 'cuda')

#%% Loss and optimizer
class IoU(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoU, self).__init__()

    def forward(self, inputs, targets):
        #mIoU=0
        #iou=0
        epsilon = 1e-7
        
        
        tp = (targets * inputs).sum().to(torch.float32)
        #tn = ((1 - targets) * (1 - inputs)).sum().to(torch.float32)
        fp = ((1 - targets) * inputs).sum().to(torch.float32)
        fn = (targets * (1 - inputs)).sum().to(torch.float32)
        iou = (1-tp/(tp+fp+fn+epsilon)).to(torch.float32)
        F1Score = (1-2*tp/(2*tp+fp+fn)).to(torch.float32)
        #for i in range(inputs.shape[0]):
        #    tp = (targets[i] * inputs[i]).sum().to(torch.float32)
            #tn = ((1 - targets[i]) * (1 - inputs[i])).sum().to(torch.float32)
         #   fp = ((1 - targets[i]) * inputs[i]).sum().to(torch.float32)
        #    fn = (targets[i] * (1 - inputs[i])).sum().to(torch.float32)
         #   iou += tp/(tp+fp+fn+epsilon)
        #mIoU = iou/inputs.shape[0] 
        
        bce = nn.BCELoss() 
        return 0.5*iou+0.5*bce(inputs,targets)+F1Score


loss_fn = IoU().to(device)
#loss_fn = nn.BCELoss() 
criterion = loss_fn

optimizer = torch.optim.AdamW(mynetwork.parameters(), lr=0.003, betas=(0.9, 0.999), eps=1e-07, weight_decay=0.01, amsgrad=False)
num_epochs =100
scheduler = PolynomialLRDecay(optimizer, max_decay_steps=num_epochs, end_learning_rate=0.00001, power=0.9)

#%% training

#PATH = '/home/mostafa3_local/Documents/saved-models/2021_12_31_new_AddedMidflow.pth'
PATH = '/home/mostafa3_local/Documents/saved-models/2021_12_31_trainOnDepth.pth'

iter_loss_saved = []
epoch_loss_saved = []
valid_loss_saved = []

since = time.time()
for epoch in range(num_epochs):
    train_loss = 0.0
    mynetwork.train()
    for i in range(n_iterations):

        images = train_depth_tns[batch_size*i:batch_size*(i+1),:,:,:]
        images = images.to(device)
        labels = train_label_tns[batch_size*i:batch_size*(i+1),:,:,:]
        labels = labels.to(device)

        # Forward pass
        # Clear the gradients
        optimizer.zero_grad()
        outputs = mynetwork(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        loss.backward()
        optimizer.step()
        train_loss += loss.item() 
        iter_loss_saved.append(loss.item())
        
        
        if (i+1) % 2 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_iterations}], Loss: {loss.item():.4f}')
    
    scheduler.step(epoch)
    mynetwork.eval() 
    valid_loss = 0.0
    for i in range(int(validation_iteration)):

      # Forward Pass
      images = valid_depth_tns[batch_size*i:batch_size*(i+1),:,:,:]
      images = images.to(device)
      labels = valid_label_tns[batch_size*i:batch_size*(i+1),:,:,:]
      labels = labels.to(device)

      pred_valid_depth = mynetwork(images)
      # Find the Loss
      loss = criterion(pred_valid_depth,labels)
      # Calculate Loss
      valid_loss += loss.item()
    print(f'Epoch {epoch+1} \t\t Training Loss: {train_loss / n_iterations} \t\t Validation Loss: {valid_loss / validation_iteration}')
    epoch_loss_saved.append(train_loss / n_iterations)
    valid_loss_saved.append(valid_loss / validation_iteration)
    print('minimum loss: ', np.min(valid_loss_saved))
    
    if valid_loss_saved[epoch] <= np.min(valid_loss_saved):
        #torch.save(mynetwork, PATH)
        print('model saved')
 
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

#%%  saving model
PATH = '/home/mostafa3_local/Documents/saved-models/2021_12_31_trainOnDepth_man.pth'
torch.save(mynetwork, PATH)
#mynetwork = torch.load(PATH)


#%% Evaluation, disparity map as input


plt.plot(epoch_loss_saved)
plt.plot(valid_loss_saved)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend({'training loss','validation loss'})

test_iteration = test_rgb.shape[0]/batch_size 
print(test_iteration)
test_loss = 0
pred_test_label = np.zeros((100,1,400,400))
mynetwork.eval() 
with torch.no_grad():
  for i in range(int(test_iteration)):
      # Forward Pass
      images = test_depth_tns[batch_size*i:batch_size*(i+1),:,:,:]
      images = images.to(device)
      labels = test_label_tns[batch_size*i:batch_size*(i+1),:,:,:]
      labels = labels.to(device)
      pred_test_label[batch_size*i:batch_size*(i+1),:,:,:] = mynetwork(images).to('cpu').numpy()
      
      pred_test_label0 = mynetwork(images)
      # Find the Loss
      loss = criterion(pred_test_label0,labels)
      # Calculate Loss
      test_loss += loss.item()
#pred_test_depth = mynetwork(test_rgb_tns[batch_size*i:batch_size*(i+1),:,:,:])
test_loss = test_loss/test_iteration
print(f'test Loss: {test_loss}')


print(pred_test_label.shape)
pred_test_label1 = np.moveaxis(pred_test_label, 1, 3)
print(pred_test_label1.shape)
print(np.max(pred_test_label1))

len(pred_test_label1)


pred_train_label = np.zeros((420,1,400,400))
with torch.no_grad():
  for i in range(int(test_iteration)):
      images = train_depth_tns[batch_size*i:batch_size*(i+1),:,:,:]
      images = images.to(device)
      pred_train_label[batch_size*i:batch_size*(i+1),:,:,:] = mynetwork(images).to('cpu').numpy()

pred_train_label = np.moveaxis(pred_train_label, 1, 3)


#%%  Evaluation

pred_test_label2 = pred_test_label1.copy()
for i in range(pred_test_label2.shape[0]):
    pred_test_label2[i] = (pred_test_label2[i]>0.5)*1


pred_test_label_trp = pred_test_label2.copy()
for i in range(pred_test_label_trp.shape[0]):
    pred_test_label_trp[i] = (pred_test_label2[i]==0)*1

test_label_trp = test_label.copy()
for i in range(test_label_trp.shape[0]):
    test_label_trp[i] = (test_label[i]==0)*1


mPrecition =0;mRecall=0;mAccuracy=0;mF1Score=0;mIoU =0

#for i in range(pred_test_label2.shape[0]):
for i in range(pred_test_label2.shape[0]):

    TP = np.sum(np.multiply(pred_test_label2[i],test_label[i]))
    FN = np.sum(np.multiply(pred_test_label_trp[i],test_label[i]))
    FP = np.sum(np.multiply(pred_test_label2[i],test_label_trp[i]))
    TN = np.sum(np.multiply(pred_test_label_trp[i],test_label_trp[i]))
    
    if TP+FP ==0:
        Precition = 0
        Recall = TP/(TP+FN)
        Accuracy = (TP+TN)/(TN+TP+FP+FN)
        F1Score = 2*TP/(2*TP+FP+FN)
        IoU = TP/(TP+FP+FN)
    else:
        Precition = TP/(TP+FP)
        Recall = TP/(TP+FN)
        Accuracy = (TP+TN)/(TN+TP+FP+FN)
        F1Score = 2*TP/(2*TP+FP+FN)
        IoU = TP/(TP+FP+FN)
    
    mPrecition += Precition
    mRecall += Recall
    mAccuracy += Accuracy
    mF1Score += F1Score
    mIoU +=IoU

print(f' mPrecition: {mPrecition/pred_test_label2.shape[0]} \n mRecall: {mRecall/pred_test_label2.shape[0]} \n mAccuracy: {mAccuracy/pred_test_label2.shape[0]} \n mF1Score: {mF1Score/pred_test_label2.shape[0]} \n mIoU: {mIoU/pred_test_label2.shape[0]}')
    
  
#%% 
##################################################################################################
######################################## Visualization ###########################################
##################################################################################################


def showdepth_train(n1,n2,n3):
    fig = plt.figure(figsize=(15, 15))
    fig.add_subplot(3, 3, 1)
    plt.imshow(train_rgb[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 4)
    plt.imshow(train_label[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 7)
    plt.imshow(pred_train_label[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 2)
    plt.imshow(train_rgb[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 5)
    plt.imshow(train_label[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 8)
    plt.imshow(pred_train_label[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 3)
    plt.imshow(train_rgb[n3])
    plt.axis('off')
    fig.add_subplot(3, 3, 6)
    plt.imshow(train_label[n3])
    plt.axis('off')
    fig.add_subplot(3, 3, 9)
    plt.imshow(pred_train_label[n3])
    plt.axis('off')


def showdepths(n1,n2,n3):
    fig = plt.figure(figsize=(15, 15))
    fig.add_subplot(3, 3, 1)
    plt.imshow(test_rgb[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 4)
    plt.imshow(test_label[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 7)
    plt.imshow(pred_test_label2[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 2)
    plt.imshow(test_rgb[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 5)
    plt.imshow(test_label[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 8)
    plt.imshow(pred_test_label2[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 3)
    plt.imshow(test_rgb[n3])
    plt.axis('off')
    fig.add_subplot(3, 3, 6)
    plt.imshow(test_label[n3])
    plt.axis('off')
    fig.add_subplot(3, 3, 9)
    plt.imshow(pred_test_label2[n3])
    plt.axis('off')

#%%
showdepth_train(5,6,7)
showdepths(30,31,32)


#%% Continue training for depth estimation

mynetwork.UP8 = depthwise_separable_conv_sig(16,3,1,1).to(device) # features[0] =16

PATH = '/home/mostafa3_local/Documents/saved-models/2021_12_31_new_dualattention.pth'
#mynetwork.load_state_dict(torch.load(PATH,map_location=torch.device('cuda')))
mynetwork = torch.load(PATH)

mynetwork(train_rgb_tns[0:5].to(device)).size()

 
#%% scale invarient loss + gradient maching loss

import kornia as K
#scale invariant loss // inputs shoud be normalized between 0 and 1
class SIGMLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SIGMLoss, self).__init__()

    def forward(self, inputs, targets):
        g = torch.log(inputs+1) - torch.log(targets+1)
        n = inputs.shape[0]*inputs.shape[1]*inputs.shape[2]*inputs.shape[3]
        SI_loss = torch.mean(g**2) - (torch.sum(g)**2)/n**2
        
        
        grads: torch.Tensor = K.filters.spatial_gradient(inputs, order=1)  # BxCx2xHxW
        grads_xx = grads[:, :, 0]
        grads_yx = grads[:, :, 1]
        
        grads: torch.Tensor = K.filters.spatial_gradient(targets, order=1)  # BxCx2xHxW
        grads_xy = grads[:, :, 0]
        grads_yy = grads[:, :, 1]
        GM_loss = torch.mean(torch.abs(grads_xx-grads_xy)+ torch.abs(grads_yx-grads_yy))
        
        return SI_loss+GM_loss

#%% SSIM + MSE loss
from piqa import SSIM

class SSIMLoss(SSIM):
    def forward(self, x, y):
        ssim_loss = 1 - super().forward(x, y)
        mse_loss = nn.MSELoss()
        #c
        
        #grads: torch.Tensor = K.filters.spatial_gradient(x, order=1)  # BxCx2xHxW
        #grads_xx = grads[:, :, 0]
        #grads_yx = grads[:, :, 1]
        ##gradx = K.color.rgb_to_grayscale(grads_x + grads_y)
        
        #grads: torch.Tensor = K.filters.spatial_gradient(y, order=1)  # BxCx2xHxW
        #grads_xy = grads[:, :, 0]
        #grads_yy = grads[:, :, 1]
        ##grady = K.color.rgb_to_grayscale(grads_x + grads_y)
        #gradient_loss= torch.mean(torch.abs(grads_xx-grads_xy)+ torch.abs(grads_yx-grads_yy))
        
        return 0.7*ssim_loss + mse_loss(x,y)
#%% scale invarient loss + gradient maching loss + ssim loss + MSE loss
# this loss performs better 
from piqa import SSIM
import kornia as K
#scale invariant loss // inputs shoud be normalized between 0 and 1
class SSIMLoss(SSIM):
    
    def forward(self, inputs, targets):
        ssim_loss = 1 - super().forward(inputs, targets)
        g = torch.log(inputs+1) - torch.log(targets+1)
        n = inputs.shape[0]*inputs.shape[1]*inputs.shape[2]*inputs.shape[3]
        SI_loss = torch.mean(g**2) - (torch.sum(g)**2)/n**2
        
        
        grads: torch.Tensor = K.filters.spatial_gradient(inputs, order=1)  # BxCx2xHxW
        grads_xx = grads[:, :, 0]
        grads_yx = grads[:, :, 1]
        
        grads: torch.Tensor = K.filters.spatial_gradient(targets, order=1)  # BxCx2xHxW
        grads_xy = grads[:, :, 0]
        grads_yy = grads[:, :, 1]
        GM_loss = torch.mean(torch.abs(grads_xx-grads_xy)+ torch.abs(grads_yx-grads_yy))
        mse_loss = nn.MSELoss()

        return SI_loss+0.5*GM_loss+0.6*ssim_loss+mse_loss(inputs,targets)
#%%

criterion = SSIMLoss().to(device) 
#optimizer = torch.optim.AdamW(mynetwork.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-07, weight_decay=0.01, amsgrad=False)

optimizer = torch.optim.AdamW(mynetwork.parameters(), lr=0.005, betas=(0.9, 0.999), eps=1e-07, weight_decay=0.01, amsgrad=False)
num_epochs =150
scheduler = PolynomialLRDecay(optimizer, max_decay_steps=num_epochs, end_learning_rate=0.00001, power=0.9)

criterion(train_depth_tns[5:10].to(device),train_depth_tns[5:10].to(device))
print(device)
#%%
PATH = '/home/mostafa3_local/Documents/saved-models/2021_12_31_trainOnDepth_SIGM_SSIM.pth'

#n_iterations =5
#validation_iteration =5
iter_loss_saved = []
epoch_loss_saved = []
valid_loss_saved = []

since = time.time()
for epoch in range(num_epochs):
    train_loss = 0.0
    mynetwork.train()
    for i in range(n_iterations):

        images = train_rgb_tns[batch_size*i:batch_size*(i+1),:,:,:]
        images = images.to(device)
        labels = train_depth_tns[batch_size*i:batch_size*(i+1),:,:,:]
        labels = labels.to(device)

        # Forward pass
        # Clear the gradients
        optimizer.zero_grad()
        outputs = mynetwork(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        loss.backward()
        optimizer.step()
        train_loss += loss.item() 
        iter_loss_saved.append(loss.item())
        
        
        if (i+1) % 2 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_iterations}], Loss: {loss.item():.4f}')
    
    scheduler.step(epoch)
    mynetwork.eval() # Optional when not using Model Specific layer
    valid_loss = 0.0
    for i in range(int(validation_iteration)):

      # Forward Pass
      images = valid_rgb_tns[batch_size*i:batch_size*(i+1),:,:,:]
      images = images.to(device)
      labels = valid_depth_tns[batch_size*i:batch_size*(i+1),:,:,:]
      labels = labels.to(device)

      pred_valid_depth = mynetwork(images)
      # Find the Loss
      loss = criterion(pred_valid_depth,labels)
      # Calculate Loss
      valid_loss += loss.item()
    print(f'Epoch {epoch+1} \t\t Training Loss: {train_loss / n_iterations} \t\t Validation Loss: {valid_loss / validation_iteration}')
    epoch_loss_saved.append(train_loss / n_iterations)
    valid_loss_saved.append(valid_loss / validation_iteration)
    print('minimum loss: ',np.min(valid_loss_saved))
    
    if valid_loss_saved[epoch] <= np.min(valid_loss_saved):
        torch.save(mynetwork, PATH)
        print('model saved')
        
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

#%%
PATH = '/home/mostafa3_local/Documents/saved-models/2021_12_31_trainOnDepth_SIGM_SSIM_man.pth'

torch.save(mynetwork, PATH)
mynetwork = torch.load(PATH)

#%% Evaluation


plt.plot(epoch_loss_saved)
plt.plot(valid_loss_saved)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend({'training loss','validation loss'})

test_iteration = test_rgb.shape[0]/batch_size 
print(test_iteration)
test_loss = 0
pred_test_label = np.zeros((100,3,400,400))
mynetwork.eval()
with torch.no_grad():
  for i in range(int(test_iteration)):
      # Forward Pass
      images = test_rgb_tns[batch_size*i:batch_size*(i+1),:,:,:]
      images = images.to(device)
      labels = test_depth_tns[batch_size*i:batch_size*(i+1),:,:,:]
      labels = labels.to(device)
      pred_test_label[batch_size*i:batch_size*(i+1),:,:,:] = mynetwork(images).to('cpu').numpy()
      
      pred_test_label0 = mynetwork(images)
      # Find the Loss
      loss = criterion(pred_test_label0,labels)
      # Calculate Loss
      test_loss += loss.item()
#pred_test_depth = mynetwork(test_rgb_tns[batch_size*i:batch_size*(i+1),:,:,:])
test_loss = test_loss/test_iteration
print(f'test Loss: {test_loss}')


print(pred_test_label.shape)
pred_test_label1 = np.moveaxis(pred_test_label, 1, 3)
print(pred_test_label1.shape)
print(np.max(pred_test_label1))

len(pred_test_label1)


pred_test_depth3 = pred_test_label1 +1
print(np.min(pred_test_depth3))
print(np.max(pred_test_depth3))

test_depth3 = test_depth +1
print(np.min(test_depth3))
print(np.max(test_depth3))

AbsRel = np.mean(np.abs(pred_test_depth3-test_depth3)/test_depth3)
RMSE = np.sqrt(np.mean(np.abs(pred_test_depth3-test_depth3)**2))
RMSE_log = np.sqrt(np.mean(np.abs(np.log10(pred_test_depth3)-np.log10(test_depth3))**2))
SqRel = np.mean(np.abs(pred_test_depth3-test_depth3)**2/test_depth3)

print(f' AbsRel: {AbsRel} \n RMSE: {RMSE} \n RMSE_log: {RMSE_log} \n SqRel: {SqRel} ')



xnp = pred_test_depth3.reshape(100,480000)
ynp = test_depth3.reshape(100,480000)     

#%%
thr = np.zeros(3)
acc = np.zeros(3)
for i in range(100):
  for j in range(480000):
      thr[0] += np.max([xnp[i][j]/ynp[i][j],ynp[i][j]/xnp[i][j]]) <1.25
      thr[1] += np.max([xnp[i][j]/ynp[i][j],ynp[i][j]/xnp[i][j]]) <1.25**2
      thr[2] += np.max([xnp[i][j]/ynp[i][j],ynp[i][j]/xnp[i][j]]) <1.25**3
  acc += thr/480000
  thr = np.zeros(3)

print(f' Accuracy with threshold: {acc/(i+1)}')


#%% 
##################################################################################################
######################################## Visualization ###########################################
##################################################################################################

since = time.time()
pred_train_label = np.zeros((420,3,400,400))
with torch.no_grad():
  for i in range(int(test_iteration)):
      images = train_rgb_tns[batch_size*i:batch_size*(i+1),:,:,:]
      images = images.to(device)
      pred_train_label[batch_size*i:batch_size*(i+1),:,:,:] = mynetwork(images).to('cpu').numpy()
      
time_elapsed = time.time() - since
print('420 frames complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

pred_train_label = np.moveaxis(pred_train_label, 1, 3)



def showdepth_train(n1,n2,n3):
    fig = plt.figure(figsize=(15, 15))
    fig.add_subplot(3, 3, 1)
    plt.imshow(train_rgb[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 4)
    plt.imshow(train_depth[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 7)
    plt.imshow(pred_train_label[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 2)
    plt.imshow(train_rgb[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 5)
    plt.imshow(train_depth[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 8)
    plt.imshow(pred_train_label[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 3)
    plt.imshow(train_rgb[n3])
    plt.axis('off')
    fig.add_subplot(3, 3, 6)
    plt.imshow(train_depth[n3])
    plt.axis('off')
    fig.add_subplot(3, 3, 9)
    plt.imshow(pred_train_label[n3])
    plt.axis('off')


def showdepths(n1,n2,n3):
    fig = plt.figure(figsize=(15, 15))
    fig.add_subplot(3, 3, 1)
    plt.imshow(test_rgb[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 4)
    plt.imshow(test_depth[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 7)
    plt.imshow(pred_test_label1[n1])
    plt.axis('off')
    fig.add_subplot(3, 3, 2)
    plt.imshow(test_rgb[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 5)
    plt.imshow(test_depth[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 8)
    plt.imshow(pred_test_label1[n2])
    plt.axis('off')
    fig.add_subplot(3, 3, 3)
    plt.imshow(test_rgb[n3])
    plt.axis('off')
    fig.add_subplot(3, 3, 6)
    plt.imshow(test_depth[n3])
    plt.axis('off')
    fig.add_subplot(3, 3, 9)
    plt.imshow(pred_test_label1[n3])
    plt.axis('off')

plt.imshow(test_rgb[5])
plt.imshow(pred_test_label1[13])

#%%
showdepth_train(22,23,24)
showdepths(20,21,22)
#%%
plt.imshow(pred_test_label1[28])

a=pred_test_label1[28]
a[:,:,2] = 0
plt.imshow(a)


