In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np


from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.models.feature_extraction as feat
from torchvision.io import decode_image

import multiprocessing
from multiprocessing import Process
from multiprocessing import Pipe

import os
import shutil

import copy
from skimage import io,color
#thanks to https://stackoverflow.com/questions/47411872/extract-multiple-windows-patches-from-an-image-array-as-defined-in-another-ar
from skimage.util.shape import view_as_windows

%matplotlib inline
plt.rcParams['figure.figsize'] = [16, 10] # matplotlib setting to control the size of display images

In [None]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
unloader = transforms.ToPILImage()

# Implementing Patch Matching Algorithm

In [None]:
class NNF:
    #creates mapping from img_1 to img_2
    def __init__(self,img_1,img_2,patch_size = 7, norm = False):
        if patch_size%2 == 0:
            print("Patch size should be odd. Increasing by 1.")
            patch_size = patch_size + 1
        self.normalised = norm
        self.patch_size = patch_size
        self.source = img_1
        self.dest = img_2
        self.rows = self.source.shape[1]
        self.cols = self.source.shape[2]
        self.chans = self.source.shape[0]
        #random coordinates
        self.matchings = np.random.rand(self.rows,self.cols,2)
        #place x,y will contain mapping x'y'
        self.matchings[:,:,0] = self.matchings[:,:,0]*self.dest.shape[1]
        self.matchings[:,:,1] = self.matchings[:,:,1]*self.dest.shape[2]
        self.matchings = self.matchings.astype(int)
        return
    #src:https://stackoverflow.com/questions/13405956/convert-an-image-rgb-lab-with-python
    def __rgbToLAB(self,image_tensor):
        labImage = color.rgb2lab(image_tensor.numpy(force = True),channel_axis = 0)
        labTensor = torch.Tensor(labImage)
        return labTensor
    
    def __labToRGB(self,lab_tensor):
        rgbImage = color.lab2rgb(lab_tensor.numpy(force = True), channel_axis = 0)
        rgbTensor = torch.Tensor(rgbImage)
        return rgbTensor
        
                          
    #Distance metric between two patches. If one of the patches is at the corner, it accounts for this.
    def __D(self,row,col,dest_row,dest_col):
        dest_rows = self.dest.shape[1]
        dest_cols = self.dest.shape[2]
        
        r = self.patch_size//2
        
#         r_left = min(r,col,dest_col)
#         r_right = min(r,self.cols - (col + 1),dest_cols - (dest_col + 1))
#         r_top = min(r,row,dest_row)
#         r_bott = min(r,self.rows - (row + 1),dest_rows - (dest_row + 1))
        
        #the array is padded by r, so we need to shift each index up by r
#         patch1 = self.sNorm[:,r + row - r_top:r + row + r_bott + 1,r + col - r_left : r + col + r_right + 1]
        patch1 = self.sWin[:,row,col].squeeze()
        p1s = patch1.shape
#         patch2 = self.dNorm[:,r + dest_row - r_top : r + dest_row + r_bott + 1, r + dest_col - r_left : r + dest_col + r_right + 1]
        patch2 = self.dWin[:,dest_row,dest_col].squeeze()
        p2s = patch2.shape
        if(patch1.shape != patch2.shape):
            print(p1s,p2s)
            print("Source dims: ({} rows by {} cols)\nDest dims: ({} rows by {} cols)".format(self.rows,self.cols,dest_rows,dest_cols))
            print("source row: {}\nsource col: {}\ndest row: {}\ndest col:{}".format(row,col,dest_row,dest_col))
            print("Source Patch Cols: [{} to {})".format(col - r_left,col + r_right + 1))
            print("Source Patch Rows: [{} to {})".format(row - r_top,row + r_bott + 1))
            print("Dest Patch Cols: [{} to {})".format(dest_col - r_left,dest_col + r_right + 1))
            print("Dest Patch Rows: [{} to {})".format(dest_row - r_top,dest_row + r_bott + 1))
            print(r_left,r_right,r_top,r_bott,self.source[row - r_top:row + r_bott + 1,row - r_left:row + r_right + 1])
        

        res = float(((patch1 - patch2)**2).sum())**0.5
        return res
    
    def __D_vec(self,source_loc,dest_locs):
        #vectorised version of the distance metric, for the random search step
        #the patches given by scikit's window view for a given location have the patch's top left corner as the origin.
        #However, this is already accounted for since dNorm and sNorm are padded by r
        dest_patch_coords = (dest_locs[:,0].reshape(1,dest_locs.shape[0]),dest_locs[:,1].reshape(1,dest_locs.shape[0]))
        
        dest_patches = self.dWin[:,dest_patch_coords[0],dest_patch_coords[1]][0][0]
        #print(dest_patches.shape)
        #dest_patches = dest_patches.squeeze()
        
        source_patch = self.sWin[:,source_loc[0],source_loc[1]].squeeze()
        
        
        dists = ((source_patch - dest_patches)**2).sum(axis = tuple(range(1,len(dest_patches.shape))))**0.5
        #print(dists)
        closest = np.argmin(dists)
        return np.array(dest_locs[closest]).astype(int),float(dists[closest])
        
        
        
    
    def __step(self,row,col,alpha = 0.5,reverse_prop = False,threshold = 0):
        #setup
        dest_row = self.matchings[row,col,0]
        dest_col = self.matchings[row,col,1]
        #print(dest_row)
        D_min = self.dists[row,col]
        if(D_min == np.inf):
            D_min = self.__D(row,col,dest_row,dest_col)
            self.dists[row,col] = D_min
        #propogation step: Compare to mappings near previous spot
        candidate = None
        
        #positive one if reverse_prop,negative one if not
        pm1 = (reverse_prop * 2) - 1
        
        if(not reverse_prop and (row!= 0)) or (reverse_prop and (row!= self.rows - 1)):
            prev = self.matchings[row + pm1,col]
            if (not reverse_prop and (prev[0] < self.dest.shape[1] - 1)) or (reverse_prop and (prev[0] > 0)):
                candidate = self.__D(row,col,prev[0] - pm1,prev[1])
                if(candidate <= D_min):
                    D_min = candidate
                    self.matchings[row,col,0] = prev[0] - pm1
                    self.matchings[row,col,1] = prev[1]
                    self.dists[row,col] = D_min
                
        
        if(not reverse_prop and (col!= 0)) or (reverse_prop and (col!= self.cols - 1)):
            prev = self.matchings[row,col + pm1]
            if (not reverse_prop and (prev[1] < self.dest.shape[2] - 1)) or (reverse_prop and (prev[1] > 0)):
                candidate = self.__D(row,col,prev[0],prev[1] - pm1)
                if(candidate <= D_min):
                    D_min = candidate
                    self.matchings[row,col,0] = prev[0]
                    self.matchings[row,col,1] = prev[1] - pm1
                    self.dists[row,col] = D_min
        
        #now we do the random search step
        
        i = 0
        #maximum random search radius. Set as maximum image dimension
        destshape = self.dest.shape
        dest_rows = destshape[1]
        dest_cols = destshape[2]
        w = max(dest_rows,dest_cols)
        
        #This search step replaces search loop. The loop ends when w*(alpha**i), which is the search radius,
        #gets to be less than or equal to 1.
        num_iters = int(np.emath.logn(alpha,1/w))
        #If it turns out we wouldn't be doing any iterations for some reason, we skip the operations
        if(num_iters > 0):
            #This part is actually simplified somewhat. Take the current position, get a random 2d offset 
            #vector where each element is within the range(-w*alpha**i,w*alpha**i). The proper way to do this
            #would be to create a search with a radius of #-w*alpha**i.
            prods = np.full((num_iters,2),[row,col],dtype = int) + np.rint((np.random.rand(num_iters,2) - 0.5)*2*(w*alpha**np.arange(num_iters).reshape(num_iters,1))).astype(int)
            #check if each generated position is a valid one. If not, exclude it
            rows_in_bounds = np.logical_and(np.greater_equal(prods[:,0],0),np.less(prods[:,0],dest_rows))
            cols_in_bounds = np.logical_and(np.greater_equal(prods[:,1],0),np.less(prods[:,1],dest_cols))
            prods = prods[np.logical_and(rows_in_bounds,cols_in_bounds)]
            #don't bother getting distances if every option has been excluded
            if(len(prods > 0)):
                candidate,dist = self.__D_vec([row,col],prods)
                #D_vec takes vector of positions and returns the position and distance of the best match out of them
                if(dist <= D_min + threshold):
                    self.num_switch += 1
                    D_min = dist
                    self.matchings[row,col] = candidate
                    self.dists[row,col] = D_min

