In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from reconstruction import Model, TopKLayer
from PIL import Image
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import os
import itertools
from torch import optim

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [8]:
def getImage(img_path, p, epochs=100):
    m = Model(img_path, topk=p['topk'], device=p['device'], dimensions=p['dimensions'], mode=p['mode'])
    I = torch.rand((1, 3, m.dimensions[0], m.dimensions[1])).to(device)
    I = I.requires_grad_(True)
    optimizer = optim.LBFGS([I], lr=1)

    r = 0
    while r < epochs:
        def closure():
            optimizer.zero_grad()
            I.grad = None
            m(I)
            loss = m.loss()
            loss.backward(retain_graph=True)
            #print(f'I gradient: {I.grad}')
            return loss
        optimizer.step(closure)
        
        #transforms.ToPILImage()(np.clip(I.clone().detach().cpu().squeeze(0).numpy().transpose(1, 2, 0), 0, 1)).save(os.path.join('./results', f'{r}.jpg'))
        r += 1
        print(f'\rEpoch {r}: Loss {m.loss().item()}', end='')
    print()
    return I

In [9]:
def get_path(img_name):
    img_ext = '.jpg'
    img_path = os.path.join('./', img_name + img_ext)
    return img_path, img_name, img_ext

In [10]:
parameters = {
    'topk': 0.05,
    'device': device,
    'dimensions': (500, 500), 
    'mode': 'topk'
} 

In [None]:
topks = [0.05]#, 0.5, 0.95]
modes = ['non-topk', 'topk', 'both']
img_names = ['rocks', 'jeep1']

I = None
for p in itertools.product(topks, modes, img_names):
    print(p)
    topk = p[0]
    mode = p[1]
    img_path, img_name, img_ext = get_path(p[2])
    
    parameters['topk'] = topk
    parameters['mode'] = mode

    I = getImage(img_path, parameters, epochs=10)
    transforms.ToPILImage()(np.clip(I.clone().detach().cpu().squeeze(0).numpy().transpose(1, 2, 0), 0, 1)).save(os.path.join('./results', f'{img_name}-{topk}-{mode}{img_ext}'))

(0.05, 'non-topk', 'rocks')
Epoch 10: Loss 670866.255
(0.05, 'non-topk', 'jeep1')
Epoch 10: Loss 781452.3125
(0.05, 'topk', 'rocks')
Epoch 10: Loss 173122.96875
(0.05, 'topk', 'jeep1')
Epoch 10: Loss 206952.921875
(0.05, 'both', 'rocks')
Epoch 10: Loss 1823.9219970703125
(0.05, 'both', 'jeep1')
Epoch 8: Loss 9513.8847656255