In [None]:
import sys
sys.path.append('..')

from utils import *
from model import *
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt
from PIL import Image
%load_ext autoreload
%autoreload 2

In [None]:
root_dir = "/path/to/dataset/"

In [None]:
dataset_config = {
    'root_dir': root_dir,
    'shuffle': True,
    'img_num':  50, 
    'visible_img': 5,
    'focus_dist': [.3, .45, .75, 1.2, 1.8],
    'recon_all': True,
    'near': 0.1,
    'RGBFD': True,
    'DPT': True,
    'AIF': False,
    'far': 1,
}

In [None]:
train_dl, val_dl, test_dl = load_data(dataset_config, "DefocusNet", 1)

In [None]:
gt_dpt = test_dl.dataset[0]['dpt'][0].unsqueeze(-1).numpy()
plt.imshow(gt_dpt)
np.max(gt_dpt)

In [None]:
fs = test_dl.dataset[0]['output'][::2,:3,:,:]

In [None]:
diff_fs = fs[1:] - fs[:-1]

In [None]:
diff_fs.shape

In [None]:
diff = torch.mean(diff_fs[:, :, 151, 151], dim=1)
diff

In [None]:
samp = diff_fs[17].numpy().transpose(1, 2, 0)
plt.imshow(np.abs(samp))

In [None]:
max_d = 0
min_d = 0

for data in tqdm(train_dl.dataset):
    gt_dpt = data['dpt'][0].unsqueeze(-1).numpy()
    max_dpt = np.max(gt_dpt)
    min_dpt = np.min(gt_dpt)
    if max_d < max_dpt:
        max_d = max_dpt
    if min_d > min_dpt:
        min_d = min_dpt

In [None]:
dl = iter(train_dl)

In [None]:
a = next(dl)

In [None]:
a['output'][0].shape

In [None]:
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.imshow(a['output'][0][0].numpy().transpose(1, 2, 0))
plt.subplot(1, 2, 2)
plt.imshow(a['dpt'][0].numpy().transpose(1, 2, 0))
plt.colorbar()

In [None]:
dpt = a['dpt'][0]
dpt.shape

In [None]:
# aif = a['aif'][0].unsqueeze(0)

In [None]:
def thin_len_coc(FN, dpt, focal_length, focus_dist, pixel_size=1.21e-5):
    sensor_dist = focus_dist * focal_length / (focus_dist - focal_length)
    D =  focal_length / FN
    CoC = D * sensor_dist * (1/focal_length - 1/sensor_dist - 1/dpt)
    sigma = CoC / 2 / pixel_size
    return torch.abs(sigma)

In [None]:
defocuses = []
plt.figure(figsize=(25, 6))
for i, fd in enumerate([0.1, 0.15, 0.3, 0.7, 1.5]):
    plt.subplot(1, 5, i+1)
    defocus = thin_len_coc(1, dpt, 2.9*1e-3, fd)
    defocuses.append(defocus.unsqueeze(0))
    plt.imshow(defocus.numpy().transpose(1, 2, 0))

plt.figure(figsize=(25, 6))
for i, fd in enumerate([0.1, 0.15, 0.3, 0.7, 1.5]):
    plt.subplot(1, 5, i+1)
    defocus = thin_len_coc(1, dpt, 2.9*1e-3, fd)
    plt.imshow(np.clip(defocus.numpy().transpose(1, 2, 0), 1, np.inf))
    
plt.figure(figsize=(25, 6))
for i in range(5):
    plt.subplot(1, 5, i+1)
    plt.imshow(a['output'][0, i].numpy().transpose(1, 2, 0))

In [None]:
torch.cuda.set_device(5)
render = GaussPSF(7)
render.cuda()

In [None]:
recon = render(aif.cuda(), defocus.cuda())

In [None]:
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.imshow(recon.squeeze().cpu().numpy().transpose(1, 2, 0))
plt.subplot(1, 2, 2)
plt.imshow(a['output'][0, -1].numpy().transpose(1, 2, 0))

In [None]:
recon_loss = BlurMetric('recon')
sharp_loss = BlurMetric('sharp')
ssim_loss = BlurMetric('ssim')
l1_loss = BlurMetric('l1')
mse_loss = BlurMetric('mse')

In [None]:
recon_loss(recon, a['output'][0, -1].unsqueeze(0).cuda())

In [None]:
sharp_loss(recon, a['output'][0, -1].unsqueeze(0).cuda())

In [None]:
defocus_ = torch.cat(defocuses).squeeze()

In [None]:
aif_ = aif.expand(5, *aif.shape[1:]).contiguous()

In [None]:
recon_ = render(aif_.cuda(), defocus_.cuda())

In [None]:
metric_recon = np.zeros((6, 6))
metric_sharp = np.zeros((6, 6))
metric_ssim = np.zeros((6, 6))
metric_l1 = np.zeros((6, 6))
metric_mse = np.zeros((6, 6))

In [None]:
recon_loss(recon_[0].unsqueeze(0), a['output'][:, 0].cuda())

In [None]:
for i in range(6):
    for j in range(6):
        if i != 5:
            inp = recon_[i].unsqueeze(0)
        else:
            inp = a['aif'].cuda()
        if j != 5:
            tar = a['output'][:, j].cuda()
        else:
            tar = a['aif'].cuda()
        metric_recon[i, j] = recon_loss(inp, tar).item()
        metric_sharp[i, j] = sharp_loss(inp, tar).item()
        metric_ssim[i, j] = ssim_loss(inp, tar).item()        
        metric_l1[i, j] = l1_loss(inp, tar).item()
        metric_mse[i, j] = l1_loss(inp, tar).item()                

