# Image Retargeting via Deformation Fields
In this notebook, you train the required models for retargeting an image from scratch. You can find the results of this notebook applied to the image `scenes/balloons.jpg` resized to 50% and 150% of the width in the directory `results`.

__Note__: This code version is intended to be compact.

To train on your own image, drop it in the `scenes` directory and adapt the `SCENE` parameter in the user input cell below. You can also choose any other resizing factor (0.5 stands for resizing to 50% of the image width, while 1.5 stands for resizing to 150% of the image width).

Then, simply rerun the entire notebook.

In [None]:
#############################################
###############  USER INPUTS  ###############
#############################################

SCENE = "balloons"      # write the name of your image without the file ending
STRETCH_FACTOR = 0.5    # how much to stretch the image in the x-direction. 0.5 means half as wide, 2.0 means twice as wide
DOWNSCALE_FACTOR = 2    # adjust the resolution of the initial image; > 1.0 means downscaling, < 1.0 means upscaling
VISUALIZE = True        # if True, the intermediate steps of the algorithm are visualized

In [None]:
#############################################
#################  IMPORTS  #################
#############################################

from tqdm import tqdm

import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data

from torchvision.utils import save_image

import imageio
from PIL import Image

import models
import misc_functions as misc

import os

if not os.path.exists("net"):
    os.makedirs("net")
if not os.path.exists("results"):
    os.makedirs("results")

In [None]:
#############################################
################  CONSTANTS  ################
#############################################

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DIM_L_EMBED = 10 # embedding dimension of positional embedding

LR = 1e-3
UPDATE_STEPS_PER_EPOCH = 100
ITERATION_MULTIPLIER = 1

In [None]:
######################################################
##################  PRECOMPUTATIONS  #################
######################################################

# 1. Learn a continuous representation of the image
def learn_continuous_image_representation(img, grid, SCENE, ITS_INITIAL):
    param_net = models.Energy2D(dim_l_embed=DIM_L_EMBED).to(DEVICE)
    optimiser = torch.optim.Adam(param_net.parameters(), lr=LR)

    tensor_rgb = img.view(3, -1).transpose(0,1).to(DEVICE)
    tensor_pos = grid.view(2, -1).transpose(0,1).to(DEVICE)

    if os.path.isfile("net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_continuous_representation.net"):
        param_net.load_state_dict(torch.load("net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_continuous_representation.net"))
        return param_net
    for _ in tqdm(range(ITS_INITIAL)):
        for _ in range(UPDATE_STEPS_PER_EPOCH):
            optimiser.zero_grad()
            
            offset_out = param_net(tensor_pos)
            loss = (offset_out - tensor_rgb).square().mean()

            loss.backward()
            optimiser.step()
    
    # Save network
    torch.save(param_net.state_dict(), "net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_continuous_representation.net")
    
    return param_net

# 2.1. Learn the initial deformation
def learn_initial_deformation_net(tensor_pos, stretched_pos_train, ITS_DEFORM):
    deform_net = models.Deform2d(dim_l_embed=DIM_L_EMBED).to(DEVICE)
    optimiser = torch.optim.Adam(deform_net.parameters(), lr=LR)

    if os.path.isfile("net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_initial_deform.net"):
        deform_net.load_state_dict(torch.load("net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_initial_deform.net"))
        return deform_net
    for _ in tqdm(range(ITS_DEFORM)):
        for _ in range(UPDATE_STEPS_PER_EPOCH):
            optimiser.zero_grad()

            deformed_pos = stretched_pos_train.clone()
            offset_out = deform_net(deformed_pos)
            deformed_pos[:,1:] += offset_out
            loss = (deformed_pos - tensor_pos).square().mean()
            
            loss.backward()
            optimiser.step()
    
    # Output final visualization of the initial deformation and save network
    torch.save(deform_net.state_dict(), "net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_initial_deform.net")
    
    return deform_net

