## Understand how batch match work 

the below code is a simplified version of patch match for feature map



In [4]:
import os 
import torch 
import torch.nn as nn 
import torch.optim as optim 
import torchvision.transforms as transforms 
import torchvision.models as models
import torchvision
import torch.nn.functional as F 
from PIL import Image 
import argparse
import copy
import math 
import numpy as np 
import scipy.interpolate as interpolate
import matplotlib.pyplot as plt

In [32]:
style_fm = torch.rand((1, 64, 15, 15))
img_fm = torch.rand((1, 64, 15, 15))


In [33]:
    def match_fm(style_fm, img_fm):
        '''
        Input : 
            style_fm, img_fm : 1 * C * H * W

        Process:
            Instead of only compute the match for pixel inside the mask, here we compute the 
                match for the whole feature map and use mask to filter out the matched pixel we don't need 
            For feature map of size B * C * H * W, one unit of patch is by default set to be C * 3 * 3 and L2 
                Norm is computed on top of it.
            Patch Match is done in convolution fashion where content image's patch is used as kernal and applied 
                to the styled image, a score map will be computed and we construct the matched style feature map 
                on top of the score map 
            Content Feature Map is set to be unchanged and we split the Style Feature Map and create a new 
                matched style feature map 
        '''

        # Create a copy of style image fm (to matain size)
        style_fm_masked = style_fm.clone() # create a copy of the same size 
        n_patch_h = math.floor(style_fm_masked.shape[3] / 3) # use math package to avoid potential python2 issue 
        n_patch_w = math.floor(style_fm_masked.shape[2] / 3)
        print('style_fm_masked shape', style_fm_masked.shape)

        # TODO this function have not yet been tested yet
        # Use content image feature map as kernal and compute score map on style image
        stride = 3
        for i in range(n_patch_h):
            for j in range(n_patch_w):
                print('For i : {} j : {}'.format(str(i), str(j)))
                
                # Each kernal represent a patch in content image 
                kernal = img_fm[:, :, i*3:(i+1)*3, j*3:(j+1)*3]
                print('kernal size', kernal.shape)

                # For F.conv2d : https://pytorch.org/docs/stable/nn.functional.html
                score_map = F.conv2d(style_fm, kernal, stride=stride)
                score_map = score_map[0, 0, :, :]
                print('score_map shape', score_map.shape)
                print('score map is', score_map)
                
                idx = torch.argmax(score_map).item()
                matched_style_idx = (int(idx / score_map.shape[1]), int(idx % score_map.shape[1]))
                
                print('matched_style_idx {}'.format(str(matched_style_idx)))
                
                style_fm_masked[:, :, i*3:(i+1)*3, j*3:(j+1)*3] = style_fm[:, :, 
                        matched_style_idx[0]*3: (matched_style_idx[0]+1)*3, \
                            matched_style_idx[1]*3: (matched_style_idx[1]+1)*3]

        return style_fm_masked

In [34]:
out = match_fm(style_fm, img_fm)

style_fm_masked shape torch.Size([1, 64, 15, 15])
For i : 0 j : 0
kernal size torch.Size([1, 64, 3, 3])
score_map shape torch.Size([5, 5])
score map is tensor([[140.2602, 137.7970, 140.2724, 134.9038, 145.0046],
        [138.7862, 143.9214, 134.7990, 141.7458, 142.0261],
        [140.4134, 137.0959, 132.2034, 128.7953, 141.7768],
        [138.5422, 136.0244, 141.8092, 138.5504, 137.4596],
        [135.9445, 129.9725, 141.7546, 146.6207, 137.0418]])
matched_style_idx (4, 3)
For i : 0 j : 1
kernal size torch.Size([1, 64, 3, 3])
score_map shape torch.Size([5, 5])
score map is tensor([[141.8513, 142.4510, 145.1426, 138.9975, 145.7724],
        [136.0936, 141.9220, 136.1649, 143.5208, 142.1674],
        [139.0216, 141.4182, 137.4155, 131.4725, 140.9673],
        [140.1776, 136.9272, 142.6810, 139.5533, 139.0167],
        [138.2072, 134.0869, 141.8370, 143.8979, 138.4894]])
matched_style_idx (0, 4)
For i : 0 j : 2
kernal size torch.Size([1, 64, 3, 3])
score_map shape torch.Size([5, 5])
score