In [None]:
for i in range(6):
    for j in range(6):
        if i != 5:
            inp = a['output'][:, i].cuda()
        else:
            inp = a['aif'].cuda()
        if j != 5:
            tar = a['output'][:, j].cuda()
        else:
            tar = a['aif'].cuda()
        metric_recon[i, j] = recon_loss(inp, tar).item()
        metric_sharp[i, j] = sharp_loss(inp, tar).item()
        metric_ssim[i, j] = ssim_loss(inp, tar).item()        
        metric_l1[i, j] = l1_loss(inp, tar).item()        
        metric_mse[i, j] = l1_loss(inp, tar).item()  

In [None]:
plt.figure(figsize=(25, 4))
for i, n in enumerate([metric_recon, metric_sharp, 1-metric_ssim, metric_l1, metric_mse]):
    plt.subplot(1, 5, i+1)
    plt.imshow(n[:5, :5])
    plt.colorbar()

In [None]:
def gradient(inp):
    D_dy = inp[:, :, :, :] - F.pad(inp[:, :, :-1, :], (0, 0, 1, 0))
    D_dx = inp[:, :, :, :] - F.pad(inp[:, :, :, :-1], (1, 0, 0, 0))
    return D_dx, D_dy

def sharpness(image):
    grad = gradient(image)
    mu = F.avg_pool2d(image, 7, 1, 7 // 2, count_include_pad=False) + 1e-8
    output = - (grad[0]**2 + grad[1]**2) - torch.abs((image - mu) / mu) - torch.pow(image - mu, 2)
    return output

In [None]:
plt.imshow(-sharpness(recon_[0].unsqueeze(0))[0].cpu().numpy().transpose(1, 2, 0))

In [None]:
plt.imshow(-sharpness(a['output'][:, 0].cuda())[0].cpu().numpy().transpose(1, 2, 0))

In [None]:
plt.imshow(-sharpness(a['aif'][:].cuda())[0].cpu().numpy().transpose(1, 2, 0))

In [None]:
clear_pix = torch.min(defocus_, dim=0)

In [None]:
clear_pix[1].view(256, 256, 1).shape

In [None]:
plt.figure(figsize=(10, 4))
for i in range(2):
    plt.subplot(1, 2, i + 1)
    plt.imshow(clear_pix[i].numpy())
    plt.colorbar()

In [None]:
coarse_aif = torch.zeros(3, 256, 256)
for i in range(256):
    for j in range(256):
        coarse_aif[:, i, j] = a['output'][0][int(clear_pix[1][i, j]), :, i, j]

In [None]:
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.imshow(a['aif'][0].numpy().transpose(1, 2, 0))
plt.subplot(1, 2, 2)
plt.imshow(coarse_aif.numpy().transpose(1, 2, 0))

In [None]:
recon_loss(a['aif'], coarse_aif.unsqueeze(0))

In [None]:
ssim_loss(a['aif'], coarse_aif.unsqueeze(0))

In [None]:
l1_loss(a['aif'], coarse_aif.unsqueeze(0))

In [None]:
sharp_loss(a['aif'], coarse_aif.unsqueeze(0))

In [None]:
mse_loss(a['aif'], coarse_aif.unsqueeze(0))

In [None]:
import OpenEXR

In [None]:
def read_dpt(img_dpt_path):
    # pt = Imath.PixelType(Imath.PixelType.HALF)  # FLOAT HALF
    dpt_img = OpenEXR.InputFile(img_dpt_path)
    dw = dpt_img.header()['dataWindow']
    size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
    (r, g, b) = dpt_img.channels("RGB")
    dpt = np.frombuffer(r, dtype=np.float16)
    dpt.shape = (size[1], size[0])
    return dpt

In [None]:
imglist_dpt = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f)) and f[-7:] == "Dpt.exr"]
imglist_all = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f)) and f[-7:] == "All.tif"]
imglist_aif = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f)) and f[-7:] == "Aif.tif"]

In [None]:
imglist_dpt.sort()
imglist_all.sort()
imglist_aif.sort()

In [None]:
for dpt in imglist_dpt:
    prefix = dpt.split('.')[0]
    img_dpt_path = os.path.join(root_dir, dpt)
    depth = read_dpt(img_dpt_path)
    save_path = os.path.join(root_dir, prefix+'.npy')
    with open(save_path, 'wb') as f:
        np.save(f, depth[:, :, None])

In [None]:
for img in imglist_all:
    prefix = img.split('.')[0]
    img_all_path = os.path.join(root_dir, img)
    im = Image.open(img_all_path)
    save_path = os.path.join(root_dir, prefix+'.png')
    im.save(save_path)

In [None]:
for img in imglist_aif:
    prefix = img.split('.')[0]
    img_all_path = os.path.join(root_dir, img)
    im = Image.open(img_all_path)
    save_path = os.path.join(root_dir, prefix+'.png')
    im.save(save_path)

In [None]:
plt.imshow(depth)
plt.colorbar()

In [None]:
depth[:, :, None].shape

In [None]:
np.max(depth[:,0])