# 2.2. Learn the cumulative energy.
def learn_cumulative_energy(cumsum_grd, stretched_pos_train, ITS_INITIAL):
    cumgrad_net = models.Deform2d(dim_l_embed=DIM_L_EMBED).to(DEVICE)
    optimiser_cumgrad_net = torch.optim.Adam(cumgrad_net.parameters(), lr=0.0001)
    
    pts_for_cumgrad = stretched_pos_train.clone().to(DEVICE)
    pts_for_cumgrad[:,1] /= STRETCH_FACTOR

    if os.path.isfile("net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_cumgrad.net"):
        cumgrad_net.load_state_dict(torch.load("net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_cumgrad.net"))
        return cumgrad_net
    for _ in tqdm(range(ITS_INITIAL)):
        for _ in range(UPDATE_STEPS_PER_EPOCH):
            optimiser_cumgrad_net.zero_grad()
            
            cumsum_out = cumgrad_net(pts_for_cumgrad.clone())
            loss = (cumsum_out - cumsum_grd.view(-1)[:,None]).square().mean()

            loss.backward()
            optimiser_cumgrad_net.step()
        
    # Save network
    torch.save(cumgrad_net.state_dict(), "net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_cumgrad.net")

    return cumgrad_net

# Precompute the initial network required to optimise the final deformation.
def precomputations(img, ITS_INITIAL, ITS_DEFORM):
    grid = torch.ones(2, img.size()[1], img.size()[2])
    strip_a = torch.linspace(0, img.size()[1]-1, img.size()[1])[:,None] / img.size()[1]
    strip_b = torch.linspace(0, img.size()[2]-1, img.size()[2])[None,:] / img.size()[2]
    grid[0] = grid[0] * strip_a
    grid[1] = grid[1] * strip_b
    tensor_pos = grid.view(2, -1).transpose(0,1).to(DEVICE)

    DIM_1 = img.size()[1]
    DIM_2 = img.size()[2]
    DIM_2_TEST = int(DIM_2 * STRETCH_FACTOR)

    ##############################################
    ### STEP 1. - INITIALISE & LEARN IMAGE NET ###
    ##############################################
    print("*** STEP 1: LEARN CONTINUOUS IMAGE REPRESENTATION ***")
    
    param_net = learn_continuous_image_representation(img, grid, SCENE, ITS_INITIAL)
    
    print("*** DONE WITH LEARNING THE IMAGE ***")
    
    # Render intermediate visualization
    with torch.no_grad():
        test_img = param_net(tensor_pos).transpose(0,1).view(3, DIM_1, DIM_2).cpu()
        
    if VISUALIZE:
        misc.show(test_img)
    
    dir_name = "results/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_continuous_representation.png"
    save_image(test_img, dir_name)
    print("- Rendered continuous image representation was saved under ", dir_name)

    ##############################################
    ### STEP 2.1. - INITIALISE DEFORMATION NET ###
    ##############################################

    # Input samples for the deformed net; but as many input samples as deformed.
    strip_a = torch.linspace(0, DIM_1-1, DIM_1)[:,None] / DIM_1
    strip_b = torch.linspace(0, DIM_2-1, int(DIM_2))[None,:] / DIM_2
    grid_test = torch.ones(2, img.size()[1], int(DIM_2))
    grid_test[0] = grid_test[0] * strip_a                  # e.g. [0.0, 1.0]
    grid_test[1] = grid_test[1] * strip_b * STRETCH_FACTOR # e.g. [0.0, 0.5]
    stretched_pos_train = grid_test.view(2, -1).transpose(0,1).to(DEVICE)

    # Initialise samples on which we regularise.
    strip_a = torch.linspace(0, DIM_1-1, DIM_1)[:,None] / DIM_1
    strip_b = torch.linspace(0, DIM_2_TEST-1, int(DIM_2_TEST))[None,:] / DIM_2_TEST
    grid_test = torch.ones(2, img.size()[1], int(DIM_2_TEST))
    grid_test[0] = grid_test[0] * strip_a                  # e.g. [0.0, 1.0]
    grid_test[1] = grid_test[1] * strip_b * STRETCH_FACTOR # e.g. [0.0, 0.5]
    stretched_pos_test = grid_test.view(2, -1).transpose(0,1).to(DEVICE)

    strip_a = torch.linspace(-1, DIM_1, DIM_1 + 2)[:,None] / DIM_1
    strip_b = torch.linspace(-1, DIM_2_TEST, DIM_2_TEST + 2)[None,:] / DIM_2_TEST
    grid_test = torch.ones(2, img.size()[1] + 2, DIM_2_TEST + 2)
    grid_test[0] = grid_test[0] * strip_a                  # e.g. [0.0, 1.0]
    grid_test[1] = grid_test[1] * strip_b * STRETCH_FACTOR # e.g. [0.0, 0.5]
    
    #--------------------------------------------
    
    print("*** STEP 2.1: LEARN THE INITIAL DEFORMATION ***")
    
    deform_net = learn_initial_deformation_net(tensor_pos, stretched_pos_train, ITS_DEFORM)
    
    print("*** DONE WITH LEARNING THE INITAL DEFORMATION ***")
    
    # Render intermediate visualization
    with torch.no_grad():
        deformed_positions = deform_net(stretched_pos_test)
        pos = stretched_pos_test.clone()
        pos[:,1:] += deformed_positions
        deformed_positions = pos
        test_img = param_net(deformed_positions).transpose(0,1).view(3, DIM_1, DIM_2_TEST).cpu()
        
    if VISUALIZE:
        misc.show(test_img)
    
    dir_name = "results/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_initial_deform.png"
    save_image(test_img, dir_name)
    print("- Initially deformed image was saved under ", dir_name)

    ##################################################
    ### STEP 2.2. - LEARN INITAL CUMULATIVE ENERGY ###
    ##################################################
  
    # Compute ground turth cumulative energy via image gradient.
    grd = misc.gradient(img)
    cumsum_grd = torch.cumsum(grd, 1)
    cumsum_grd /= cumsum_grd.size()[1]
    cumsum_grd = cumsum_grd.to(DEVICE)
    cumsum_grd /= cumsum_grd.max()
    
    #--------------------------------------------
    
    print("*** STEP 2.2: LEARN THE INITIAL CUMULATIVE ENERGY ***")
    
    cumgrad_net = learn_cumulative_energy(cumsum_grd, stretched_pos_train, ITS_INITIAL)
    
    print("*** DONE WITH LEARNING THE INITAL CUMULATIVE ENERGY ***")
    
    # Render intermediate visualization
    with torch.no_grad():
        pts_for_cumgrad = stretched_pos_train.clone().to(DEVICE)
        pts_for_cumgrad[:,1] /= STRETCH_FACTOR
        cumgrad_out = cumgrad_net(pts_for_cumgrad.clone())
        test_img = cumgrad_out.view(1, DIM_1, DIM_2).cpu().repeat(3,1,1)
        test_img -= test_img.min()
        test_img /= test_img.max()

    if VISUALIZE:
        misc.show(test_img)
        
    dir_name = "results/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_cumgrad.png"
    save_image(test_img, dir_name)
    print("- Cumulative energy visualization was saved under ", dir_name)

    return param_net, deform_net, cumgrad_net

