In [None]:
import torch, copy
from seeing import nethook, setting, show, renormalize, zdataset, pbar
from seeing.encoder_loss import cor_square_error
from torch.nn.functional import mse_loss, l1_loss
torch.set_grad_enabled(False)

unwrapped_G = setting.load_proggan('church').cuda()
zds = zdataset.z_sample_for_model(unwrapped_G).cuda()
gt = {}
gt['z'] = zds[10:11]

with nethook.InstrumentedModel(unwrapped_G) as inst_G:
    inst_G.retain_layers(['layer1', 'layer2', 'layer3', 'layer4'])
    target_x = inst_G(gt['z'])
    for n, v in inst_G.retained_features().items():
        gt[n] = v
    gt['x'] = target_x
show([[renormalize.as_image(gt['x'][0])],
      [(n, 'shape') + tuple(d.shape) for n, d in gt.items()]])

In [None]:
from seeing import encoder_net, nethook

E = nethook.InstrumentedModel(encoder_net.HybridLayerNormEncoder())
filename = 'results/church/invert_hybrid_cse/snapshots/epoch_1000.pth.tar'
E.load_state_dict(torch.load(filename)['state_dict'])
E.eval().cuda()
E1 = E

E = encoder_net.HybridLayerNormEncoder()
filename = 'results/church/invert_hybrid_bottom_b5/snapshots/epoch_1000.pth.tar'
E.load_state_dict(torch.load(filename)['state_dict'])
E.eval().cuda()
E2 = E

init_z = E(target_x)
renormalize.as_image(unwrapped_G(init_z)[0])

In [None]:
F = encoder_net.make_over5_resnet()
filename = 'results/church/invert_over5_resnet/snapshots/epoch_100.pth.tar'
F.load_state_dict(torch.load(filename)['state_dict'])
F.eval().cuda()
None

In [None]:
R4 = nethook.subsequence(encoder_net.HybridLayerNormEncoder(), first_layer='inv4')
filename = 'results/church/invert_hybrid_bottom_b4/snapshots/epoch_1000.pth.tar'
R4.load_state_dict(torch.load(filename)['state_dict'])
R4.eval().cuda()

R3 = nethook.subsequence(encoder_net.HybridLayerNormEncoder(), first_layer='inv3')
filename = 'results/church/invert_hybrid_bottom_b3/snapshots/epoch_1000.pth.tar'
R3.load_state_dict(torch.load(filename)['state_dict'])
R3.eval().cuda()

R2 = nethook.subsequence(encoder_net.HybridLayerNormEncoder(), first_layer='inv2')
filename = 'results/church/invert_hybrid_bottom_b2/snapshots/epoch_1000.pth.tar'
R2.load_state_dict(torch.load(filename)['state_dict'])
R2.eval().cuda()

R1 = nethook.subsequence(encoder_net.HybridLayerNormEncoder(), first_layer='inv1')
filename = 'results/church/invert_hybrid_bottom_b1/snapshots/epoch_1000.pth.tar'
R1.load_state_dict(torch.load(filename)['state_dict'])
R1.eval().cuda()

None

In [None]:
from seeing.LBFGS import FullBatchLBFGS

def estimate_z(G, gt):
    cur = G.retained_features()
    for i in range(1, 4+1):
        if hasattr(G, 'd%d' % i):
            cur['layer%d' % i] = cur['layer%d' % i] + getattr(G, 'd%d' % i)
    cur['z'] = G.init_z
    if hasattr(G, 'dz'):
        cur['z'] = cur['z'] + G.dz
    else:
        cur['z'] = (R1(cur['layer1']) +
                    R2(cur['layer2']) +
                    R3(cur['layer3']) +
                    R4(cur['layer4'])) / 4
    err = {} if gt is None else {n: cor_square_error(gt[n], c) for n, c in cur.items()}
    return cur['z'], err

