# Graveyard code Seger

In [None]:
############################## Training One Epoch #################################
def standard_train(disp_net, pose_net, loss_function, epipolar_loss, mvs_loss, optimizer, n_epochs, train_loader, val_loader, test_loader):
    abs_diff, abs_rel, sq_rel, rmse, rmse_log = [], [], [], [], []   
    loss_train, loss_val, b100loss_list = [], [], []                                             # Initialized a new b100 list for plotting intermediate loss results
    # loop over the total number of epochs
    for epoch in range(n_epochs):
        running_loss = 0
        running_loss_val = 0
        b100loss = 0

        ############################ Training #################################

        # set the network architectures to training mode
        disp_net.train()
        pose_net.train()
        flow_net.train()
        
        # loop through the batches
        for i, data in enumerate(train_loader):

            # extract source and target image for 1 forward pass. Send to GPU
            tgt_img = data['target_image'].to(device)
            src_img_prev = data['source_image_prev'].to(device) 
            src_img_next = data['source_image_next'].to(device) 

            # Concatenate source images
            src_images = [src_img_prev , src_img_next]

            # concatenate source and target image. 
            concat_tgt_src = [torch.cat((tgt_img, src_img), 1) for src_img in src_images] # [[B, 6, H, W], .. (2x)]        

            ############################# DEPTH #############################
            # predict disparity of target image and translate to depth at four scales
            disparities_tgt = disp_net(tgt_img)
            depths_tgt = [1/disp for disp in disparities_tgt] # [[B, 1, H, W], ... (4x) ]                       
            
            # predict disparities of source images and tranlsate to depth at four scales for each src image
            disparities_src_prev = disp_net(src_img_prev)
            disparities_src_next = disp_net(src_img_next)
            disparities_src = [*disparities_src_prev, *disparities_src_next]
            depths_src = [1/disp for disp in disparities_src] # [[B, 1, H, W], ... (8x) ]

            #############################  POSE  ##############################
            # predict pose focal lengths 
            pose, focal_lengths = pose_net(tgt_img, src_images)  # [B, C=2, 6] and [B, C=2, 2]
                    
            # calculate the camera intrinsics matrix
            intrinsics = torch.stack([focal2intrinsics(focal_lengths[:,c,:], tgt_img) for c in range(len(src_images))], dim=1) #[B, C=2, 3, 3]        

            ############################## FLOW ################################

            # predict flow map between the two source images and target image. Put in a list
            flow = [*flow_net(concat_tgt_src[0]), *flow_net(concat_tgt_src[1])] # [[B, 2, H, W], ...x8]


            ############################## LOSS ###############################
            loss_pc, warped, diff = loss_function(tgt_img, src_images, depths_tgt, flow, pose, intrinsics) ### ADD 'flow' BETWEEN DEPTHS AND POSE FOR APC 
            loss_e = epipolar_loss(intrinsics, pose, src_images, flow)                               ### comment this loss away
            loss_mvs = mvs_loss(intrinsics, pose, src_images, tgt_img, depths_src, depths_tgt)
            loss = loss_pc + 0.1 * loss_mvs# + 0.01 * loss_e

            ########################### BACKPROP ##############################
            # Zero the gradients, backprop and optimization of parameters
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ######################## CALCULATE LOSSES ######################
            # obtain total running loss for each epoch
            running_loss += loss
            b100loss += loss

            ################# PRINTING AFTER 100 BATCHES ###################
            if i % 3 == 0 and i != 0:
                print('loss e =', loss_e)
                print('loss mvs =', loss_mvs)
                print('loss ap =', loss_pc)
                avg_100_batch_loss = float(b100loss / 100)
                b100loss_list.append(avg_100_batch_loss)
                print('batch: ', i)
                print('100 batch avg loss = ', avg_100_batch_loss)   

                # Save the results and plot the training curve 
                save_images_norm(output_directory_train, depths_tgt, tgt_img, i)
                plot_train_curve(b100loss_list)

                b100loss = 0

                ############### VALIDATE EVERY 100 BATCHES #####################
                abs_diff, abs_rel, sq_rel, rmse, rmse_log = validate_with_gt_during_training(test_loader, disp_net, abs_diff, abs_rel, sq_rel, rmse, rmse_log)  
                plot_metrics(abs_diff, abs_rel, sq_rel, rmse, rmse_log)

                # save_models(disp_net, pose_net, flow_net, model_path_disp, epoch) ################################## COMMENT IN!
                print('Model_saved')

        ############################# Validation Phase ################################
        disp_net.eval()
        pose_net.eval()
        flow_net.eval()

        for i, data in enumerate(val_loader):
            # extract source and target image for 1 forward pass. Send to GPU
            tgt_img = data['target_image'].to(device)
            src_img_prev = data['source_image_prev'].to(device) 
            src_img_next = data['source_image_next'].to(device) 
            src_images = [src_img_prev , src_img_next]

            with torch.no_grad():
                
                # Concatenate source images. First source image, followed by target
                src_images = [src_img_prev , src_img_next]

                # concatenate source and target image. 
                concat_tgt_src = [torch.cat((tgt_img, src_img), 1) for src_img in src_images] # [[B, 6, H, W], .. (2x)]       

                ############################# DEPTH #############################
                # predict disparity of target image and translate to depth at four scales
                disparities_tgt = disp_net(tgt_img)
                depths_tgt = [(1/disp).unsqueeze(1) for disp in disparities_tgt] # [[B, 1, H, W], ... (4x) ]                       
                
                # predict disparities of source images and translate to depth at four scales for each src image
                disparities_src_prev = disp_net(src_img_prev)
                disparities_src_next = disp_net(src_img_next)
                disparities_src = [*disparities_src_prev, *disparities_src_next]

                depths_src = [(1/disp).unsqueeze(1) for disp in disparities_src] # [[B, 1, H, W], ... (8x) ]
                                
                pose, focal_lengths = pose_net(tgt_img, src_images)  

                # calculate the camera intrinsics matrix
                intrinsics = torch.stack([focal2intrinsics(focal_lengths[:,c,:], tgt_img) for c in range(len(src_images))], dim=1) #[B, C=2, 3, 3]                

                ############################# FLOW ################################

                # predict flow map between the two source images and target image. Put in a list
                flow = [*flow_net(concat_tgt_src[0]), *flow_net(concat_tgt_src[1])] # [[B, 2, H, W], ...x8]

                ############################## LOSS ###############################

                loss_pc, warped, diff = loss_function(tgt_img, src_images, depths_tgt, flow, pose, intrinsics) ### ADD 'flow' BETWEEN DEPTHS AND POSE FOR APC
                #loss_e = epipolar_loss(intrinsics, pose, src_images, flow)
                loss_mvs = mvs_loss(intrinsics, pose, src_images, tgt_img, depths_src, depths_tgt)
                loss = loss_pc + 0.1 * loss_mvs# + 0.01 * loss_e
                running_loss_val += loss

        ###############################################################################
        
        # Calculate training van validation loss and plot the resulting curves
        running_loss_avg = running_loss/(len(train_loader.dataset)/4)
        running_loss_val_avg = running_loss_val/len(val_loader.dataset)
        loss_train.append(running_loss_avg)
        loss_val.append(running_loss_val_avg)
        plot_curve(loss_train, loss_val)

        print('epoch:', epoch)
        print('Training Loss: {:.4f}'.format(running_loss/(len(train_loader.dataset)/4)))
        print('Validation Loss: {:.4f}'.format(running_loss_val/(len(val_loader.dataset)/1)))


 
        
        