In [None]:
#############################################
##################  METHOD  #################
#############################################

def monotonicity_loss(deform_net, stretched_pos_test, STRETCH_FACTOR, OAX):
    tensor_pos_a = stretched_pos_test.clone()
    tensor_pos_b = stretched_pos_test.clone()
    offset = torch.rand_like(tensor_pos_b) * 0.05 + 0.0025
    offset[:,OAX] = 0.0
    tensor_pos_b += offset

    offset_a = deform_net(tensor_pos_a)
    offset_b = deform_net(tensor_pos_b)
    
    if STRETCH_FACTOR < 1.0: # shrinkage
        # make sure the further right we go in the image, the more offset we have
        # i.e. point more to the left should ALWAYS be smaller than the one on it's right
        loss_mono  = torch.nn.functional.relu(offset_a - offset_b).mean()
    elif STRETCH_FACTOR > 1.0: # expanding
        loss_mono = torch.nn.functional.relu(offset_b - offset_a).mean()
        loss_mono += torch.nn.functional.relu(offset_a).mean() # no positive values, only allow starting at 0 deformation
    else: # no mono loss for editing (=same size)
        loss_mono = torch.zeros(1, device=DEVICE).mean()
        
    return loss_mono
  
def boundary_loss(deform_net, STRETCH_FACTOR, stretched_pos_test, AX):  
    input_left = stretched_pos_test.clone()
    input_left[:,1] = 0.0
    target_left = torch.zeros_like(input_left[:,AX:(AX+1)])

    # all the way to the right
    input_right = stretched_pos_test.clone()
    input_right[:,1] = STRETCH_FACTOR
    target_right = 1.0 - STRETCH_FACTOR

    loss_boundaries  = (deform_net(input_left)  - target_left ).square().mean()
    loss_boundaries += (deform_net(input_right) - target_right).square().mean()
    
    return loss_boundaries