#         while w*(alpha**i) >= 1:
            
#             #rounds to nearest integer instead of truncating
#             prod = np.array([row,col]).astype(int) + np.rint((np.random.rand(2) - 0.5)*2*(w*alpha**i)).astype(int)
#             if(prod[0] >=0 and prod[0] < dest_rows and prod[1] >= 0 and prod[1] < dest_cols):
#                 candidate = self.__D(row,col,prod[0],prod[1])
#                 if(candidate < D_min):
#                     D_min = candidate
#                     self.matchings[row,col,0] = prod[0]
#                     self.matchings[row,col,1] = prod[1]
#             i += 1
        return
    
    
    def iterate(self,n = 5,alpha = 0.5):
        norm = self.normalised
        #Normalised source image, Each location in the source image should be normalised, because according
        #to the paper, this makes for better feature matching
        pad_dims = ((0,0),(self.patch_size//2,self.patch_size//2),(self.patch_size//2,self.patch_size//2))
        if norm:
            self.sNorm = torch.Tensor(np.pad(F.normalize(self.source.detach(), dim = 0).numpy(),pad_dims,mode = 'reflect'))
            self.dNorm = torch.Tensor(np.pad(F.normalize(self.dest.detach(),dim = 0).numpy(),pad_dims,mode='reflect'))
        else:
            self.sNorm = torch.Tensor(np.pad(self.source.detach().numpy(),pad_dims,mode = 'reflect'))
            self.dNorm = torch.Tensor(np.pad(self.dest.detach().numpy(),pad_dims,mode='reflect'))
        #Makes it easier to get patches. Since this is just a view, it does not take much extra space
        window_shape = (self.source.shape[0],self.patch_size,self.patch_size)
        self.sWin = view_as_windows(self.sNorm.numpy(),window_shape)
        self.dWin = view_as_windows(self.dNorm.numpy(),window_shape)
        
        self.dists = np.full((self.rows,self.cols),np.inf)
        
        self.num_switch = 0
        
        
        for x in range(n):
            print("Iteration {}:".format(x + 1))
            for row in range(self.rows):
                if(int((row/self.rows)*100)%10 == 0) and (int(((row - 1)/self.rows)*100)%10 != 0):
                    print("\t{}%".format(int(100*row/self.rows)))
                for col in range(self.cols):
                    #reverse propogation direction on even runs(suggested by original paper)
                    self.__step(row,col,alpha,(x + 1)%2 == 0)
            if((self.num_switch/(self.rows*self.cols)) < 0.01):
                print("\t Mapping Reached.")
                #break
            self.num_switch = 0
            
            
        #so that we do not store unnecessary data
        del(self.num_switch)
        del(self.sNorm)
        del(self.dNorm)
        del(self.sWin)
        del(self.dWin)
        del(self.dists)
        return
    
    
    
        
                          

### A past experiment in trying to restrict the range of pixels certain segments of an image could draw from

In [None]:
class restricted_NNF(NNF):
    def __init__(self,img_1,img_2,mapping,patch_size = 7, norm = False):
        super().__init__(img_1,img_2,patch_size = 7, norm = False)
        
        scale_factor = self.rows/mapping.shape[0]
        if scale_factor != int(scale_factor):
            raise("Expected integer value scaling. Instead, the image is scaled by {}.".format(scale_factor))
        self.scale_factor = scale_factor
        self.mapping = (mapping*scale_factor).repeat(scale_factor,axis = 0).repeat(scale_factor,axis = 1)
        self.settle_streak = 0
        return
    
    
    def __D(self,row,col,dest_row,dest_col):
        dest_rows = self.dest.shape[1]
        dest_cols = self.dest.shape[2]
        
        r = self.patch_size//2
        

        patch1 = self.sWin[:,row,col].squeeze()
        p1s = patch1.shape
#         patch2 = self.dNorm[:,r + dest_row - r_top : r + dest_row + r_bott + 1, r + dest_col - r_left : r + dest_col + r_right + 1]
        patch2 = self.dWin[:,dest_row,dest_col].squeeze()
        p2s = patch2.shape
        if(patch1.shape != patch2.shape):
            print(p1s,p2s)
            print("Source dims: ({} rows by {} cols)\nDest dims: ({} rows by {} cols)".format(self.rows,self.cols,dest_rows,dest_cols))
            print("source row: {}\nsource col: {}\ndest row: {}\ndest col:{}".format(row,col,dest_row,dest_col))
            print("Source Patch Cols: [{} to {})".format(col - r_left,col + r_right + 1))
            print("Source Patch Rows: [{} to {})".format(row - r_top,row + r_bott + 1))
            print("Dest Patch Cols: [{} to {})".format(dest_col - r_left,dest_col + r_right + 1))
            print("Dest Patch Rows: [{} to {})".format(dest_row - r_top,dest_row + r_bott + 1))
            print(r_left,r_right,r_top,r_bott,self.source[row - r_top:row + r_bott + 1,row - r_left:row + r_right + 1])
        

        res = float(((patch1 - patch2)**2).sum())**0.5
        return res
    
    def __D_vec(self,source_loc,dest_locs):
        #vectorised version of the distance metric, for the random search step
        #the patches given by scikit's window view for a given location have the patch's top left corner as the origin.
        #However, this is already accounted for since dNorm and sNorm are padded by r
        dest_patch_coords = (dest_locs[:,0].reshape(1,dest_locs.shape[0]),dest_locs[:,1].reshape(1,dest_locs.shape[0]))
        
        dest_patches = self.dWin[:,dest_patch_coords[0],dest_patch_coords[1]][0][0]
        #print(dest_patches.shape)
        #dest_patches = dest_patches.squeeze()
        
        source_patch = self.sWin[:,source_loc[0],source_loc[1]].squeeze()
        
        
        dists = ((source_patch - dest_patches)**2).sum(axis = tuple(range(1,len(dest_patches.shape))))**0.5
        #print(dists)
        closest = np.argmin(dists)
        return np.array(dest_locs[closest]).astype(int),float(dists[closest])
        
    
    
        
        
    def __step(self,row,col,alpha = 0.5,reverse_prop = False,threshold = 0):
        
        
        #setup
        dest_row = self.matchings[row,col,0]
        dest_col = self.matchings[row,col,1]
        #print(dest_row)
        D_min = self.dists[row,col]
        if(D_min == np.inf):
            D_min = self.__D(row,col,dest_row,dest_col)
            self.dists[row,col] = D_min
        #propogation step: Compare to mappings near previous spot
        candidate = None
        
        #positive one if reverse_prop,negative one if not
        pm1 = (reverse_prop * 2) - 1
        
        #allow propogation across boundaries
        
        if(not reverse_prop and (row!= 0)) or (reverse_prop and (row!= self.rows - 1)):
            prev = self.matchings[row + pm1,col]
            if (not reverse_prop and (prev[0] < self.dest.shape[1] - 1)) or (reverse_prop and (prev[0] > 0)):
                candidate = self.__D(row,col,prev[0] - pm1,prev[1])
                if(candidate <= D_min):
                    D_min = candidate
                    self.matchings[row,col,0] = prev[0] - pm1
                    self.matchings[row,col,1] = prev[1]
                    self.dists[row,col] = D_min
                
        
        if(not reverse_prop and (col!= 0)) or (reverse_prop and (col!= self.cols - 1)):
            prev = self.matchings[row,col + pm1]
            if (not reverse_prop and (prev[1] < self.dest.shape[2] - 1)) or (reverse_prop and (prev[1] > 0)):
                candidate = self.__D(row,col,prev[0],prev[1] - pm1)
                if(candidate <= D_min):
                    D_min = candidate
                    self.matchings[row,col,0] = prev[0]
                    self.matchings[row,col,1] = prev[1] - pm1
                    self.dists[row,col] = D_min
        
        #now we do the random search step
        
        i = 0
        #maximum random search radius. Set as maximum image dimension
        destshape = self.dest.shape
        dest_rows = destshape[1]
        dest_cols = destshape[2]
        w = max(self.scale_factor,2)
        
        #This search step replaces search loop. The loop ends when w*(alpha**i), which is the search radius,
        #gets to be less than or equal to 1.
        num_iters = int(np.emath.logn(alpha,1/w))
        #If it turns out we wouldn't be doing any iterations for some reason, we skip the operations
        if(num_iters > 0):
            #This part is actually simplified somewhat. Take the current position, get a random 2d offset 
            #vector where each element is within the range(-w*alpha**i,w*alpha**i). The proper way to do this
            #would be to create a search with a radius of #-w*alpha**i.

            
            
            prods = np.full((num_iters,2),[row,col],dtype = int) + np.rint((np.random.rand(num_iters,2) - 0.5)*2*(w*alpha**np.arange(num_iters).reshape(num_iters,1))).astype(int)
            #check if each generated position is a valid one. If not, exclude it
            #print(self.rows,self.cols,dest_rows,dest_cols)
            
            min_rows = np.maximum(self.scale_factor*(prods[:,0]//self.scale_factor),0)
            max_rows = np.minimum((prods[:,0]//self.scale_factor + 1)*self.scale_factor,dest_rows)
            min_cols = np.maximum(self.scale_factor*(prods[:,1]//self.scale_factor),0)
            max_cols = np.minimum((prods[:,1]//self.scale_factor + 1)*self.scale_factor,dest_cols)
            
            rows_in_bounds = np.logical_and(np.greater_equal(prods[:,0],min_rows),np.less(prods[:,0],max_rows))
            cols_in_bounds = np.logical_and(np.greater_equal(prods[:,1],min_cols),np.less(prods[:,1],max_cols))
            prods = prods[np.logical_and(rows_in_bounds,cols_in_bounds)]
            
            #print(prods)
            #don't bother getting distances if every option has been excluded
            if(len(prods > 0)):
                candidate,dist = self.__D_vec([row,col],prods)
                #D_vec takes vector of positions and returns the position and distance of the best match out of them
                if(dist <= D_min + threshold):
                    self.num_switch += 1
                    D_min = dist
                    self.matchings[row,col] = candidate
                    self.dists[row,col] = D_min

        return
    
        
    def iterate(self,n = 5,alpha = 0.5):
        norm = self.normalised
        #Normalised source image, Each location in the source image should be normalised, because according
        #to the paper, this makes for better feature matching
        pad_dims = ((0,0),(self.patch_size//2,self.patch_size//2),(self.patch_size//2,self.patch_size//2))
        if norm:
            self.sNorm = torch.Tensor(np.pad(F.normalize(self.source.detach(), dim = 0).numpy(),pad_dims,mode = 'reflect'))
            self.dNorm = torch.Tensor(np.pad(F.normalize(self.dest.detach(),dim = 0).numpy(),pad_dims,mode='reflect'))
        else:
            self.sNorm = torch.Tensor(np.pad(self.source.detach().numpy(),pad_dims,mode = 'reflect'))
            self.dNorm = torch.Tensor(np.pad(self.dest.detach().numpy(),pad_dims,mode='reflect'))
        #Makes it easier to get patches. Since this is just a view, it does not take much extra space
        window_shape = (self.source.shape[0],self.patch_size,self.patch_size)
        self.sWin = view_as_windows(self.sNorm.numpy(),window_shape)
        self.dWin = view_as_windows(self.dNorm.numpy(),window_shape)
        
        self.dists = np.full((self.rows,self.cols),np.inf)
        
        self.num_switch = 0
        
        
        for x in range(n):
            print("Iteration {}:".format(x + 1))
            for row in range(self.rows):
                if(int((row/self.rows)*100)%10 == 0) and (int(((row - 1)/self.rows)*100)%10 != 0):
                    print("\t{}%".format(int(100*row/self.rows)))
                for col in range(self.cols):
                    #reverse propogation direction on even runs(suggested by original paper)
                    self.__step(row,col,alpha,(x + 1)%2 == 0)
            if((self.num_switch/(self.rows*self.cols)) < 0.01):
                print("\t Mapping Reached.")
                self.settle_streak += 1
                if(self.settle_streak >= 3):
                    print("\t Reached a streak of 3 settled states. Stopping early.")
                    break
            else:
                self.settle_streak = 0
            self.num_switch = 0
            
            
        #so that we do not store unnecessary data
        del(self.num_switch)
        del(self.sNorm)
        del(self.dNorm)
        del(self.sWin)
        del(self.dWin)
        del(self.dists)
        return
    
            
        
    


# Bi-Directional Search(BDS)

Two versions are offered: The one implemented according to the paper, and one which tries to weight pixels vs the original image based on certainty, with an adjustable threshold

### Paper version

In [None]:
#Technically the correct one
class BDS:
    def __init__(self,source,target,patch_size = 7,search_alpha = 0.5,NNFs = None, norm = False,default_weight = 0.5):
        if patch_size%2 == 0:
            print("Patch size should be odd. Increasing by 1.")
            patch_size = patch_size + 1
        self.source = source
        self.target = target
        #source to target map
        self.patch_size = patch_size
        self.search_alpha = search_alpha
        if(NNFs is None):
            self.normalised = norm
            self.complete = NNF(source,target,patch_size,norm)
            self.cohere = NNF(target,source,patch_size,norm)
        else:
            self.complete,self.cohere = NNFs
            self.normalised = self.complete.normalised or self.cohere.normalised
        return
    
    def __rgbToLAB(self,image_arr):
        labImage = color.rgb2lab(image_arr,channel_axis = 0)
        return labImage
    
    def __labToRGB(self,lab_arr):
        rgbImage = color.lab2rgb(lab_arr, channel_axis = 0)
        return rgbImage
        
    def refine_bidirectional_maps(self, num_iters = 4):
        print("Refining Completeness map (Source --> Target) for {} iterations...".format(num_iters))
        self.complete.iterate(num_iters,self.search_alpha)
        print("Getting Coherence map (Target --> Source) for {} iterations...".format(num_iters))
        self.cohere.iterate(num_iters,self.search_alpha)
        return
    
    def get_new_bidirectional_maps(self):
        self.complete = NNF(source,target,self.patch_size)
        self.cohere = NNF(target,source,self.patch_size)
        return
    
    
    def __get_matched_patch(self,NNF_mapping,row,col):
        #helper function to get the indices for the patch being mapped from and the patch being mapped to.
        to_rows = NNF_mapping.dest.shape[1]
        to_cols = NNF_mapping.dest.shape[2]
        
        from_rows = NNF_mapping.source.shape[1]
        from_cols = NNF_mapping.source.shape[2]
        
        to_row = NNF_mapping.matchings[row,col,0]
        to_col = NNF_mapping.matchings[row,col,1]
        
        r = self.patch_size//2
        
        r_left = min(r,col,to_col)
        r_right = min(r,from_cols - (col + 1),to_cols - (to_col + 1))
        r_top = min(r,row,to_row)
        r_bott = min(r,from_rows - (row + 1),to_rows - (to_row + 1))
        
        
        from_patch = ((row - r_top,row + r_bott + 1),(col - r_left,col + r_right + 1))
        to_patch = ((to_row - r_top,to_row + r_bott + 1),(to_col - r_left,to_col + r_right + 1))
        
        return (from_patch,to_patch)
        
    def patch_vote(self,complete_cohere_weights = (0.5,0.5),switch_target = True):
        
        shape = self.target.shape
        rows = shape[1]
        cols = shape[2]
        chans = shape[0]
        

        
        sShape = self.source.shape
        sRows = sShape[1]
        sCols = sShape[2]
        tShape = self.target.shape
        tRows = tShape[1]
        tCols = tShape[2]
        
        
        Ns = sRows*sCols
        Nt = tRows*tCols
        
        
        votes_t_to_s = np.zeros((sShape[0],tRows,tCols))
        votes_s_to_t = np.zeros((sShape[0],tRows,tCols))
        
        num_votes_t_to_s = np.zeros((tRows,tCols))
        num_votes_s_to_t = np.zeros((tRows,tCols))

        
        src = self.source.numpy(force=True)
        
        for row in range(tRows):
            for col in range(tCols):
                t_patch,s_patch = self.__get_matched_patch(self.cohere,row,col)
                
                t_rows,t_cols = t_patch
                s_rows,s_cols = s_patch
                
                t_top,t_bott = t_rows
                t_left,t_right = t_cols
                s_top,s_bott = s_rows
                s_left,s_right = s_cols
                
                patch = src[:,s_top:s_bott,s_left:s_right]
                
                votes_t_to_s[:,t_top:t_bott,t_left:t_right] += patch
                num_votes_t_to_s[t_top:t_bott,t_left:t_right] += np.ones((patch.shape[1],patch.shape[2]))
                
        
        for row in range(sRows):
            for col in range(sCols):
                s_patch,t_patch = self.__get_matched_patch(self.complete,row,col)
                
                s_rows,s_cols = s_patch
                t_rows,t_cols = t_patch
                
                s_top,s_bott = s_rows
                s_left,s_right = s_cols
                t_top,t_bott = t_rows
                t_left,t_right = t_cols
                
                s_patch = src[:,s_top:s_bott,s_left:s_right]
                num_votes_s_to_t[t_top:t_bott,t_left:t_right] += np.ones((s_patch.shape[1],s_patch.shape[2]))
                votes_s_to_t[:,t_top:t_bott,t_left:t_right] += s_patch
                
                
        votes_t_to_s*=complete_cohere_weights[0]*2/Nt
        votes_s_to_t *=complete_cohere_weights[1]*2/Ns
        num = votes_t_to_s + votes_s_to_t
        denom = np.maximum(num_votes_s_to_t,1)*complete_cohere_weights[1]*2/Ns + np.maximum(num_votes_t_to_s,1)*complete_cohere_weights[0]*2/Ns
        targ_result = torch.Tensor(num/denom)
        if(switch_target):
            self.target = targ_result
        return targ_result
                
        
        
        

### Weighted Version

In [None]:
#Works best from my tests
class BDS:
    def __init__(self,source,target,patch_size = 7,search_alpha = 0.5,NNFs = None, norm = False,default_weight = 0.5):
        if patch_size%2 == 0:
            print("Patch size should be odd. Increasing by 1.")
            patch_size = patch_size + 1
        self.source = source
        self.target = target
        #source to target map
        self.patch_size = patch_size
        self.search_alpha = search_alpha
        #Plain black image is not a good default. Default to original target image, but still mostly use votes
        self.default_weight = default_weight
        if(NNFs is None):
            self.normalised = norm
            self.complete = NNF(source,target,patch_size,norm)
            self.cohere = NNF(target,source,patch_size,norm)
        else:
            self.complete,self.cohere = NNFs
            self.normalised = self.complete.normalised or self.cohere.normalised
        return
    
    def __rgbToLAB(self,image_arr):
        labImage = color.rgb2lab(image_arr,channel_axis = 0)
        return labImage
    
    def __labToRGB(self,lab_arr):
        rgbImage = color.lab2rgb(lab_arr, channel_axis = 0)
        return rgbImage
        
    def refine_bidirectional_maps(self, num_iters = 4):
        print("Refining Completeness map (Source --> Target) for {} iterations...".format(num_iters))
        self.complete.iterate(num_iters,self.search_alpha)
        print("Getting Coherence map (Target --> Source) for {} iterations...".format(num_iters))
        self.cohere.iterate(num_iters,self.search_alpha)
        return
    
    def get_new_bidirectional_maps(self):
        self.complete = NNF(source,target,self.patch_size)
        self.cohere = NNF(target,source,self.patch_size)
        return
    
    
    def __get_matched_patch(self,NNF_mapping,row,col):
        #helper function to get the indices for the patch being mapped from and the patch being mapped to.
        to_rows = NNF_mapping.dest.shape[1]
        to_cols = NNF_mapping.dest.shape[2]
        
        from_rows = NNF_mapping.source.shape[1]
        from_cols = NNF_mapping.source.shape[2]
        
        to_row = NNF_mapping.matchings[row,col,0]
        to_col = NNF_mapping.matchings[row,col,1]
        
        r = self.patch_size//2
        
        r_left = min(r,col,to_col)
        r_right = min(r,from_cols - (col + 1),to_cols - (to_col + 1))
        r_top = min(r,row,to_row)
        r_bott = min(r,from_rows - (row + 1),to_rows - (to_row + 1))
        
        
        from_patch = ((row - r_top,row + r_bott + 1),(col - r_left,col + r_right + 1))
        to_patch = ((to_row - r_top,to_row + r_bott + 1),(to_col - r_left,to_col + r_right + 1))
        
        return (from_patch,to_patch)
        
    def patch_vote(self,complete_cohere_weights = (0.5,0.5),switch_target = True):
        
        #Torch tensors store images: [colour,rows,cols]
        shape = self.target.shape
        rows = shape[1]
        cols = shape[2]
        chans = shape[0]
        
        
        
        sShape = self.source.shape
        sRows = sShape[1]
        sCols = sShape[2]
        tShape = self.target.shape
        tRows = tShape[1]
        tCols = tShape[2]
        
        
        Ns = sRows*sCols
        Nt = tRows*tCols
        
        
        votes_t_to_s = np.zeros((sShape[0],tRows,tCols))
        votes_s_to_t = np.zeros((sShape[0],tRows,tCols))
        
        num_votes_t_to_s = np.zeros((tRows,tCols))
        num_votes_s_to_t = np.zeros((tRows,tCols))

        
        src = self.source.numpy(force=True)
        
        for row in range(tRows):
            for col in range(tCols):
                t_patch,s_patch = self.__get_matched_patch(self.cohere,row,col)
                
                t_rows,t_cols = t_patch
                s_rows,s_cols = s_patch
                
                t_top,t_bott = t_rows
                t_left,t_right = t_cols
                s_top,s_bott = s_rows
                s_left,s_right = s_cols
                
                patch = src[:,s_top:s_bott,s_left:s_right]

                votes_t_to_s[:,t_top:t_bott,t_left:t_right] += patch
                num_votes_t_to_s[t_top:t_bott,t_left:t_right] += np.ones((patch.shape[1],patch.shape[2]))
                
                

        
        for row in range(sRows):
            for col in range(sCols):
                s_patch,t_patch = self.__get_matched_patch(self.complete,row,col)
                
                s_rows,s_cols = s_patch
                t_rows,t_cols = t_patch
                
                s_top,s_bott = s_rows
                s_left,s_right = s_cols
                t_top,t_bott = t_rows
                t_left,t_right = t_cols
                
                s_patch = src[:,s_top:s_bott,s_left:s_right]
                num_votes_s_to_t[t_top:t_bott,t_left:t_right] += np.ones((s_patch.shape[1],s_patch.shape[2]))
                votes_s_to_t[:,t_top:t_bott,t_left:t_right] += s_patch
                
        votes_t_to_s/=np.maximum(num_votes_t_to_s,1)#avoid zero division
        votes_s_to_t/=np.maximum(num_votes_s_to_t,1)

        num_votes_t_to_s *= complete_cohere_weights[1]*(1/Nt) 
        num_votes_s_to_t *= complete_cohere_weights[0]*(1/Ns) 
                
                
        #make the weights add to 1:
        #add small epsilon to avoid zero division
        vote_sum = num_votes_t_to_s + num_votes_s_to_t + 0.00001
                
                
        num_votes_t_to_s/=vote_sum
        num_votes_s_to_t/=vote_sum
                
        vote_sum = vote_sum/vote_sum.max()#bound from 0 to 1. Basically what is the confidence for a given pixel.
        vote_sum = np.power(vote_sum,self.default_weight)
                

        
        #votes = (votes_t_to_s + votes_s_to_t)/2
        votes = votes_t_to_s*num_votes_t_to_s + votes_s_to_t*num_votes_s_to_t
        votes = votes*vote_sum + self.target.numpy(force = True)*(1 - vote_sum)
        targ_result = torch.Tensor(votes)
        if(switch_target):
            self.target = targ_result
        return targ_result
                
        
        
        

### Helper functions that aided in testing.

In [None]:
def patchup(nn):
    unloader = transforms.ToPILImage()
    dest,src = nn.dest,nn.source
    print(dest.shape)
    proto = np.zeros(dest.shape)
    for row in range(nn.matchings.shape[0]):
        #print(row)
        for col in range(nn.matchings.shape[1]):
            proto[:,nn.matchings[row,col][0],nn.matchings[row,col][1]] = src[:,row,col]
    result = torch.Tensor(proto)
    plt.imshow(unloader(result))
    return result

    

In [None]:
def patchdown(nn):
    unloader = transforms.ToPILImage()
    dest,src = nn.dest,nn.source
    proto = np.zeros(src.shape)
    for row in range(nn.matchings.shape[0]):
        #print(row)
        for col in range(nn.matchings.shape[1]):
            proto[:,row,col] = dest[:,nn.matchings[row,col][0],nn.matchings[row,col][1]]
    result = torch.Tensor(proto)
    plt.imshow(unloader(result))
    return result

# Colour Transfer Section

In [None]:
vgg = models.vgg19(pretrained=True).features.to(device).eval()

In [None]:
def labToRGB(lab_tensor):
        rgbImage = color.lab2rgb(lab_tensor.numpy(force = True), channel_axis = 0)
        rgbTensor = torch.Tensor(rgbImage)
        return rgbTensor

### Implements the error for optimising the transform

In [None]:
class error(nn.Module):
    def __init__(self):
        super(error, self).__init__()

    def similarity_loss(self,omega_L,elp,a,b,source,guidance):

        T = torch.mul(a,source) + b
        #changed the 1 in the loss term to 2 to avoid negative
        return torch.sum(omega_L * (2-(elp)) * torch.sum(torch.pow(T - guidance,2),dim = 0))

    def smoothness_loss(self, omega_lum, a,b):

#         #cut off last row
#         omega_lum_above = omega_lum[:-2,1:-1]
#         #cut off first row
#         omega_lum_below = omega_lum[2:,1:-1]
#         #cut off last col
#         omega_lum_left = omega_lum[1:-1,:-2]
#         #cut off first col
#         omega_lum_right = omega_lum[1:-1,2:]
        omega_lum_above,omega_lum_below,omega_lum_left,omega_lum_right = omega_lum

        left_padding = 1
        right_padding = 1
        top_padding = 1
        bottom_padding = 1
        a_padded = F.pad(a,(left_padding,right_padding,top_padding,bottom_padding),mode = 'reflect')
        
        a_sum_above_below = (a_padded[:,1:,1:-1] - a_padded[:,:-1,1:-1])
        a_sum_left_right = (a_padded[:,1:-1,1:] - a_padded[:,1:-1,:-1])
        a_sum_above = torch.sum(torch.pow(a_sum_above_below[:,:-1],2),dim = 0)
        a_sum_below = torch.sum(torch.pow(a_sum_above_below[:,1:],2),dim = 0)
        a_sum_left = torch.sum(torch.pow(a_sum_left_right[:,:,:-1],2),dim = 0)
        a_sum_right = torch.sum(torch.pow(a_sum_left_right[:,:,1:],2),dim = 0)

        b_padded = F.pad(b,(left_padding,right_padding,top_padding,bottom_padding),mode = 'reflect')
        
        b_sum_above_below = (b_padded[:,1:,1:-1] - b_padded[:,:-1,1:-1])
        b_sum_left_right = (b_padded[:,1:-1,1:] - b_padded[:,1:-1,:-1])
        b_sum_above = torch.pow(b_sum_above_below[:,:-1],2)
        b_sum_below = torch.pow(-b_sum_above_below[:,1:],2)
        b_sum_left = torch.pow(b_sum_left_right[:,:,:-1],2)
        b_sum_right = torch.pow(-b_sum_left_right[:,:,1:],2)

        res_above = omega_lum_above*(a_sum_above + b_sum_above)
        res_below = omega_lum_below*(a_sum_below + b_sum_below)
        res_left = omega_lum_left*(a_sum_left + b_sum_left)
        res_right = omega_lum_right*(a_sum_right + b_sum_right)

        return torch.sum(res_above + res_below + res_left + res_right)




    def forward(self, omega_L,omega_lum,elp,a,b,source,guidance):

        return self.similarity_loss(omega_L,elp,a,b,source,guidance) + 0.0000128*self.smoothness_loss(omega_lum,a,b)

### The class which handles the colour transfer

In [None]:
class singleImageColourTransfer:
    def __init__(self,source_image,style_image,patch_size = 7, normalised = True,replace = None,
                 rollingAv = False,sample_feats = False,output_path = "./outputs"):
        
        self.sample_feats = sample_feats
        feat_dict = {}
        i = 0
        max_pools = 0
        prev_pool = -1
        while(len(feat_dict) < 5 and i < len(vgg)):
            current = vgg[i]
            if isinstance(current,nn.ReLU):
                if(max_pools != prev_pool):
                    feat_dict[i] = "relu-{}_1".format(max_pools + 1)
                    prev_pool += 1
            elif isinstance(current,nn.MaxPool2d):
                max_pools += 1
            i += 1
        self.unloader = transforms.ToPILImage()
        self.ii = transforms.Compose([transforms.ToTensor()])
        self.feature_map = models.feature_extraction.create_feature_extractor(vgg,feat_dict)
        self.source = source_image
        self.reference = style_image
        self.patch_size = patch_size
        if(self.patch_size is None):
            ps = round(np.sqrt(self.source.shape[1]*self.source.shape[2]/300))
            self.patch_size = min(max( 3 , ps+abs(ps%2 - 1) ),7)
            print("Patch size:",self.patch_size)
        self.normalised = normalised
        self.rollingAv = rollingAv
        
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        #Files that start with "." are invisible files like .DS_Store
        md = os.listdir(output_path)
        
        #by default, model number is the earliest not taken model number
        mod_set = set()
        default_dest = None
        for model in md:
            if(model[0] == "."):
                pass
            mod_set.add(model)
        for i in range(1,len(mod_set) + 2):
            if not ("model_{}".format(i) in mod_set):
                default_dest = "model_{}".format(i)
                break
            
        
        
        self.model_name = None
        if replace is None:
            self.model_name = default_dest
            
        else:
            self.model_name = replace if isinstance(replace,str) else "model_{}".format(replace)
            
        print("Initialising " + self.model_name + "...")
        self.output_path = "{}/{}".format(output_path,self.model_name)
        if not(replace is None):
            os.system('rm -rf {}/*'.format(self.output_path))
        if not os.path.exists(self.output_path + "/intermediate_sources"):
            os.makedirs(self.output_path + "/intermediate_sources")
        if not os.path.exists(self.output_path + "/guidance"):
            os.makedirs(self.output_path + "/guidance")
        
        self.intermediate_sources = self.output_path + "/intermediate_sources"
        self.guidance_images = self.output_path + "/guidance"
        #Intermediate source starts at being the source image
        self.S_L = source_image
        self.S_Prev = None
        print("Finished initialisation.")
        return
        
    
    #=========================================================================================================#
    
    
    def __rgbToLAB(self,image_tensor):
        labImage = color.rgb2lab(image_tensor.numpy(force = True),channel_axis = 0)
        labTensor = torch.Tensor(labImage)
        return labTensor
    
    def __labToRGB(self,lab_tensor):
        rgbImage = color.lab2rgb(lab_tensor.numpy(force = True), channel_axis = 0)
        rgbTensor = torch.Tensor(rgbImage)
        return rgbTensor
    
    
    #=========================================================================================================#
    
    
    def limit_channels(self,image1,image2,num_channels = None):
        #Experiment: I would like to see if good results can be achieved even leaving out information
        image_channels = image1.shape[0]
        
        if num_channels is None:
            num_channels = round(image_channels/8)
        
        if(image_channels <= num_channels):
            return image1,image2

        rng = np.random.default_rng()
        indices = rng.choice(image_channels, size=num_channels, replace=False)
        
        return image1[indices],image2[indices]
    
        
        
    
    #=========================================================================================================#
        
    def construct_guidance(self,level,iters = 8,com_coh_weights = (8/9,1/9),source_weight = 0.4):
        #Intermediate source is S_L
        level_name = None
        F_LS = None
        F_LR = None
        if level != 0:
            level_name = "relu-{}_1".format(level)
            bw = lambda x: ii(unloader(x).convert("L").convert("RGB"))
            F_LS = self.feature_map(self.S_L)[level_name]
            F_LR = self.feature_map(self.reference)[level_name]
        else:
            F_LS = self.S_L
            F_LR = self.reference
        
        sShape = F_LS.shape
        sRows = sShape[1]
        sCols = sShape[2]
        rShape = F_LR.shape
        rRows = rShape[1]
        rCols = rShape[2]
        
        if(self.sample_feats):
            F_LS,F_LR = self.limit_channels(F_LS,F_LR)
        
#         sqLen = int(((sRows*sCols)/2 + (rRows*rCols)/2)**0.5)
#         pLen = int(0.08*sqLen)
#         pLen = pLen if pLen%2 == 1 else pLen + 1
        
#         patchsize = int(max(pLen,3))
        guidance_map = None
        #if(hasattr(self,"NNF_matchings")):
            #guidance_NNFs = (restricted_NNF(F_LR,F_LS,self.NNF_matchings[0],patch_size = self.patch_size, norm = self.normalised),
                            #restricted_NNF(F_LS,F_LR,self.NNF_matchings[1],patch_size = self.patch_size, norm = self.normalised))
            #guidance_map = BDS(F_LR,F_LS,patch_size = self.patch_size,NNFs = guidance_NNFs,norm = self.normalised)
            
        #else:
            #guidance_map = BDS(F_LR,F_LS,patch_size = self.patch_size,norm = self.normalised)
        guidance_map = BDS(F_LR,F_LS,patch_size = self.patch_size,norm = self.normalised,default_weight = source_weight)
        guidance_map.refine_bidirectional_maps(iters)
        #if not hasattr(self,"NNF_matchings"):
         #   self.NNF_matchings = (guidance_map.complete.matchings,guidance_map.cohere.matchings)
        

        F_G = guidance_map.patch_vote(complete_cohere_weights = com_coh_weights,switch_target = False)
        
        
        print("FG shape", F_G.shape, "FLS shape",F_LS.shape)
        
        
        resize_source = transforms.Resize(F_LS.shape[1:])
        resize_ref = transforms.Resize(F_LR.shape[1:])
        pretrained_nnfs = (guidance_map.complete, guidance_map.cohere)
        rr = resize_ref(self.reference)
        rs = resize_source(self.S_L)
        print("fs", F_LS.shape,"fr", F_LR.shape, "rs", rs.shape,"rr",rr.shape)
        guidance_constructor = BDS(rr,rs,patch_size = self.patch_size,
                                   NNFs = pretrained_nnfs, norm = self.normalised,default_weight = source_weight)
        G = guidance_constructor.patch_vote(complete_cohere_weights = com_coh_weights,switch_target = False)
        return (G,rs,F_G,F_LS)
    
    #def smoothness_loss(self,)
    
    #TODO: Implement smoothness and nonlocal loss
#     def cluster_semantic(self,S):
#         sShape = s.shape
#         sRows = sShape[1]
#         sCols = sShape[2]
#         #basically flatten 
#         cielab = self.__rgbToLAB(self.feature_map(S)["relu-5_1"]).reshape(-1,sRows*sCols)
        
    #=========================================================================================================#
    
    
        
        
    #=========================================================================================================#
        
    
    def initialise_ab(self,G,patch_radius = 1):
        window_shape = (G.shape[0],patch_radius*2 + 1,patch_radius*2 + 1)
        resize_source = transforms.Resize(G.shape[1:])
        
        resize_source
        
        #pad so we can do window view
        pad_dims = ((0,0),(patch_radius,patch_radius),(patch_radius,patch_radius))
        src_padded = torch.Tensor(np.pad(resize_source(self.__rgbToLAB(self.source)).detach().numpy(),pad_dims,mode = 'reflect'))
        g_padded = torch.Tensor(np.pad(G.detach().numpy(),pad_dims,mode = 'reflect'))
        
        #put window view
        sWin = view_as_windows(src_padded.numpy(),window_shape)
        gWin = view_as_windows(g_padded.numpy(),window_shape)
        
        
        #Go through and set each section of a as SD of patch in G over SD of patch in S(+epsilon, avoid zero division)
        #Go through and set each section of b as mean of patch in G - a at location times mean of S patch
        a = np.zeros(G.shape)
        b = np.zeros(G.shape)
        epsilon = 0.002
        axes = (1,2)
        for row in range(G.shape[1]):
            for col in range(G.shape[2]):
                gPatch = gWin[:,row,col].squeeze()
                sPatch = sWin[:,row,col].squeeze()
                
                g_Patch = gPatch.T.reshape((-1,3))
                s_Patch = sPatch.T.reshape((-1,3))
                
                g_std = np.std(g_Patch,axis = 0)
                s_std = np.std(s_Patch,axis = 0)
                
                g_mean = np.mean(g_Patch, axis = 0)
                s_mean = np.mean(s_Patch, axis = 0)
                
#                 g_std = np.array([np.std(gPatch[0]),np.std(gPatch[1]),np.std(gPatch[2])])
#                 s_std = np.array([np.std(sPatch[0]),np.std(sPatch[1]),np.std(sPatch[2])])
                
#                 g_mean = np.array([np.mean(gPatch[0]),np.mean(gPatch[1]),np.mean(gPatch[2])])
#                 s_mean = np.array([np.mean(sPatch[0]),np.mean(sPatch[1]),np.mean(sPatch[2])])
                
                a[:,row,col] = g_std/(s_std + epsilon)
                b[:,row,col] = g_mean - a[:,row,col]*s_mean
        a = torch.Tensor(a)
        a.requires_grad_()
        b = torch.Tensor(b)
        b.requires_grad_()
        return (a,b)
    
    
    #=========================================================================================================#
    
    
    
    def transfer_at_level(self,level,optim_iters = 10,nnf_iters = 8,complete_cohere_weights = (4/9,5/9),source_weight = 0.4):
        G,_,F_G,F_S = self.construct_guidance(level,iters = nnf_iters,com_coh_weights = complete_cohere_weights,source_weight = source_weight)
        guidance_filename = self.guidance_images + "/guidance_{}.png".format(level)
        intermediate_source_filename = self.intermediate_sources + "/source_{}.png".format(level)
        
        resize_source = transforms.Resize(G.shape[1:])
        S = self.__rgbToLAB(resize_source(self.source))
        S = S.detach()
        G = self.__rgbToLAB(G)
        G = G.detach()
        F_G
        S.requires_grad_(False)
        F_G = F.normalize(F_G, dim = 0)
        F_G = F_G.detach()
        F_S = F.normalize(F_S, dim = 0)
        F_S = F_S.detach()
        
        
        
        
        
        
        #Create constant input values for similarity term
        omega_L = 4**(level - 1)
        #British people asking for assistance(it's actually normalised matching error though):
        elp = torch.sum(torch.pow((F_S - F_G),2),dim = 0)
        
        print("elp shape",elp.shape)
        
        
        
        #creates constant input value for smoothness term. This value needs to
        s_left_padding = 1
        s_right_padding = 1
        s_top_padding = 1
        s_bottom_padding = 1
        S_L_channel_padded = F.pad(S[0].reshape(tuple([1] + list(S[0].shape))),
                                   (s_left_padding,s_right_padding,s_top_padding,s_bottom_padding),
                                   mode = 'reflect').squeeze()

        sum_above_below = (S_L_channel_padded[1:,1:-1] - S_L_channel_padded[:-1,1:-1])
        sum_left_right = (S_L_channel_padded[1:-1,1:] - S_L_channel_padded[1:-1,:-1])
        sum_above = torch.pow(sum_above_below[:-1],2)
        sum_below = torch.pow(-sum_above_below[1:],2)
        sum_left = torch.pow(sum_left_right[:,:-1],2)
        sum_right = torch.pow(-sum_left_right[:,1:],2)
        epsilon_lum = 0.00001
        alpha = 1.2
        #omega_lum = torch.pow(torch.pow(torch.pow(sum_above + sum_below + sum_left + sum_right,0.5),alpha) + epsilon_lum,-1)
        omega_lum_above = torch.pow(torch.pow(torch.pow(sum_above,0.5),alpha) + epsilon_lum,-1)
        omega_lum_below = torch.pow(torch.pow(torch.pow(sum_below,0.5),alpha) + epsilon_lum,-1)
        omega_lum_left = torch.pow(torch.pow(torch.pow(sum_left,0.5),alpha) + epsilon_lum,-1)
        omega_lum_right = torch.pow(torch.pow(torch.pow(sum_right,0.5),alpha) + epsilon_lum,-1)
        omega_lum = (omega_lum_above,omega_lum_below,omega_lum_left,omega_lum_right)
        #omega_lum = F.pad(omega_lum,(1,1,1,1),mode = "constant", value = 0)
        
        
        
        
        
        
        
        a,b = self.initialise_ab(G)
        print("A size", a.shape, "B size", b.shape, "Omega Size",omega_lum[0].shape, "S size", S.shape )
        a.requires_grad_(True)
        b.requires_grad_(True)
        
        #T = torch.mul(a,S) + b
        #plt.imshow(unloader(labToRGB(T)))
        #raise()
            
        
        
        optimiser = torch.optim.Adam([a,b],lr = 0.001)#0.001
        loss_fn = error()
        for i in range(optim_iters):
            # Forward pass
            loss = loss_fn(omega_L,omega_lum,elp,a,b,S,G)
            print("\tLoss: iter {}\n\t\t{}".format(i,loss))
            # Backward and optimize
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
        u = nn.Upsample((self.source.shape[1],self.source.shape[2]),mode = "bilinear")
        upsamp = lambda x: u(x.reshape(tuple([1] + list(x.shape)))).squeeze()
        
        
        a_up = upsamp(a.detach())
        b_up = upsamp(b.detach())
        self.S_L = self.__labToRGB(torch.mul(a_up,self.__rgbToLAB(self.source)) + b_up)
        
        
        
        
        #self.S_L = torch.div(self.S_L,torch.max(self.S_L))
        if (not self.S_Prev is None) and self.rollingAv:
            self.S_L = self.S_L*0.4 + self.S_Prev*0.6
        if self.S_Prev is None:
            self.S_Prev = self.S_L
            
            
        self.unloader(self.__labToRGB(G)).save(guidance_filename)
        self.unloader(self.S_L).save(intermediate_source_filename)
        return
    
    
    #=========================================================================================================#
    
    
    def transfer(self,optim_iters = 617,nnf_iters = 4,complete_cohere_weights = (2/3,1/3),source_weight = 0.4,
                 num_passes = 2,gif = True,gif_length_secs = 5):
        im = None
        ms = gif_length_secs*1000
        gif_name = self.output_path + '/' + self.model_name + '.gif'
        gif_folder = gif_name[:-4] + '_temp/'
        frame_names = [gif_folder + "{}.gif".format(i + 1) for i in range(5*num_passes + 1)]
        if(gif):
            if not os.path.exists(gif_folder):
                os.makedirs(gif_folder)
            im = self.unloader(self.source)
            im.save(frame_names[0]) 
        nth = 1
        for j in range(num_passes):
            for i in range(5,0,-1):
                if(i == 5):
                    nnf = 42
                else:
                    nnf = nnf_iters
                self.transfer_at_level(i,optim_iters = optim_iters,nnf_iters = nnf,complete_cohere_weights = complete_cohere_weights,
                                      source_weight = source_weight)
                if gif:
                    frame = self.unloader(self.S_L)
                    frame.save(frame_names[nth])
                    nth += 1
            self.source = self.S_L
        if gif:
            images = []
            for n in frame_names:
                frame = Image.open(n)
                images.append(frame)
            for i in range(len(frame_names) - 2, -1, -1):
                frame = Image.open(frame_names[i])
                images.append(frame)
            images.append(self.unloader(self.source))
            images[0].save(gif_name,
               save_all=True,
               append_images=images[1:],
               duration=int(ms/(2*len(frame_names) - 1)),
               loop=617)
            
            shutil.rmtree(gif_folder)
            
            
                
        return
        
        
        
                
                
                
        

        
    

# Testing and Helper Functions

In [None]:
from PIL import Image, ImageSequence

def slow_down_gif(path = "./outputs/model_1/city.gif",new_length_seconds = 10):
    ms = new_length_seconds *1000
    im = Image.open(path)
    index = 1
    images = []
    for frame in ImageSequence.Iterator(im):
        images.append(frame)
        index += 1
    os.remove(path)
    images[0].save(path,
                   save_all=True,
                   append_images=images[1:],
                   duration=int(ms/len(images)),
                   loop=617)
    return

In [None]:
sources = "./source_images/"
ii = transforms.Compose([transforms.ToTensor()])
c_night = ii(Image.open(sources + "city_night.jpg").convert('RGB'))
c_day = ii(Image.open(sources + "city_day.jpg").convert('RGB'))
resize = transforms.Resize(100)
c_night = resize(c_night)
c_day = resize(c_day)
unloader = transforms.ToPILImage()

In [None]:
day = unloader(c_day)

In [None]:
sources = "./source_images/"
car_colours = ii(Image.open(sources + "有很多顏色的汽車.jpg").convert('RGB'))
yellow_car = ii(Image.open(sources + "黃色的汽車.jpg").convert('RGB'))
purp_car = ii(Image.open(sources + "idek.webp").convert('RGB'))
resize = transforms.Resize(100)
car_colours = resize(car_colours)
yellow_car = resize(yellow_car)
purp_car = resize(purp_car)

In [None]:
sources = "./source_images/"
cat_geo = ii(Image.open(sources + "cat_geo.png").convert('RGB'))
orange = ii(Image.open(sources + "orange.jpg").convert('RGB'))
car_bw = ii(Image.open(sources + "car_bw.png").convert('RGB'))

resize = transforms.Resize(100)
cat_geo = resize(cat_geo)
orange = resize(orange)
car_bw = resize(car_bw)

In [None]:
sources = "./source_images/"
ii = transforms.Compose([transforms.ToTensor()])
wren = ii(Image.open(sources + "house_wren.jpg").convert('RGB'))
cardinal = ii(Image.open(sources + "northern_cardinal.jpg").convert('RGB'))
resize = transforms.Resize(100)
wren = resize(wren)
cardinal = resize(cardinal)
unloader = transforms.ToPILImage()

In [None]:
sources = "./source_images/"
ii = transforms.Compose([transforms.ToTensor()])
parrot = ii(Image.open(sources + "parrot.jpg").convert('RGB'))
blackbird = ii(Image.open(sources + "blackbird.jpg").convert('RGB'))
resize = transforms.Resize(100)
parrot = resize(parrot)
blackbird = resize(blackbird)
unloader = transforms.ToPILImage()

In [None]:
sources = "./source_images/"
ii = transforms.Compose([transforms.ToTensor()])
sfx_face = ii(Image.open(sources + "sfx.jpg").convert('RGB'))
no_sfx = ii(Image.open(sources + "no_sfx.jpg").convert('RGB'))
resize = transforms.Resize(100)
sfx_face = resize(sfx_face)
no_sfx = resize(no_sfx)
unloader = transforms.ToPILImage()

In [None]:
si = singleImageColourTransfer(blackbird,parrot,patch_size = 7,normalised = True,replace = 1,
                               rollingAv = False)


In [None]:
si.transfer(optim_iters = int(6170),nnf_iters = 4,
            complete_cohere_weights = (0.3,0.7),source_weight = 0.617*0,
            num_passes = 1, gif_length_secs = 10)

In [None]:
def compare_to_paper(size = 100,sample = False):
    o_path = "./paper_outputs"
    if(os.path.exists(o_path)):
        shutil.rmtree(o_path)
    source_path = "./Paper_Images/"
    num_images = 5
    ii = transforms.Compose([transforms.ToTensor()])
    resize = transforms.Resize(size)
    
    get_names = lambda x: ("in{}.png".format(x),"tar{}.png".format(x))
    
    transform_image = lambda name: resize(ii(Image.open(source_path + name).convert('RGB')))
    
    sources = []
    targets = []
    
    for i in range(num_images):
        src,trg = get_names(i)
        sources.append(transform_image(src))
        targets.append(transform_image(trg))
        
    for i in range(num_images):
        si = singleImageColourTransfer(sources[i],targets[i],patch_size = 3, normalised = True,replace = None,
                                       rollingAv = False,sample_feats = sample,output_path = o_path)
        
        si.transfer(optim_iters = int(6170),nnf_iters = 4,
                    complete_cohere_weights = (0.3,0.7),source_weight = 0.617*0,
                    num_passes = 1, gif_length_secs = 10)
        
    return

In [None]:
compare_to_paper(256,True)

In [None]:
plt.subplot(1,2,1)
plt.imshow(unloader(ds))
plt.title('Source Image')
plt.axis('off')

plt.subplot(1,2,2)
plt.imshow(unloader(sr))
plt.title('Target Image')
plt.axis('off')

In [None]:
#use this for wls upscale: https://www.cs.huji.ac.il/~danix/epd/epd.pdf