def refine_z(init_z, target_x, gt, optimize_over=None,
             lr=0.02, lambda_f=0.25, num_steps=3000, show_every=100):
    from matplotlib import pyplot as plt

    if optimize_over is None:
        optimize_over = ['layer1']
    show.flush()
    G = encoder_net.ResidualGenerator(copy.deepcopy(unwrapped_G), init_z, optimize_over)
    G.retain_layers(['layer1', 'layer2', 'layer3', 'layer4'])

    parameters = list(G.parameters(recurse=False))
    nethook.set_requires_grad(False, G, E)
    nethook.set_requires_grad(True, *parameters)
    optimizer = torch.optim.Adam(parameters, lr=lr)
    
    with torch.no_grad():
        target_f = F(target_x)

    with torch.enable_grad():
        for step_num in pbar(range(num_steps + 1)):
            current_x = G()
            loss_x = l1_loss(target_x, current_x)
            loss_f = mse_loss(target_f, F(current_x))
            loss = loss_x + loss_f * lambda_f
            if show_every and step_num % show_every == 0:
                with torch.no_grad():
                    est_z, err = estimate_z(G, gt)
                    show.a(
                        ['step %d' % step_num] +
                        ['loss: %f' % loss.item()] +
                        ['loss_x: %f' % loss_x.item()] +
                        ['loss_f: %f' % loss_f.item()] +
                        ['err in %s: %f' % (n, e) for n, e in err.items()] +
                        [[renormalize.as_image(current_x[0])]], cols=3)
            optimizer.zero_grad()
            loss.backward()
            if step_num > 0:
                optimizer.step()
        show.flush()
    est_z, err = estimate_z(G, gt)
    return est_z, err

def refine_z_lbfgs(init_z, target_x, gt, optimize_over=None, lambda_f=0.25,
                   num_steps=1000, show_every=100):

    if optimize_over is None:
        optimize_over = ['layer1']
    show.flush()

    with torch.no_grad():
        target_f = F(target_x)

    G = encoder_net.ResidualGenerator(copy.deepcopy(unwrapped_G), init_z, optimize_over)
    G.retain_layers(['layer1', 'layer2', 'layer3', 'layer4', ('output_256x256', 'x')])

    parameters = list(G.parameters(recurse=False))
    nethook.set_requires_grad(False, G, E)
    nethook.set_requires_grad(True, *parameters)
    optimizer = FullBatchLBFGS(parameters)

    def closure():
        optimizer.zero_grad()
        current_x = G()
        loss = l1_loss(target_x, current_x)
        if lambda_f:
            loss += mse_loss(target_f, F(current_x)) * lambda_f
        return loss

    with torch.enable_grad():
        for step_num in pbar(range(num_steps + 1)):
            if step_num == 0:
                loss = closure()
                loss.backward()
                lr, F_eval, G_eval = 0, 0, 0
            else:
                options = {'closure': closure, 'current_loss': loss, 'max_ls': 10}
                loss, _, lr, _, _, _, _, _ = optimizer.step(options)
            if step_num % show_every == 0:
                with torch.no_grad():
                    est_z, err = estimate_z(G, gt)
                    show.a(
                        ['step %d' % step_num] +
                        ['loss: %f' % loss.item()] +
                        ['lr: %f' % lr] +
                        ['err in %s: %f' % (n, e) for n, e in err.items()] +
                        [[renormalize.as_image(G.retained_layer('x')[0])]], cols=3)
        if show_every > 0:
            show.flush()
    est_z, err = estimate_z(G, gt)
    return est_z, err


In [None]:
if False:
    print("Phase 1")
    new_z_1, _ = refine_z(init_z,  target_x, gt, optimize_over=['layer1'], num_steps=5000, show_every=1000)
    print("Phase 2")
    new_z_2, _ = refine_z(new_z_1, target_x, gt, optimize_over=['layer1'], num_steps=2000, show_every=1000)
    print("Phase 3")
    new_z_3, _ = refine_z(new_z_2, target_x, gt, optimize_over=['z'], num_steps=10000, show_every=1000)
    print("Phase 4")
    new_z_4, _ = refine_z(new_z_3, target_x, gt, optimize_over=['z'], lambda_f=0.5, num_steps=10000, show_every=1000)