def gradient_flow(deform_net, param_net, cumgrad_net, stretched_pos_test, STRETCH_FACTOR, AX, OAX, DISTANCE_FOR_GRADIENT, DIM_1, DIM_2_TEST, loss_boundaries):
    loss_gradients = 0.0

    pt_x = stretched_pos_test.clone() # in [0,1] for dim 0, in [0,0.5] for dim 1
    pt_x_eps = pt_x.clone()
    pt_x_eps[:,AX:(AX+1)] += DISTANCE_FOR_GRADIENT

    offset_pt_x = deform_net(pt_x.clone())
    D_pt_x = pt_x.clone()
    D_pt_x[:,AX:(AX+1)] += offset_pt_x.clone()
    
    D_pt_x_eps = pt_x_eps.clone()
    offset_pt_x_eps = deform_net(D_pt_x_eps.clone())
    D_pt_x_eps[:,AX:(AX+1)] += offset_pt_x_eps

    if STRETCH_FACTOR > 1.0: # expansion: fix boundary losses
        offsets = D_pt_x[:,AX]
        smaller = offsets < 0.0
        if offsets[smaller].size()[0] > 0:
            loss_boundaries += -offsets[smaller].sum() / offsets.size()[0]
        bigger = offsets > 1.0
        if offsets[bigger].size()[0] > 0:
            loss_boundaries += offsets[bigger].sum() / offsets.size()[0]
    
    D_pt_x_eps = D_pt_x_eps.clamp(0.0, 1.0)
    D_pt_x = D_pt_x.clamp(0.0, 1.0)

    energy = torch.abs(cumgrad_net(D_pt_x_eps.clone()) - cumgrad_net(D_pt_x.clone()))[:,0].detach()
    
    if STRETCH_FACTOR > 1.0: # expansion: punish LOCAL gradient, no need for integral stuff (we can't jump over things)
        D_pt_eps_x = D_pt_x.clone()
        D_pt_eps_x2 = D_pt_x.clone()
        D_pt_eps_x3 = D_pt_x.clone()
        D_pt_eps_x4 = D_pt_x.clone()
        D_pt_eps_x[:,AX:(AX+1)] += DISTANCE_FOR_GRADIENT
        D_pt_eps_x2[:,AX:(AX+1)] -= DISTANCE_FOR_GRADIENT
        D_pt_eps_x3[:,OAX:(OAX+1)] += DISTANCE_FOR_GRADIENT
        D_pt_eps_x4[:,OAX:(OAX+1)] -= DISTANCE_FOR_GRADIENT
        energy  = ((param_net(D_pt_eps_x.clone()) - param_net(D_pt_x.clone())).square().sum(-1).abs() + 0.000001).sqrt().detach()
        energy += ((param_net(D_pt_eps_x2.clone()) - param_net(D_pt_x.clone())).square().sum(-1).abs() + 0.000001).sqrt().detach()
        energy += ((param_net(D_pt_eps_x3.clone()) - param_net(D_pt_x.clone())).square().sum(-1).abs() + 0.000001).sqrt().detach()
        energy += ((param_net(D_pt_eps_x4.clone()) - param_net(D_pt_x.clone())).square().sum(-1).abs() + 0.000001).sqrt().detach()
        energy /= 4.0

    
    offset_1 = deform_net(pt_x.clone())
    offset_2 = deform_net(pt_x_eps.clone())
    deformation_magnitude = (offset_1 - offset_2).abs().sum(dim=1)
    
    energy = energy.view(DIM_1, DIM_2_TEST)
    loss_gradients = (deformation_magnitude.view(DIM_1, DIM_2_TEST) * energy.clone()).mean()

    # make sure we do include the diff between last piece <-> end of item
    fringe = stretched_pos_test.view(DIM_1, DIM_2_TEST, 2)[:, -1, :].clone()
    boundary = torch.ones_like(fringe)
    boundary[:,AX] = STRETCH_FACTOR

    offset_1 = deform_net(fringe.clone())
    offset_2 = deform_net(boundary.clone())

    D_fringe = fringe.clone()
    D_fringe[:,AX:(AX+1)] += offset_1
    D_boundary = boundary.clone()
    D_boundary[:,AX:(AX+1)] += offset_2

    deformation_magnitude_fringe = (offset_1 - offset_2).abs().sum(dim=1)
    energy_fringe = torch.abs(cumgrad_net(D_boundary.clone()) - cumgrad_net(D_fringe.clone()))[:,0].detach()
    energy_fringe *= 10.0

    impact = 1.0 / DIM_1
    loss_gradients += (deformation_magnitude_fringe * energy_fringe).mean() * impact

    # for shearing, only take colour gradient
    D_pt_eps_x = D_pt_x.clone()
    D_pt_eps_x[:,AX:(AX+1)] += DISTANCE_FOR_GRADIENT
    energy = ((param_net(D_pt_eps_x.clone()) - param_net(D_pt_x.clone())).square().sum(-1).abs() + 0.000001).sqrt().detach().view(DIM_1, DIM_2_TEST)
    
    energy_for_shearing = energy.clone()
    
    return energy_for_shearing, loss_gradients, loss_boundaries

