In [1]:
import os
import torch
import torchvision.models as models
import cv2
import torch.optim as optim
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import numpy as np
import os
import utils

In [None]:
c='lehigh.jpg'; s='starry.jpg'; savename='trm-lehigh-starry'
if not os.path.exists(f'./generated'):
    os.mkdir(f'./generated')
if not os.path.exists(f'./generated/{savename}'):
    os.mkdir(f'./generated/{savename}')
epochs=5000; c_layer=5; alpha=1; beta=1e4; printevery=100; starting=0
print(f"Content Image:{c} | Style Image:{s} | savename: {savename}", flush=True)
# load model
model = models.vgg19(pretrained=True)
model = model.cuda()
for param in model.parameters():
    param.requires_grad = False
# load image
# contentImage
img = cv2.imread(c)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w, c = img.shape
# avoid memory overflow for hessian-vevctor product
h, w = h//2, w//2
img = cv2.resize(img, (w, h))
contentImage = torch.tensor(img / 255.0).float().cuda()
# style image
img = cv2.imread(s)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (w, h))
styleImage = torch.tensor(img / 255.0).float().cuda()
layers = utils.get_layers(model)

aCs = utils.get_feature_maps(contentImage, layers)
aSs = utils.get_feature_maps(styleImage, layers)

G = contentImage.detach().clone().requires_grad_(True).cuda()
style_layer_weights = [1.0 / 16 for i in range(16)]

In [None]:
def to_boundary(z,d,radius):
    """
        find tau>0 such that ||z+tau*d|| = radius
    """
    norm_d_sq = (d * d).sum()
    norm_d = norm_d_sq.sqrt()
    tau = (radius * norm_d -  (z*d).sum())/ norm_d_sq
    return tau

In [None]:
def cgsteihaug(gradf, radius, cg_maxiter=10):
    """
        Use CG-Steihaug method to approximately solve the 
        trust-region subproblem:
            min_d 1/2d^THd + grad^Td 
            s.t ||d|| <= radius
    """
    # initialize CG-Steihaug
    cg_iter = 0
    with torch.no_grad():
        z = torch.zeros_like(G)
        # create a copy
        r = gradf[0].data + 0.0
        d = 0.0 - gradf[0]
        norm_r0 = torch.norm(r).data.item()
        cg_tol = min(0.5, norm_r0) * norm_r0
    if norm_r0 < cg_tol:
        cgflag = 'cgtol'
    while True:
        cg_iter += 1
        if cg_iter > cg_maxiter:
            # print("Reach cg max iterations!")
            d = z
            cgflag = 'cgmax'
            break
        # Hessian vector product
        Hd, = torch.autograd.grad((gradf[0] * d).sum(), G,retain_graph=True)
        with torch.no_grad():
            dTHd  = (d * Hd).sum().data.item()
            # negative curvature
            if dTHd <= 0:
                # get the positive stepsize to boundary
                tau = to_boundary(z, d, radius)
                d = z + tau * d
                cgflag = 'negcv'
                break
            # H is Positive definite
            norm_r_sq = (r * r).sum().data.item()
            alpha = norm_r_sq / dTHd
            z += alpha * d
            z_norm = (z * z).sum().sqrt().data.item()
            if z_norm >= radius:
                tau = to_boundary(z, d, radius)
                d = z + tau * d
                cgflag = 'posbd'
                break
            rnew = r + alpha * Hd
            rnew_norm = (rnew * rnew).sum().sqrt().data.item()
            if rnew_norm < cg_tol:
                p = z
                cgflg = 'cgtol'
            beta = rnew_norm ** 2 / norm_r_sq
            d = -rnew + beta * d
            r = rnew
    print(f'   CG-Steihaug: radius:{radius:3.3e} | current gradf_norm:{norm_r0:3.3e} | {cg_iter}/{cg_maxiter} | terminate with: {cgflag}')
    # free memory
    del Hd
    return d

In [None]:
radius=5; radius_max=100; cg_maxiter=10; eta=0.01; printevery=100; epochs = 2000
for it in range(epochs):
    aGs = utils.get_feature_maps(G, layers)
    loss, content_cost, style_cost = utils.compute_total_cost(aGs, aCs, aSs, style_layer_weights,
                                                content_layer_idx=c_layer, alpha=alpha, beta=beta)
    print(f'Iter:{it+1} | loss:{loss.data.cpu().item():2.3e} | content: {content_cost.item():2.3e} | style_cost:{style_cost.item():2.3e}', flush=True)
    gradf = torch.autograd.grad(loss, G, create_graph=True)
    p = cgsteihaug(gradf, radius, cg_maxiter)
    # actual decrease at the trial point
    with torch.no_grad():
        Gnew = G + p
    aGnew = utils.get_feature_maps(Gnew, layers)
    with torch.no_grad():
        loss_new, _, _ = utils.compute_total_cost(aGnew, aCs, aSs, style_layer_weights,
                                                content_layer_idx=c_layer, alpha=alpha, beta=beta)
    actual_decrease = loss - loss_new
    # model decrease
    Hp, = torch.autograd.grad((gradf[0] * p).sum(), G)
    pTHp = (p * Hp).sum().data.item()
    gTp = (gradf[0] * p).sum().data.item()
    model_decrease = -gTp - pTHp / 2
    rho = actual_decrease / model_decrease
    norm_p = (p*p).sum().sqrt().data.item()
    if rho < 0.25:
        radius *= 0.25
        radius_flag = 'shrink'
    else:
        if rho > 0.75 and np.abs(norm_p - radius) <= 1e-12:
            radius = min(2*radius, radius_max)
            radius_flag = 'enlarge'
        else:
            radius_flag = 'unchange'
    if rho > eta:
        G.data = G.data + p.data
        update_flag = 'move'
    else:
        update_flag = 'stay'
    print(f'   Trust-Region: {radius_flag:10s} | new radius:{radius:3.3e} | x-update:{update_flag}')
    if (it + 1) % printevery == 0 or it == 0:
        save_image(G.permute(2, 0, 1).cpu().detach(), fp='./generated/{}/iter_{}.png'.format(savename, it+1))