In [None]:
from collections import defaultdict

def get_gt(true_z):
    gt = dict(z=true_z)
    with nethook.InstrumentedModel(unwrapped_G) as inst_G:
        inst_G.retain_layers(['layer1', 'layer2', 'layer3', 'layer4'])
        target_x = inst_G(gt['z'])
        for n, v in inst_G.retained_features().items():
            gt[n] = v
        gt['x'] = target_x
    return gt

def test_image(true_z):
    gt = get_gt(true_z)
    target_x = gt['x']
    init_z = E(target_x)
    new_z, err = refine_z_lbfgs(init_z, target_x, gt, optimize_over=['z'], num_steps=10000, show_every=5000)
    est = get_gt(new_z)
    return new_z, gt, est

all_gt, all_est = defaultdict(list), defaultdict(list)
for n in pbar(range(10)):
    _, gt, est = test_image(zds[n][None,...])
    for k in gt:
        all_gt[k].append(gt[k].view(-1).cpu().numpy())
        all_est[k].append(est[k].view(-1).cpu().numpy())
  
    
    

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import numpy, matplotlib.pyplot as plt
plt.style.use('dark_background')

fig, axes = plt.subplots(1, 3, figsize=(6.1,2), dpi=300)
for i, k in enumerate(['z', 'layer4', 'x']):
    gtcat = numpy.concatenate(all_gt[k])
    estcat = numpy.concatenate(all_est[k])
    if k == 'z':
        estcat /= numpy.e
    ax = axes[i]
    ax.axis('equal')
    ax.scatter(gtcat, estcat, alpha=0.5, s=0.5, color="cornflowerblue")
    ax.set_title('%s: corr %.6f' % (k, numpy.corrcoef(gtcat, estcat)[0,1]))
    ax.set_ylabel(dict(z='true $z$', layer4='true $g_4(z)$', x='true $G(z)$')[k])
    ax.set_xlabel(dict(z='$G^{-1}(G(z))$', layer4='$g_4(G^{-1}(G(z)))$', x='$G(G^{-1}(G(z)))$')[k])
fig.tight_layout()
fig.show()


In [None]:
# People: 120, 407, 441, 447, 457, 463, 515, 520, 523, 569, 571, 594, 639, 646, 751, 787, 874, 882, 883, 895, 906, 911
# Buildings: 90
# Fence: 767
# Busy street: 638
# Text: 469, 503
# Monument: 477, 485
real_x = setting.load_image('church', 485)[None,...].cuda()

renormalize.as_image(real_x[0])

In [None]:
1

In [None]:
for i in [90, 120, 407, 441, 447, 457, 463, 515, 520, 523, 569, 571, 594, 639, 646, 751, 787, 874, 882, 883, 895, 906, 911]:
    real_x = setting.load_image('church', i)[None,...].cuda()
    renormalize.as_image(real_x[0])
    init_z = E2(real_x)
    show.a([renormalize.as_image(real_x[0])])
    new_z, err = refine_z_lbfgs(init_z, real_x, None, optimize_over=['z'], lambda_f=0.25, num_steps=5000, show_every=5000)


In [None]:
for i in [90, 120, 407, 441, 447, 457, 463, 515, 520, 523, 569, 571, 594, 639, 646, 751, 787, 874, 882, 883, 895, 906, 911]:
    real_x = setting.load_image('church', i)[None,...].cuda()
    renormalize.as_image(real_x[0])
    init_z = E1(real_x)
    show.a([renormalize.as_image(real_x[0])])
    new_z, err = refine_z_lbfgs(init_z, real_x, None, optimize_over=['z'], lambda_f=0.25, num_steps=5000, show_every=5000)