def shearing_loss(deform_net, stretched_pos_test, OAX, DISTANCE_FOR_GRADIENT, DIM_1, DIM_2_TEST, energy_for_shearing):
    loss_shearing = torch.zeros(1, device=DEVICE)

    # 1. Get points.
    pt_x = stretched_pos_test.clone()
    pt_y = pt_x.clone()
    pt_y[:,OAX:(OAX+1)] += DISTANCE_FOR_GRADIENT

    # 2. Get their offset on y-axis.
    D_pt_x = pt_x.clone()
    D_pt_y = pt_y.clone()
    offset_x = deform_net(D_pt_x.clone())
    offset_y = deform_net(D_pt_y.clone())
    diff = (offset_x - offset_y).abs()
    
    diff = diff.view(DIM_1, DIM_2_TEST).contiguous()
    diff = diff * energy_for_shearing

    loss_shearing = diff.mean()
    
    return loss_shearing

def continuous_seam_carve(img, STRETCH_FACTOR, ITS_INITIAL, ITS_DEFORM, ITS_OPTIM):
    DISTANCE_FOR_GRADIENT = 0.01
    #stretch factor:
    #    1 = original size
    #  0.5 = half size
    #    2 = double the size
    DIM_1 = img.size()[1]
    DIM_2 = img.size()[2]
    DIM_2_TEST = int(DIM_2 * STRETCH_FACTOR)
    
    ######################################################
    ### STEP 1-2 - PRECOMPUTE ALL THE INITIAL NETWORKS ###
    ######################################################
    
    param_net, deform_net, cumgrad_net = precomputations(img, ITS_INITIAL, ITS_DEFORM)
    
    cumgrad_net.train(False)
    param_net.train(False)

    #############################################
    ### STEP 3 - OPTIMISE DEFORMATION NETWORK ###
    #############################################
    
    # Initialise samples on which we regularise.
    strip_a = torch.linspace(0, DIM_1-1, DIM_1)[:,None] / DIM_1
    strip_b = torch.linspace(0, DIM_2_TEST-1, int(DIM_2_TEST))[None,:] / DIM_2_TEST
    grid_test = torch.ones(2, img.size()[1], int(DIM_2_TEST))
    grid_test[0] = grid_test[0] * strip_a                  # e.g. [0.0, 1.0]
    grid_test[1] = grid_test[1] * strip_b * STRETCH_FACTOR # e.g. [0.0, 0.5]
    stretched_pos_test = grid_test.view(2, -1).transpose(0,1).to(DEVICE)

    AX = 1
    OAX = 0
    
    print("*** STEP 3: OPTIMISE DEFORMATION ***")
    
    optimiser = torch.optim.AdamW(deform_net.parameters(), lr=0.001, weight_decay=0.0, amsgrad=True)

    best_image = None
    best_loss  = None

    if STRETCH_FACTOR > 1.0:
        ITS_OPTIM = int(ITS_OPTIM * 0.5)
        
    for epoch in tqdm(range(ITS_OPTIM)):
        avg_loss = 0.0

        avg_loss_mono       = 0.0
        avg_loss_boundaries = 0.0
        avg_loss_gradients  = 0.0
        avg_loss_shearing   = 0.0
        avg_loss_cap        = 0.0

        loss_amp = 1.0
        # decay LR last few iterations
        if epoch > int(ITS_OPTIM * 0.75):
            loss_amp = (ITS_OPTIM - epoch) / ITS_OPTIM * 4.0
        
        lambdas = dict()

        lambdas['mono']   =  10000.0
        lambdas['bound']  =  10000.0
        lambdas['grad']   =  1000.0
        lambdas['shear']  =    250.0
        lambdas['cap']    = 100000.0
        
        for s in range(UPDATE_STEPS_PER_EPOCH):
            optimiser.zero_grad()

            ### MONOTONICITY ###
            loss_mono = monotonicity_loss(deform_net, 
                                          stretched_pos_test, 
                                          STRETCH_FACTOR, 
                                          OAX)
            ### MONOTONICITY ###

            ### BOUNDARIES ###
            loss_boundaries = boundary_loss(deform_net, 
                                            STRETCH_FACTOR, 
                                            stretched_pos_test, 
                                            AX)
            ### BOUNDARIES ###

            ### GRADIENT FLOW CONTROL ###
            energy_for_shearing, loss_gradients, loss_boundaries = gradient_flow(deform_net,
                                                                                 param_net,
                                                                                 cumgrad_net,
                                                                                 stretched_pos_test,
                                                                                 STRETCH_FACTOR,
                                                                                 AX,
                                                                                 OAX,
                                                                                 DISTANCE_FOR_GRADIENT,
                                                                                 DIM_1,
                                                                                 DIM_2_TEST,
                                                                                 loss_boundaries)
            ### GRADIENT FLOW CONTROL ###
            
            ### SHEARING ###
            loss_shearing = shearing_loss(deform_net,
                                          stretched_pos_test,
                                          OAX,
                                          DISTANCE_FOR_GRADIENT,
                                          DIM_1,
                                          DIM_2_TEST,
                                          energy_for_shearing)
            ### SHEARING ###

            ### CAP DEFORMATION FOR EXPANSION ###
            loss_cap = torch.zeros(1, device=DEVICE).mean()
            if STRETCH_FACTOR > 1.0:
                pt_x = stretched_pos_test.clone() # in [0,1] for dim 0, in [0,0.5] for dim 1
                pt_x_eps = pt_x.clone()
                max_dist = max(1.0 / DIM_2_TEST, 1.0 / DIM_2)
                DISTANCE_VALUES = max_dist * 1.1 * torch.rand(pt_x.size()[0], device=DEVICE) + 0.0000001
                pt_x_eps[:,AX] += DISTANCE_VALUES

                offset_pt_x = deform_net(pt_x.clone()).view(-1)
                offset_pt_x_eps = deform_net(pt_x_eps.clone()).view(-1)

                loss_cap = torch.nn.functional.relu((offset_pt_x - offset_pt_x_eps) - DISTANCE_VALUES * .75)
                
                loss_cap = loss_cap.mean()
            ### CAP DEFORMATION FOR EXPANSION ###

            # Putting everything together.
            loss = loss_mono * lambdas['mono'] + loss_boundaries * lambdas['bound'] + loss_gradients * lambdas['grad'] + loss_shearing * lambdas['shear']
            if STRETCH_FACTOR > 1.0:
                loss += loss_cap * lambdas['cap']
                
            loss = loss * loss_amp

            avg_loss_mono += loss_mono.item()
            avg_loss_boundaries += loss_boundaries.item()
            avg_loss_gradients += loss_gradients.item()
            avg_loss_shearing += loss_shearing.item()
            avg_loss_cap += loss_cap.item()
            
            avg_loss += loss.item()

            loss.backward()

            if epoch == 0:
                optimiser.zero_grad()
                break
            optimiser.step()

        # Keep track of the best image and net
        if (best_loss == None or avg_loss < best_loss) or epoch == 1:
            best_loss = avg_loss
            best_image = misc.output_deformed_image(deform_net, param_net, stretched_pos_test, DIM_1, DIM_2_TEST)
            torch.save(deform_net.state_dict(), "net/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_optimized_deform.net")

    print("*** DONE WITH OPTIMISING DEFORMATION ***")
    return best_image

