# Style Transfer by Relaxed Optimal Transport and Self-Similarity + Auto mask pyramid

import the module we need

In [1]:

import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from imageio import imread
import numpy as np
import os
import math
import PIL
from time import time
from Patch_Match import PatchMatch

patch_match_feat_layer=29
AUTO_MASK_PYRAMID=False

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
if torch.cuda.is_available():
  print("Using the GPU!")
elif torch.backends.mps.is_available():
  print("Using the MPS!")
else:
  print("WARNING: Could not find GPU! Using CPU only. If you want to enable GPU, please to go Edit > Notebook Settings > Hardware Accelerator and select GPU.")

PyTorch Version:  2.2.0+cu118
Torchvision Version:  0.17.0+cu118
Using the GPU!


## Step 1: helper functions

### Tensor and PIL utils

In [2]:
use_random=True

# load PIL image
def pil_loader(path):
    with open(path, 'rb') as f:
        img = PIL.Image.open(f)
        return img.convert('RGB')
# tensor插值
def tensor_resample(tensor, dst_size, mode='bilinear'):
    return F.interpolate(tensor, dst_size, mode=mode, align_corners=False)
# image resize by short edge
def pil_resize_short_edge_to(pil, trg_size):
    short_w = pil.width < pil.height
    ar_resized_short = (trg_size / pil.width) if short_w else (trg_size / pil.height)
    resized = pil.resize((int(pil.width * ar_resized_short), int(pil.height * ar_resized_short)), PIL.Image.BICUBIC)
    return resized
# image resize by long edge
def pil_resize_long_edge_to(pil, trg_size):
    short_w = pil.width < pil.height
    ar_resized_long = (trg_size / pil.height) if short_w else (trg_size / pil.width)
    resized = pil.resize((int(pil.width * ar_resized_long), int(pil.height * ar_resized_long)), PIL.Image.BICUBIC)
    return resized

def np_to_pil(npy):
    return PIL.Image.fromarray(npy.astype(np.uint8))

def pil_to_np(pil):
    return np.array(pil)

In [3]:
def tensor_to_np(tensor, cut_dim_to_3=True):
    if len(tensor.shape) == 4:
        if cut_dim_to_3:
            tensor = tensor[0]
        else:
            return tensor.detach().cpu().numpy().transpose((0, 2, 3, 1))
    return tensor.detach().data.cpu().numpy().transpose((1,2,0))

def np_to_tensor(npy, space):
    if space == 'vgg':
        return np_to_tensor_correct(npy)
    return (torch.Tensor(npy.astype(np.float64) / 127.5) - 1.0).permute((2,0,1)).unsqueeze(0)

def np_to_tensor_correct(npy):
    pil = np_to_pil(npy)
    transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(pil).unsqueeze(0)

### Laplacian Pyramid