In [None]:
def adaptive_photometric_loss(tgt_img, src_imgs, depths, flows, pose, intrinsics):
    """ Calculate the adaptive photometric loss
    Args:
        tgt_img: target image                                     -- [B, 3, H, W]
        src_imgs: list of the source images (previous & next)     -- [[B, 3, H, W], [B, 3, H, W]]
        depths: list of depth maps of target images on 4 scales   -- [[B, 1, H, W], [B, 1, H, W], [B, 1, H, W], [B, 1, H, W]]
        flows: flow maps                                          -- [[B, 2, H, W]....8x] 
        pose: 6DoF pose parameters from target to source          -- [B, C=2, 6]
        intrinsics: camera intrinsic matrix                       -- [B, C=2, 3, 3]
    Return:
        total adaptive photometric loss
    """
    def one_scale_pc(depth): # [B, 1, H, W]
        assert(pose.size(1) == len(src_imgs))
       
        reconstruction_loss = 0

        # retrieve depth size
        b, _, h, w = depth.size()
        downscale = tgt_img.size(2)/h
        
        tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') # [B, 3, H, W]
        src_imgs_scaled = [F.interpolate(src_img, (h, w), mode='area') for src_img in src_imgs] # [[B, 3, H, W], [B, 3, H, W]]
        
        downscale_matrix = torch.tensor([[1, 1, 1/downscale],
                        [1, 1, 1/downscale],
                        [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)

        intrinsics_scaled = intrinsics * downscale_matrix #[B, C=2, 3, 3]
        
        warped_imgs = []
        diff_maps = []
        
        for i, src_img in enumerate(src_imgs_scaled):
            current_intrinsics = intrinsics_scaled[:,i] # [B, 3, 3]
            current_pose = pose[:,i]
            
            # warp a source image to the target image plane -- ######### changed input from src_img to tgt_img ########
            projected_image, valid_points = inverse_warp(tgt_img, depth[:,0], current_pose, current_intrinsics) # [B, 3, H, W], # [B, H, W]
            
            # calculate loss
            ssim_loss = pytorch_ssim.SSIM(window_size = 11)

            ##################### changed to src_img & src_img_scaled_valid
            src_img_scaled_valid = src_img * valid_points.unsqueeze(1).float() 
            projected_image_valid = projected_image * valid_points.unsqueeze(1).float()

            ####################### changed to src_img_scaled_valid ##########################
            ssim = ssim_loss(src_img_scaled_valid, projected_image_valid) #* valid_points.unsqueeze(1).float() # value
            #ssim_abs = ssim.mean()

            ####################### changed to src_img_scaled_valid ######################
            diff = (src_img_scaled_valid - projected_image) * valid_points.unsqueeze(1).float() 
            diff_abs = diff.abs().mean()
            
            #reconstruction_loss = diff_abs
            reconstruction_loss += 0.85 * ((1-ssim)/2) + (1-0.85) * diff_abs

            warped_imgs.append(projected_image[0])
            diff_maps.append(diff[0])        
                        
        return reconstruction_loss, warped_imgs, diff_maps
    
    def one_scale_apc(local_flow, nr_src_img): 
        loss_flow = 0      
         # retrieve depth size and downscale factor
        b, _, h, w = local_flow.size() 
        downscale = tgt_img.size(2)/h

        # Scale the source and target image to the size of the respective flow map
        src_imgs_scaled = [F.interpolate(src_img, (h, w), mode='area') for src_img in src_imgs] # [[B, 3, H, W], [B, 3, H, W]]
        tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') # [B, 3, H, W]
        
        # define downscaling matrix
        downscale_matrix = torch.tensor([[1, 1, 1/downscale],
                        [1, 1, 1/downscale],
                        [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)

        # Scale intrinsics matrix according to scale
        intrinsics_scaled = intrinsics * downscale_matrix #[B, C=2, 3, 3]

        # Determine the current intrinsics and pose
        current_intrinsics = intrinsics_scaled[:,nr_src_img] # [B, 3, 3]
        current_pose = pose[:,nr_src_img]
 
        warped_imgs = []
        diff_maps = []
    
        # warp a source image to the target image plane using optical flow
        warped_image = flow_warp(local_flow, tgt_img) #[B,3,H,W]
        
        # calculate loss
        ssim_loss = pytorch_ssim.SSIM(window_size = 11)

        # calculate ssim loss and diff
        ssim = ssim_loss(src_imgs_scaled[nr_src_img], warped_image) 
        diff = src_imgs_scaled[nr_src_img] - warped_image
        diff_abs = diff.abs().mean()
        
        # Calculate loss 
        loss_flow += 0.85 * ((1-ssim)/2) + (1-0.85) * diff_abs    
                        
        return loss_flow

    warped_results, diff_results = [], []
    warped_results_flow, diff_results_flow = [], []
    total_loss_flow, total_loss_pc, apc_loss = 0,0,0

    # Loop over the depths to obtain rigid photometric loss
    for i, depth in enumerate(depths):
        loss_pc, warped, diff = one_scale_pc(depth)                      
        total_loss_pc += loss_pc
        warped_results.append(warped)
        diff_results.append(diff)

    # Loop over the flows to obtain adaptive photometric loss related to flow
    for idx, flow in enumerate(flows):
        if idx < 4:
            nr_src_img = 0
        else:
            nr_src_img = 1
        loss_flow = one_scale_apc(flow, nr_src_img)
        total_loss_flow += loss_flow
        warped_results_flow.append(warped)

    loss_apc = min(total_loss_flow, total_loss_pc)

    return loss_apc, warped_results, diff_results

######################## NEW Photometric Loss ############################

def standard_photometric_loss(tgt_img, src_imgs, depths, pose, intrinsics):
    """ Calculate the photometric loss related to rigid displacement (not adaptive)
    Args:
        tgt_img: target image                                     -- [B, 3, H, W]
        src_imgs: list of the source images (previous & next)     -- [[B, 3, H, W], [B, 3, H, W]]
        depths: list of depth maps of target images on 4 scales   -- [[B, 1, H, W], [B, 1, H, W], [B, 1, H, W], [B, 1, H, W]]
        pose: 6DoF pose parameters from target to source          -- [B, C=2, 6]
        intrinsics: camera intrinsic matrix                       -- [B, C=2, 3, 3]
    Return:
        total photometric loss related to rigid displacement
    """
    def one_scale(depth): # [B, 1, H, W]
        assert(pose.size(1) == len(src_imgs))
       
        reconstruction_loss = 0

        # retrieve depth size
        b, _, h, w = depth.size()
        downscale = tgt_img.size(2)/h
        
        tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') # [B, 3, H, W]
        src_imgs_scaled = [F.interpolate(src_img, (h, w), mode='area') for src_img in src_imgs] # [[B, 3, H, W], [B, 3, H, W]]
        
        downscale_matrix = torch.tensor([[1, 1, 1/downscale],
                        [1, 1, 1/downscale],
                        [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)

        intrinsics_scaled = intrinsics * downscale_matrix #[B, C=2, 3, 3]
        
        warped_imgs = []
        diff_maps = []
        
        for i, src_img in enumerate(src_imgs_scaled):
            current_intrinsics = intrinsics_scaled[:,i] # [B, 3, 3]
            current_pose = pose[:,i]
            
            # warp a source image to the target image plane -- ######### changed input from src_img to tgt_img ########
            projected_image, valid_points = inverse_warp(tgt_img, depth[:,0], current_pose, current_intrinsics) # [B, 3, H, W], # [B, H, W]

            # calculate loss
            ssim_loss = pytorch_ssim.SSIM(window_size = 11)

            ##################### changed to src_img & src_img_scaled_valid
            src_img_scaled_valid = src_img * valid_points.unsqueeze(1).float() 
            projected_image_valid = projected_image * valid_points.unsqueeze(1).float()

            ####################### changed to src_img_scaled_valid ##########################
            ssim = ssim_loss(src_img_scaled_valid, projected_image_valid) #* valid_points.unsqueeze(1).float() # value
            #ssim_abs = ssim.mean()

            ####################### changed to src_img_scaled_valid ######################
            diff = (src_img_scaled_valid - projected_image) * valid_points.unsqueeze(1).float() 
            diff_abs = diff.abs().mean()
            
            #reconstruction_loss = diff_abs
            reconstruction_loss += 0.85 * ((1-ssim)/2) + (1-0.85) * diff_abs

            warped_imgs.append(projected_image[0])
            diff_maps.append(diff[0])        
                        
        return reconstruction_loss, warped_imgs, diff_maps

    warped_results, diff_results = [], []

    total_loss = 0

    for i, depth in enumerate(depths):
        loss, warped, diff = one_scale(depth)                      
        total_loss += loss
        warped_results.append(warped)
        diff_results.append(diff)


    return total_loss, warped_results, diff_results

######################### OLD PHOTOMETRIC LOSS ################################

# def standard_photometric_loss(tgt_img, src_imgs, depths, pose, intrinsics):
#     """ Calculate the photometric loss related to rigid displacement (not adaptive)
#     Args:
#         tgt_img: target image                                     -- [B, 3, H, W]
#         src_imgs: list of the source images (previous & next)     -- [[B, 3, H, W], [B, 3, H, W]]
#         depths: list of depth maps of target images on 4 scales   -- [[B, 1, H, W], [B, 1, H, W], [B, 1, H, W], [B, 1, H, W]]
#         pose: 6DoF pose parameters from target to source          -- [B, C=2, 6]
#         intrinsics: camera intrinsic matrix                       -- [B, C=2, 3, 3]
#     Return:
#         total photometric loss related to rigid displacement
#     """
#     def one_scale(depth): # [B, 1, H, W]
#         assert(pose.size(1) == len(src_imgs))
       
#         reconstruction_loss = 0

#         # retrieve depth size
#         b, _, h, w = depth.size()
#         downscale = tgt_img.size(2)/h
        
#         tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') # [B, 3, H, W]
#         src_imgs_scaled = [F.interpolate(src_img, (h, w), mode='area') for src_img in src_imgs] # [[B, 3, H, W], [B, 3, H, W]]
        
#         downscale_matrix = torch.tensor([[1, 1, 1/downscale],
#                         [1, 1, 1/downscale],
#                         [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)

#         intrinsics_scaled = intrinsics * downscale_matrix #[B, C=2, 3, 3]
        
#         warped_imgs = []
#         diff_maps = []
        
#         for i, src_img in enumerate(src_imgs_scaled):
#             current_intrinsics = intrinsics_scaled[:,i] # [B, 3, 3]
#             current_pose = pose[:,i]
            
#             projected_image, valid_points = inverse_warp(src_img, depth[:,0], current_pose, current_intrinsics) # [B, 3, H, W], # [B, H, W]

#             # calculate loss
#             ssim_loss = pytorch_ssim.SSIM(window_size = 11)
#             tgt_img_scaled_valid = tgt_img_scaled * valid_points.unsqueeze(1).float()
#             projected_image_valid = projected_image * valid_points.unsqueeze(1).float()

#             ssim = ssim_loss(tgt_img_scaled_valid, projected_image_valid) #* valid_points.unsqueeze(1).float() # value
#             #ssim_abs = ssim.mean()
#             diff = (tgt_img_scaled - projected_image) * valid_points.unsqueeze(1).float()
#             diff_abs = diff.abs().mean()
            
#             #reconstruction_loss = diff_abs
#             reconstruction_loss += 0.85 * ((1-ssim)/2) + (1-0.85) * diff_abs

#             warped_imgs.append(projected_image[0])
#             diff_maps.append(diff[0])        
                        
#         return reconstruction_loss, warped_imgs, diff_maps

#     warped_results, diff_results = [], []

#     total_loss = 0

#     for i, depth in enumerate(depths):
#         loss, warped, diff = one_scale(depth)                      
#         total_loss += loss
#         warped_results.append(warped)
#         diff_results.append(diff)


#     return total_loss, warped_results, diff_results