In [None]:
#############################################
###################  MAIN  ##################
#############################################

def resize_image(img, factor_x, factor_y = 1.0):
    height = int(img.shape[0] // DOWNSCALE_FACTOR)
    width = int(img.shape[1] // DOWNSCALE_FACTOR)
    img = Image.fromarray(img).convert("RGBA")
    img = img.resize((width, height), Image.Resampling.LANCZOS)
    img = torch.tensor(np.array(img)).transpose(1,2).transpose(1,0).float()[0:3] / 255.0

    if factor_x != 1.0 or (factor_x == 1.0 and factor_y == 1.0):
        result = continuous_seam_carve(img=img,
                                       STRETCH_FACTOR=factor_x,
                                       ITS_INITIAL=int(2 * 50 * ITERATION_MULTIPLIER),
                                       ITS_DEFORM=int(20 * ITERATION_MULTIPLIER),
                                       ITS_OPTIM=int(100 * ITERATION_MULTIPLIER))
    if factor_y != 1.0:
        result = continuous_seam_carve(img=img.transpose(1,2).contiguous(),
                                       STRETCH_FACTOR=factor_y,
                                       ITS_INITIAL=int(2 * 50 * ITERATION_MULTIPLIER), 
                                       ITS_DEFORM=int(20 * ITERATION_MULTIPLIER), 
                                       ITS_OPTIM=int(100 * ITERATION_MULTIPLIER))
        result = result.transpose(1,2).contiguous()
        
    return result

# -----------------------------------------

try:
    img = imageio.v2.imread("scenes/"+SCENE+".jpg")
except:
    img = imageio.v2.imread("scenes/"+SCENE+".png")

result = resize_image(img, STRETCH_FACTOR, 1.0)

if VISUALIZE:
    misc.show(result)
    
dir_name = "results/"+SCENE+"_"+str(int(STRETCH_FACTOR*100.))+"_final_deform.png"
save_image(result, dir_name)
print("- The final deformed image was saved under ", dir_name)