In [4]:
def laplacian(x):
    # x - upsample(downsample(x))
    return x - tensor_resample(tensor_resample(x, [x.shape[2] // 2, x.shape[3] // 2]), [x.shape[2], x.shape[3]])

def make_laplace_pyramid(x, levels):
    pyramid = []
    current = x
    for i in range(levels):
        pyramid.append(laplacian(current))
        current = tensor_resample(current, (max(current.shape[2] // 2,1), max(current.shape[3] // 2,1)))
    pyramid.append(current)
    return pyramid

def fold_laplace_pyramid(pyramid):
    current = pyramid[-1]
    for i in range(len(pyramid)-2, -1, -1): # iterate from len-2 to 0
        up_h, up_w = pyramid[i].shape[2], pyramid[i].shape[3]
        current = pyramid[i] + tensor_resample(current, (up_h,up_w))
    return current

### Feature sampling

In [5]:
def sample_indices(feat_content, feat_style_all, r, ri, xx, xy, yx):

    indices = None
    const = 128**2 # 32k or so

    feat_style =  feat_style_all[ri]

    for i in range(len(feat_style)):
        
        feat_cont = feat_content[i]
        d = feat_style[i].size(1)
        feat_style_st = feat_style[i].view(1,d,-1,1)
        big_size = feat_cont.shape[2] * feat_cont.shape[3] # num feaxels

        stride_x = int(max(math.floor(math.sqrt(big_size//const)),1))
        offset_x = np.random.randint(stride_x)
        stride_y = int(max(math.ceil(math.sqrt(big_size//const)),1))
        offset_y = np.random.randint(stride_y)
        xx_arr, xy_arr = np.meshgrid(np.arange(feat_cont.shape[2])[offset_x::stride_x], np.arange(feat_cont.shape[3])[offset_y::stride_y])

        xx_arr = np.expand_dims(xx_arr.flatten(),1)
        xy_arr = np.expand_dims(xy_arr.flatten(),1)
        xc = np.concatenate([xx_arr,xy_arr], 1)

        region_mask = r

        try:
            xc = xc[region_mask[xy_arr[:,0],xx_arr[:,0]], :]
        except:
            region_mask = region_mask[:,:]
            xc = xc[region_mask[xy_arr[:,0],xx_arr[:,0]], :]
        
        xx[ri].append(xc[:,0])
        xy[ri].append(xc[:,1])

        feat_result = np.arange(feat_style_st.size(2)).astype(np.int32)
        yx[ri].append(feat_result)

def get_feature_indices(xx_dict, xy_dict, yx_dict, ri=0, i=0, cnt=32**2):

    xx = xx_dict[ri][i][:cnt]
    xy = xy_dict[ri][i][:cnt]
    yx = yx_dict[ri][i][:cnt]

    return xx, xy, yx


### Extract spatial features from VGG layers

In [6]:
def spatial_feature_extract(feat_result, feat_content, xx, xy):

    l2, l3 = [], []
    device = feat_result[0].device

    # for each extracted layer
    for i in range(len(feat_result)):
        fr = feat_result[i]
        fc = feat_content[i]

        # hack to detect reduced scale
        if i>0 and feat_result[i-1].size(2) > feat_result[i].size(2):
            xx = xx/2.0
            xy = xy/2.0


        # go back to ints and get residual
        xxm = np.floor(xx).astype(np.float32)
        xxr = xx - xxm

        xym = np.floor(xy).astype(np.float32)
        xyr = xy - xym

        # do bilinear resample
        w00 = torch.from_numpy((1.-xxr)*(1.-xyr)).float().view(1, 1, -1, 1).to(device)
        w01 = torch.from_numpy((1.-xxr)*xyr).float().view(1, 1, -1, 1).to(device)
        w10 = torch.from_numpy(xxr*(1.-xyr)).float().view(1, 1, -1, 1).to(device)
        w11 = torch.from_numpy(xxr*xyr).float().view(1, 1, -1, 1).to(device)

        xxm = np.clip(xxm.astype(np.int32),0,fr.size(2)-1)
        xym = np.clip(xym.astype(np.int32),0,fr.size(3)-1)

        s00 = xxm*fr.size(3)+xym
        s01 = xxm*fr.size(3)+np.clip(xym+1,0,fr.size(3)-1)
        s10 = np.clip(xxm+1,0,fr.size(2)-1)*fr.size(3)+(xym)
        s11 = np.clip(xxm+1,0,fr.size(2)-1)*fr.size(3)+np.clip(xym+1,0,fr.size(3)-1)

        fr = fr.view(1,fr.size(1),fr.size(2)*fr.size(3),1)
        fr = fr[:,:,s00,:].mul_(w00).add_(fr[:,:,s01,:].mul_(w01)).add_(fr[:,:,s10,:].mul_(w10)).add_(fr[:,:,s11,:].mul_(w11))

        fc = fc.view(1,fc.size(1),fc.size(2)*fc.size(3),1)
        fc = fc[:,:,s00,:].mul_(w00).add_(fc[:,:,s01,:].mul_(w01)).add_(fc[:,:,s10,:].mul_(w10)).add_(fc[:,:,s11,:].mul_(w11))

        l2.append(fr)
        l3.append(fc)

    x_st = torch.cat([li.contiguous() for li in l2],1)
    c_st = torch.cat([li.contiguous() for li in l3],1)

    xx = torch.from_numpy(xx).view(1,1,x_st.size(2),1).float().to(device)
    yy = torch.from_numpy(xy).view(1,1,x_st.size(2),1).float().to(device)
    
    x_st = torch.cat([x_st,xx,yy],1)
    c_st = torch.cat([c_st,xx,yy],1)
    return x_st, c_st

### Color and distance helper functions

In [7]:
def rgb_to_yuv(rgb):
    C = torch.Tensor([[0.577350,0.577350,0.577350],[-0.577350,0.788675,-0.211325],[-0.577350,-0.211325,0.788675]]).to(rgb.device)
    yuv = torch.mm(C,rgb)
    return yuv

def pairwise_distances_cos(x, y):
    x_norm = torch.sqrt((x**2).sum(1).view(-1, 1))
    y_t = torch.transpose(y, 0, 1)
    y_norm = torch.sqrt((y**2).sum(1).view(1, -1))
    dist = 1.-torch.mm(x, y_t)/x_norm/y_norm
    return dist

def pairwise_distances_sq_l2(x, y):
    x_norm = (x**2).sum(1).view(-1, 1)
    y_t = torch.transpose(y, 0, 1)
    y_norm = (y**2).sum(1).view(1, -1)
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 1e-5, 1e5)/x.size(1)

### Create mask by ignoring some color

In [8]:
def create_mask_from_image(image, ignore_color=[0, 0, 0]):
    """
    Create a mask from an image, where pixels matching the ignore_color are set to 0, and others to 1.

    :param image: input image.
    :param ignore_color: Color to be ignored, default is black ([0, 0, 0]).
    :return: Mask tensor of shape (1, H, W), where 1 indicates important areas and 0 indicates areas to ignore.
    """

    # Check if the image is grayscale or RGB
    if len(image.shape) == 2:  # Grayscale image
        mask = image != ignore_color[0]
    else:  # RGB image
        mask = np.all(image != ignore_color, axis=-1)

    mask = torch.from_numpy(mask).unsqueeze(0).float()

    return mask

### Regions building functions

In [9]:
def extract_color_regions(content_path, style_path):
    s_regions = imread(style_path).transpose(1,0,2)
    c_regions = imread(content_path).transpose(1,0,2)

    color_codes,c1 = np.unique(s_regions.reshape(-1, s_regions.shape[2]), axis=0,return_counts=True)
    
    color_codes = color_codes[c1>10000]

    c_out = []
    s_out = []

    for c in color_codes:
        c_expand =  np.expand_dims(np.expand_dims(c,0),0)
        
        s_mask = np.equal(np.sum(s_regions - c_expand,axis=2),0).astype(np.float32)
        c_mask = np.equal(np.sum(c_regions - c_expand,axis=2),0).astype(np.float32)

        s_out.append(s_mask)
        c_out.append(c_mask)
    return [c_out,s_out]

In [10]:
def build_from_matches(matchs,num_of_maxpool,img1,img2):
    # 首先我们要找到feat11、feat22中每一个坐标分别对应了其原图中的感受野是那一片
    # 然后regions是我们要返回的区域对应关系
    # regions是若干对mask的list，其中每一对mask标出了feat11一个元素对应的原图中感受野，另一个mask标出了feat22一个元素对应的原图中感受野
    # 首先要找出对应了相同的x,y的点，然后根据这些点找出对应的感受野
    dict={}
    for i in range(matchs.shape[0]):
        for j in range(matchs.shape[1]):
            x1,y1=i,j
            x2,y2=matchs[i,j]
            if (x2,y2) not in dict:
                dict[(x2,y2)]=[]
            dict[(x2,y2)].append((x1,y1))
    c_out,s_out=[],[]
    l=2**num_of_maxpool+2 # 18
    for key,v in dict.items():
        x2,y2=key
        mask1=np.zeros(img1.size)
        mask2=np.zeros(img2.size)
        mask2[x2*l:min((x2+1)*l,img2.size[0]),y2*l:min((y2+1)*l,img2.size[1])]=1
        for x1,y1 in v:
            mask1[x1*l:min((x1+1)*l,img1.size[0]),y1*l:min((y1+1)*l,img1.size[1])]=1
        c_out.append(mask1)
        s_out.append(mask2)
    return [c_out,s_out] 

### Compute distance of two vectors

In [11]:
def distmat(x, y, cos_d=True):
    if cos_d:
        M = pairwise_distances_cos(x, y)
    else:
        M = torch.sqrt(pairwise_distances_sq_l2(x, y))
    return M

## Step 2: Loss functions

### Content loss function

Formal Definition: Let  $D^{X}$  be the pairwise cosine distance matrix of all feature vectors extracted from  $X^{(t)}$ , and let  $D^{I_{C}}$  represent the corresponding matrix for the content image.

The content loss is defined as:
$\mathcal{L}_{\text{content}}(X, C) = \frac{1}{n^2} \sum{i,j} \left| \frac{D_{ij}^{X}}{\sum_i D_{ij}^{X}} - \frac{D_{ij}^{I_{C}}}{\sum_i D_{ij}^{I_{C}}} \right|$

Interpretation of this loss: For the normalized cosine distances between feature vectors (row-wise normalization), the values should remain consistent between the content image and the output image.

Effect of this loss: This loss preserves the structure of the output image without enforcing direct pixel-wise similarity to the content image. As a result, the semantics and local layout of the content are maintained, while allowing the pixel values in  $X^{(t)}$  to differ significantly from those in the content image.


In [12]:
def content_loss(feat_result, feat_content):
    d = feat_result.size(1)

    X = feat_result.transpose(0,1).contiguous().view(d,-1).transpose(0,1)
    Y = feat_content.transpose(0,1).contiguous().view(d,-1).transpose(0,1)

    Y = Y[:,:-2]
    X = X[:,:-2]
    # X = X.t()
    # Y = Y.t()

    Mx = distmat(X, X)
    Mx = Mx#/Mx.sum(0, keepdim=True)

    My = distmat(Y, Y)
    My = My#/My.sum(0, keepdim=True)

    d = torch.abs(Mx-My).mean()# * X.shape[0]
    return d


### Style loss functions

Let  $A = \{ A_1, \dots, A_n \}$  represent a set of  n  feature vectors extracted from  $X^{(t)}$ , and  $B = \{ B_1, \dots, B_m \}$  represent a set of  m  feature vectors extracted from the style image  $I_S$ . The style loss is derived from the Earth Mover’s Distance (EMD):


$$EMD(A, B) = \min_{T \geq 0} \sum_{i,j} T_{ij} C_{ij}$$

- Constraints on  T :$$\sum_j T_{ij} = \frac{1}{m}, \quad \sum_i T_{ij} = \frac{1}{n}$$ 
  - T  is the transport matrix, where  $T_{ij}$  represents the weight of the distance between the  i -th element of  A  and the  j -th element of  B  in the EMD computation.
  - $T_{ij}$  is non-negative and is optimized during the computation.


- Cost Matrix  C :
  - $C_{ij}$  represents the distance between the  i -th element of  A  and the  j -th element of  B .

The EMD measures the distance between two sets. However, computing the optimal  T  has a time complexity of  $O(\max(n, m)^3)$ , which is computationally impractical for gradient-based style transfer.

Relaxed EMD:

To address this, a relaxed version of the EMD is used. Two auxiliary distances are defined by considering only one of the original constraints at a time:

1.	Relaxed EMD for  A :

$$R_A(A, B) = \min_{T \geq 0} \sum_{i,j} T_{ij} C_{ij}, \quad \text{s.t.} \quad \sum_j T_{ij} = \frac{1}{m}$$

2.	Relaxed EMD for  B :

$$R_B(A, B) = \min_{T \geq 0} \sum_{i,j} T_{ij} C_{ij}, \quad \text{s.t.} \quad \sum_i T_{ij} = \frac{1}{n}$$


The relaxed EMD is then defined as:

$$l_r = REMD(A, B) = \max(R_A(A, B), R_B(A, B))$$


This can be equivalently expressed as:

$$l_r = \max\left( \frac{1}{n} \sum_j \min_i C_{ij}, \frac{1}{m} \sum_i \min_j C_{ij} \right)$$


In [13]:
def style_loss(X, Y, cos_d=True):
    d = X.shape[1]

    if d == 3:
        X = rgb_to_yuv(X.transpose(0,1).contiguous().view(d,-1)).transpose(0,1)
        Y = rgb_to_yuv(Y.transpose(0,1).contiguous().view(d,-1)).transpose(0,1)
    else:
        X = X.transpose(0,1).contiguous().view(d,-1).transpose(0,1)
        Y = Y.transpose(0,1).contiguous().view(d,-1).transpose(0,1)

    # Relaxed EMD
    CX_M = distmat(X, Y, cos_d=True)

    if d==3: CX_M = CX_M + distmat(X, Y, cos_d=False)

    m1, m1_inds = CX_M.min(1)
    m2, m2_inds = CX_M.min(0)

    remd = torch.max(m1.mean(), m2.mean())

    return remd

### Moment loss function

The relaxed EMD ( $l_r$ ) effectively transfers the structural style of the source image to the target image. However, cosine distance ( $D_{\cos}$ ) overlooks the magnitudes of feature vectors, leading to visual artifacts in the output, such as oversaturation or undersaturation of colors.

To address this issue, a Moment Matching Loss is introduced. This loss improves the visual quality of the output, especially for color saturation and detail retention:


$$l_m = \frac{1}{d} \| \mu_A - \mu_B \|_1 + \frac{1}{d^2} \| \Sigma_A - \Sigma_B \|_1$$

- Components:
  - $\mu_A$ : The mean of feature vectors in  A .
  - $\Sigma_A$ : The covariance of feature vectors in  A .
  - Similarly,  $\mu_B$  and  $\Sigma_B$  are the mean and covariance of feature vectors in  B .
  - d : The dimensionality of the feature vectors.
- Interpretation:
  - The first term ( $\frac{1}{d} \| \mu_A - \mu_B \|_1$ ) ensures that the overall distributions of feature vector means for  A  and  B  are aligned.
  - The second term ( $\frac{1}{d^2} \| \Sigma_A - \Sigma_B \|_1$ ) aligns the variances and correlations within the feature sets of  A  and  B .

In [14]:
def moment_loss(X, Y, moments=[1,2]):
    loss = 0.
    d = X.size(1)
    # X = X.squeeze().t()
    # Y = Y.squeeze().t()

    Xo = X.transpose(0,1).contiguous().view(d,-1).transpose(0,1)
    Yo = Y.transpose(0,1).contiguous().view(d,-1).transpose(0,1)

    splits = [Xo.size(1)]

    cb = 0
    ce = 0
    for i in range(len(splits)):
        ce = cb + splits[i]
        X = Xo[:,cb:ce]
        Y = Yo[:,cb:ce]
        cb = ce

        mu_x = torch.mean(X,0,keepdim=True)
        mu_y = torch.mean(Y,0,keepdim=True)
        mu_d = torch.abs(mu_x-mu_y).mean()

        if 1 in moments:
            # print(mu_x.shape)
            loss = loss + mu_d

        if 2 in moments:

            sig_x = torch.mm((X-mu_x).transpose(0,1), (X-mu_x))/X.size(0)
            sig_y = torch.mm((Y-mu_y).transpose(0,1), (Y-mu_y))/Y.size(0)

            sig_d = torch.abs(sig_x-sig_y).mean()

            # print(X_cov.shape)
            # exit(1)
            loss = loss + sig_d

    return loss

### Calculating the loss

This approach uses gradient descent with RMSprop to minimize the objective function based on the output image  X , given a style image ( $I_S$ ) and a content image ( $I_C$ ):


$$L(X, I_C, I_S) = \frac{\alpha l_C + l_M + l_r + \frac{1}{\alpha} l_p}{2 + \alpha + \frac{1}{\alpha}}$$

Components of the Objective Function:
1.	Content Term ( $\alpha l_C$ ):
    - This term ensures that the output image  X  retains the structure and semantic layout of the content image  $I_C$ .
    - Weighted by the hyperparameter  $\alpha$ , which controls the emphasis on preserving content.
2.	Style Terms ( $l_M + l_r + \frac{1}{\alpha} l_p$ ):
    - Moment Matching Loss ( $l_M$ ): Helps align the color and detail distribution between  X  and the style image  $I_S$ , addressing issues like oversaturation.
    - Relaxed EMD ( $l_r$ ): Ensures structural and style similarity by aligning the feature sets of  X  and  $I_S$ .
    - Perceptual Loss ( $l_p$ ): A fine-grained loss term (scaled by  $\frac{1}{\alpha}$ ) to further refine the visual fidelity.


In [15]:
def calculate_loss(feat_result, feat_content, feat_style, xx_dict, xy_dict, yx_dict, content_weight, regions, moment_weight=1.0):
    # spatial feature extract
    num_locations = 1024
    loss_total = 0.

    for ri in range(len(regions[0])):
        xx, xy, yx = get_feature_indices(xx_dict, xy_dict, yx_dict, ri=ri, cnt=num_locations)
        spatial_result, spatial_content = spatial_feature_extract(feat_result, feat_content, xx, xy)

        loss_content = content_loss(spatial_result, spatial_content)

        d = feat_style[ri][0].shape[0]
        spatial_style = feat_style[ri][0].view(1, d, -1, 1)

        feat_max = 3+2*64+128*2+256*3+512*2 # (sum of all extracted channels)

        loss_remd = style_loss(spatial_result[:, :feat_max, :, :], spatial_style[:, :feat_max, :, :])

        loss_moment = moment_loss(spatial_result[:,:-2,:,:], spatial_style, moments=[1,2]) # -2 is so that it can fit?
        # palette matching
        content_weight_frac = 1./max(content_weight,1.)
        loss_moment += content_weight_frac * style_loss(spatial_result[:,:3,:,:], spatial_style[:,:3,:,:])
        
        loss_style = loss_remd + moment_weight * loss_moment
        print(f'Style: {loss_style.item():.3f}, Content: {loss_content.item():.3f}')

        style_weight = 1.0 + moment_weight
        loss_total += (content_weight * loss_content + loss_style) / (content_weight + style_weight)

    return loss_total/len(xx_dict.keys())

## Step3: VGG network

### Network design

#### User Control in Style Transfer
1.	Integrating User Constraints
    - User-defined constraints can restrict the output style by enforcing correspondence between two sets of local regions in  $X^{(t)}$  (output image) and  $I_S$  (style image).
    - These regions are designed to have low style loss between the corresponding areas.
2.	Point-to-Point User Control
    - Each correspondence involves a pair of single local points, defined as  $(X_{t_1}, S_{s_1}), \dots, (X_{t_K}, S_{s_K})$ .

#### Redefining REMD with User Constraints

The cost matrix  $C_{ij}$  in the Relaxed Earth Mover’s Distance (REMD) is redefined as:


$$C_{ij} =
\begin{cases}
\beta \cdot D_{\cos}(A_i, B_j), & \text{if } i \in X_{t_k}, j \in S_{s_k} \\
\infty, & \text{if } \exists k \text{ such that } i \in X_{t_k}, j \notin S_{s_k} \\
D_{\cos}(A_i, B_j), & \text{otherwise}.
\end{cases}$$

- Explanation:
  - $\beta$  controls the weight of user-specified constraints relative to unconstrained areas.
  - In this study,  $\beta$  is set to 5 in all experiments.
  - $D_{\cos}(A_i, B_j)$  represents the cosine distance between feature vectors  $A_i$  (output image) and  $B_j$  (style image).

#### Enhancing Point-to-Point Control

- To strengthen the constraints, for each originally defined point, a 3×3 grid of additional constraint points is created around it.
- These new constraints include points within a 20-pixel horizontal and vertical distance of the original point in a 512×512 image.
- The spacing can be adjusted according to specific requirements.

In [16]:
RESAMPLE_FREQ = 1

class Vgg19_Extractor(nn.Module):
    def __init__(self, space):
        super().__init__()
        self.vgg_layers = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features

        for param in self.parameters():
            param.requires_grad = False
        self.capture_layers = set([1, 3, 6, 8, 11, 13, 15, 17, 20, 22, 24, 26, 29, 31, 33, 35])
        self.patch_match_layer=29
        self.num_max_pool = 4
        self.space = space
        
    def patch_match_features(self, x):
        feat=None
        for i in range(len(self.vgg_layers)):
            x = self.vgg_layers[i](x)
            if i == self.patch_match_layer:
                feat = x
                return feat/torch.norm(feat)
    
    def forward_base(self, x):
        feat = [x]
        for i in range(len(self.vgg_layers)):
            x = self.vgg_layers[i](x)
            if i in self.capture_layers: feat.append(x)
        return feat

    def forward(self, x):
        if self.space != 'vgg':
            x = (x + 1.) / 2.
            x = x - (torch.Tensor([0.485, 0.456, 0.406]).to(x.device).view(1, -1, 1, 1))
            x = x / (torch.Tensor([0.229, 0.224, 0.225]).to(x.device).view(1, -1, 1, 1))
        feat = self.forward_base(x)
        return feat
    
    def forward_samples_hypercolumn(self, X, samps=100):
        feat = self.forward(X)

        xx,xy = np.meshgrid(np.arange(X.shape[2]), np.arange(X.shape[3]))
        xx = np.expand_dims(xx.flatten(),1)
        xy = np.expand_dims(xy.flatten(),1)
        xc = np.concatenate([xx,xy],1)
        
        samples = min(samps,xc.shape[0])

        np.random.shuffle(xc)
        xx = xc[:samples,0]
        yy = xc[:samples,1]

        feat_samples = []
        for i in range(len(feat)):

            layer_feat = feat[i]

            # hack to detect lower resolution
            if i>0 and feat[i].size(2) < feat[i-1].size(2):
                xx = xx/2.0
                yy = yy/2.0

            xx = np.clip(xx, 0, layer_feat.shape[2]-1).astype(np.int32)
            yy = np.clip(yy, 0, layer_feat.shape[3]-1).astype(np.int32)

            features = layer_feat[:,:, xx[range(samples)], yy[range(samples)]]
            feat_samples.append(features.clone().detach())

        feat = torch.cat(feat_samples,1)
        return feat
    
    def forward_cat(self, X, r, samps=100):
        feat = self.forward(X)

        try:
            r = r[:,:,0]
        except:
            pass

        if r.max()<0.1:
            region_mask = np.greater(r.flatten()+1.,0.5)
        else:
            region_mask = np.greater(r.flatten(),0.5)

        xx,xy = np.meshgrid(np.arange(X.shape[2]), np.arange(X.shape[3]))
        xx = np.expand_dims(xx.flatten(),1)
        xy = np.expand_dims(xy.flatten(),1)
        xc = np.concatenate([xx,xy],1)

        xc = xc[region_mask,:]

        samples = min(samps,xc.shape[0])

        np.random.shuffle(xc)
        xx = xc[:samples,0]
        yy = xc[:samples,1]

        feat_samples = []
        for i in range(len(feat)):
            layer_feat = feat[i]
            # hack to detect lower resolution
            if i>0 and feat[i].size(2) < feat[i-1].size(2):
                xx = xx/2.0
                yy = yy/2.0
            
            xx = np.clip(xx, 0, layer_feat.shape[2]-1).astype(np.int32)
            yy = np.clip(yy, 0, layer_feat.shape[3]-1).astype(np.int32)

            features = layer_feat[:,:, xx[range(samples)], yy[range(samples)]]
            feat_samples.append(features.clone().detach())

        feat = torch.cat(feat_samples,1)
        return feat

In [17]:
RESAMPLE_FREQ = 1

class Vgg16_Extractor(nn.Module):
    def __init__(self, space):
        super().__init__()
        self.vgg_layers = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features

        for param in self.parameters():
            param.requires_grad = False
        self.capture_layers = set([1,3,6,8,11,13,15,22,29])
        self.space = space
        
    def forward_base(self, x):
        feat = [x]
        for i in range(len(self.vgg_layers)):
            x = self.vgg_layers[i](x)
            if i in self.capture_layers: feat.append(x)
        return feat

    def forward(self, x):
        if self.space != 'vgg':
            x = (x + 1.) / 2.
            x = x - (torch.Tensor([0.485, 0.456, 0.406]).to(x.device).view(1, -1, 1, 1))
            x = x / (torch.Tensor([0.229, 0.224, 0.225]).to(x.device).view(1, -1, 1, 1))
        feat = self.forward_base(x)
        return feat
    
    def forward_samples_hypercolumn(self, X, samps=100):
        feat = self.forward(X)

        xx,xy = np.meshgrid(np.arange(X.shape[2]), np.arange(X.shape[3]))
        xx = np.expand_dims(xx.flatten(),1)
        xy = np.expand_dims(xy.flatten(),1)
        xc = np.concatenate([xx,xy],1)
        
        samples = min(samps,xc.shape[0])

        np.random.shuffle(xc)
        xx = xc[:samples,0]
        yy = xc[:samples,1]

        feat_samples = []
        for i in range(len(feat)):

            layer_feat = feat[i]

            # hack to detect lower resolution
            if i>0 and feat[i].size(2) < feat[i-1].size(2):
                xx = xx/2.0
                yy = yy/2.0

            xx = np.clip(xx, 0, layer_feat.shape[2]-1).astype(np.int32)
            yy = np.clip(yy, 0, layer_feat.shape[3]-1).astype(np.int32)

            features = layer_feat[:,:, xx[range(samples)], yy[range(samples)]]
            feat_samples.append(features.clone().detach())

        feat = torch.cat(feat_samples,1)
        return feat
    
    def forward_cat(self, X, r, samps=100):
        feat = self.forward(X)

        try:
            r = r[:,:,0]
        except:
            pass

        if r.max()<0.1:
            region_mask = np.greater(r.flatten()+1.,0.5)
        else:
            region_mask = np.greater(r.flatten(),0.5)

        xx,xy = np.meshgrid(np.arange(X.shape[2]), np.arange(X.shape[3]))
        xx = np.expand_dims(xx.flatten(),1)
        xy = np.expand_dims(xy.flatten(),1)
        xc = np.concatenate([xx,xy],1)

        xc = xc[region_mask,:]

        samples = min(samps,xc.shape[0])

        np.random.shuffle(xc)
        xx = xc[:samples,0]
        yy = xc[:samples,1]

        feat_samples = []
        for i in range(len(feat)):
            layer_feat = feat[i]
            # hack to detect lower resolution
            if i>0 and feat[i].size(2) < feat[i-1].size(2):
                xx = xx/2.0
                yy = yy/2.0
            
            xx = np.clip(xx, 0, layer_feat.shape[2]-1).astype(np.int32)
            yy = np.clip(yy, 0, layer_feat.shape[3]-1).astype(np.int32)

            features = layer_feat[:,:, xx[range(samples)], yy[range(samples)]]
            feat_samples.append(features.clone().detach())

        feat = torch.cat(feat_samples,1)
        return feat

### Optimize

#### Optimizing Computation for Loss Calculation
1.	Random Sampling Strategy
  - To avoid computing distances between every pixel in the style and output images, a random sampling approach is employed:
  - Style Image: Randomly sample 1024 points.
  - Content Image: Use a uniform grid to sample 1024 points, adding a random  x, y  offset at each sampling step to cover a broader image area.
2.	Efficiency in Loss Calculation
  - Only the feature vectors corresponding to the sampled points are used for loss calculation.
  - Gradients (used in backpropagation) are also computed only for these sampled points, avoiding unnecessary optimization over all pixels.
3.	Dynamic Sampling
  - After each optimization step (using the RMSprop algorithm), the sampling positions are refreshed.
  - This dynamic sampling strategy helps:
  - Avoid overfitting to specific regions of the image.
  - Gradually cover the entire image’s feature space across multiple iterations.

In [18]:
def optimize(result, content, style, content_path, style_path, content_weight, lr, extractor, regions=0):
    # torch.autograd.set_detect_anomaly(True)
    result_pyramid = make_laplace_pyramid(result, 5)
    result_pyramid = [l.data.requires_grad_() for l in result_pyramid]

    opt_iter = 200

    # use rmsprop
    optimizer = optim.RMSprop(result_pyramid, lr=lr)

    # extract features for content
    feat_content = extractor(content) # 

    stylized = fold_laplace_pyramid(result_pyramid)
    
    # some inner loop that extracts samples
    feat_style_all = []
    for ri in range(len(regions[0])):
        r_temp = regions[0][ri]
        if len(r_temp.shape) > 2:
            r_temp = r_temp[:,:,0]
        
        r_temp = torch.from_numpy(r_temp).unsqueeze(0).unsqueeze(0).contiguous()
        if AUTO_MASK_PYRAMID:
            r=r_temp[0,0,:,:].numpy()
        else:
            r = tensor_resample(r_temp,[style.size(3),style.size(2)])[0,0,:,:].numpy()
        feat_style = None
        for j in range(5):
            with torch.no_grad():
                feat_e = extractor.forward_cat(style, r, samps=1000)  
                feat_style = feat_e if feat_style is None else torch.cat((feat_style, feat_e), dim=2)
        feat_style_all.append(feat_style)

    xx = {}
    xy = {}
    yx = {}
        
    for ri in range(len(regions[0])):

        try:
            temp = xx[ri]
        except:
            xx[ri] = []
            xy[ri] = []
            yx[ri] = []

        r_temp = regions[0][ri]
        r_temp = torch.from_numpy(r_temp).unsqueeze(0).unsqueeze(0).contiguous()
        if AUTO_MASK_PYRAMID:
            r=r_temp[0,0,:,:].numpy()
        else:
            r = tensor_resample(r_temp, ([stylized.size(3), stylized.size(2)]))[0,0,:,:].numpy()     

        if r.max()<0.1:
            r = np.greater(r+1.,0.5)
        else:
            r = np.greater(r,0.5)
   
        sample_indices(feat_content, feat_style_all, r, ri, xx, xy, yx) # 0 to sample over first layer extracted

    # init indices to optimize over
    for it in range(opt_iter):
        optimizer.zero_grad()
        stylized = fold_laplace_pyramid(result_pyramid)
        # original code has resample here, seems pointless with uniform shuffle
        # ...
        # also shuffle them every y iter
        if it % 1 == 0 and it != 0:
            for ri in xx.keys():
                np.random.shuffle(xx[ri][0])
                np.random.shuffle(xy[ri][0])
                np.random.shuffle(yx[ri][0])

        feat_result = extractor(stylized)

        loss = calculate_loss(feat_result, feat_content, feat_style_all, xx, xy, yx, content_weight, regions)
        
        loss.backward()
        optimizer.step()
    
    return stylized

## Step 5: STROTSS process

1.	Iterative Resolution Adjustment
    - The method iteratively increases the resolution and halves the value of  \alpha  in each step.
2.	Initial Resolution
    - At the start, the longer side of both the content image and style image is scaled down to 64 pixels.
    - The output from each scaled resolution is bilinearly upsampled to twice the resolution and used as the initialization for the next scale’s output.
3.	Default Settings
    - By default, the process iterates four times.
    - The initial value of  \alpha = 16 , ensuring that  \alpha = 1  at the final resolution.
4.	Initialization at the Lowest Resolution
    - At the lowest resolution:
    - The output is initialized as the lowest level of the Laplacian pyramid of the content image, combined with the average color of the style image.
    - This initialized image is then decomposed into a 5-level Laplacian pyramid.
5.	Optimization Using RMSprop
    - RMSprop is used to update the components of the Laplacian pyramid, minimizing the objective function.
    - Optimizing the Laplacian pyramid instead of raw pixels significantly accelerates convergence.
6.	Optimization Details
    - For each resolution:
    - RMSprop is run for 200 updates.
    - The learning rate is set to 0.002 for all resolutions except the last one, where it is reduced to 0.001.

In [19]:
def strotss(content_pil, style_pil, content_path, style_path, regions, content_weight=1.0*16.0, device='cuda:0', space='uniform', content_mask=None, style_mask=None):
    content_np = pil_to_np(content_pil)
    style_np = pil_to_np(style_pil)
    content_full = np_to_tensor(content_np, space).to(device)
    style_full = np_to_tensor(style_np, space).to(device)

    if content_mask is not None and style_mask is not None:
        content_mask = content_mask.to(device)
        style_mask = style_mask.to(device)

    lr = 2e-3
    extractor = Vgg19_Extractor(space=space).to(device)

    scale_last = max(content_full.shape[2], content_full.shape[3])
    scales = []
    for scale in range(10):
        divisor = 2**scale
        if min(content_pil.width, content_pil.height) // divisor >= 33:
            scales.insert(0, divisor)
    
    for scale in scales:
        # rescale content to current scale
        content = tensor_resample(content_full, [ content_full.shape[2] // scale, content_full.shape[3] // scale ])
        style = tensor_resample(style_full, [ style_full.shape[2] // scale, style_full.shape[3] // scale ])
        if AUTO_MASK_PYRAMID:
            # using patch_match to get regions' correspondence
            # Prom1: resample——因为原本的版本中，regions就只提供一次，大小和原图一样，但是现在我们是对每一个分辨率都得到一个新的regions
            # Prom2: 我们需要将patch_match融入到我们代码当中——要研究对应的接口，希望最后能够得到两张图片正方形区域用颜色对应起来
            feat1=extractor.patch_match_features(content)
            feat2=extractor.patch_match_features(style)
            pm=PatchMatch(feat1,feat2,3)
            pm.propagate(iters=5)
            matches=pm.nnf
            im1=np_to_pil(tensor_to_np(content))
            im2=np_to_pil(tensor_to_np(style))
            regions=build_from_matches(matches,extractor.num_max_pool,im1,im2)
            
        print(f'Optimizing at resoluton [{content.shape[2]}, {content.shape[3]}]')
        # upsample or initialize the result
        if scale == scales[0]:
            # first
            result = laplacian(content) + style.mean(2,keepdim=True).mean(3,keepdim=True)
        elif scale == scales[-1]:
            # last 
            result = tensor_resample(result, [content.shape[2], content.shape[3]])
            lr = 1e-3
        else:
            result = tensor_resample(result, [content.shape[2], content.shape[3]]) + laplacian(content)

        # # do the optimization on this scale
        result = optimize(result, content, style, content_path, style_path, content_weight=content_weight, lr=lr, extractor=extractor, regions=regions)

        # next scale lower weight
        content_weight /= 2.0

    clow = -1.0 if space == 'uniform' else -1.7
    chigh = 1.0 if space == 'uniform' else 1.7
    result_image = tensor_to_np(tensor_resample(torch.clamp(result, clow, chigh), [content_full.shape[2], content_full.shape[3]])) # 
    # renormalize image
    result_image -= result_image.min()
    result_image /= result_image.max()
    return np_to_pil(result_image * 255.)

In [20]:
args={"content": "content.jpg", "style": "style.jpg", "content_mask": None, "style_mask": None, "weight": 1.0, "output": "strotss.png", "device": "cuda:0", "ospace": "uniform", "resize_to": 512,"AUTO_MASK_PYRAMID":False}
args["content"] ="/Users/xiongjiangkai/Downloads/anni.jpg"
args["style"] = "/Users/xiongjiangkai/Downloads/mona.jpg"
args["output"]=args["content"].split('/')[-1].split('.')[0]+"_"+args["style"].split('/')[-1].split('.')[0]+"_masked.png"
AUTO_MASK_PYRAMID=args["AUTO_MASK_PYRAMID"]=True
# args["content_mask"] ="/Users/xiongjiangkai/Downloads/cv_heol/test_content/content_palace_single1.jpg"
# args["style_mask"] = "/Users/xiongjiangkai/Downloads/cv_heol/test_style/vangogh_starry_night_single1.jpg"

# make 256 the smallest possible long side, will still fail if short side is <
if args["resize_to"] < 2**8:
    print("Resulution too low.")
    exit(1)

content_pil, style_pil = pil_loader(args["content"]), pil_loader(args["style"])
content_mask, style_mask = None, None

if args["content_mask"] and args["style_mask"] is not None:
    regions = extract_color_regions(args["content_mask"], args["style_mask"])

    pil_content_mask = pil_loader(args["content_mask"])
    pil_style_mask = pil_loader(args["style_mask"])
    
    pil_content_mask = pil_resize_long_edge_to(pil_content_mask, args["resize_to"])
    pil_style_mask = pil_resize_long_edge_to(pil_style_mask, args["resize_to"])
    
    content_mask = pil_to_np(pil_content_mask)
    style_mask = pil_to_np(pil_style_mask)

    content_mask = create_mask_from_image(content_mask)
    style_mask = create_mask_from_image(style_mask)
else:
    try:
        regions = [[pil_to_np(pil_resize_long_edge_to(pil_loader(args["content"]), args["resize_to"]))[:,:,0]*0.+1.], [pil_to_np(pil_resize_long_edge_to(pil_loader(args["style"]), args["resize_to"]))[:,:,0]*0.+1.]]
    except:
        regions = [[pil_to_np(pil_resize_long_edge_to(pil_loader(args["content"]), args["resize_to"]))[:,:]*0.+1.], [pil_to_np(pil_resize_long_edge_to(pil_loader(args["style"]), args["resize_to"]))[:,:]*0.+1.]]

content_weight = args["weight"] * 16.0

start = time()
result = strotss(pil_resize_long_edge_to(content_pil, args["resize_to"]), 
                    pil_resize_long_edge_to(style_pil, args["resize_to"]), args["content"], args["style"], regions, content_weight, device, args["ospace"], content_mask=content_mask, style_mask=style_mask)
# result.save(args["output"])
print(f'Done in {time()-start:.3f}s')

FileNotFoundError: [Errno 2] No such file or directory: '/Users/xiongjiangkai/Downloads/anni.jpg'