In [None]:
import torch.nn as nn
from PIL import Image
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches
%matplotlib inline

from tqdm import tqdm_notebook as tqdm

from utils.plot_utils import create_overlap
from developing_suite import *

In [None]:

str_args = ["--model=UResNet", 
            "--save-dir=./runs/",
            "--resume=./model.pth.tar",
            "--mode=test",
            "--workers=1",
            "--data=full_image",
            "--datafolder=./example_dataset/",
            "--batch-size=1"]

args = parser.parse_args(str_args)
developingSuite = DevelopingSuite(args)

iterator_test_dataset = iter(developingSuite.dataloaders["test"].dataset)

In [None]:
sample = next(iterator_test_dataset)
x = sample["x"]

fig = plt.figure(figsize=(16,8))
plt.imshow(x.numpy().squeeze(),cmap="gray") #[0:256,0:256])
plt.title(sample["name"])
plt.xticks([])
plt.yticks([])
plt.show()


In [None]:
y,z_rough,z,rgb_filter_image,objective_refinement = developingSuite.colorize_image(x)

fig = plt.figure(figsize=(16,8))
plt.imshow(y.squeeze().permute(1,2,0))
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
# this part saves all output images

img = Image.fromarray((255*x).numpy().astype('uint8').squeeze(), 'L')
img.save("input" + sample["name"])

img = Image.fromarray((255*z_rough).numpy().astype('uint8').squeeze(), 'L')
img.save("z_rough" + sample["name"])

img = Image.fromarray((255*z).numpy().astype('uint8').squeeze(), 'L')
img.save("z" + sample["name"])

img = Image.fromarray((255*y.squeeze().permute(1,2,0)).numpy().astype('uint8').squeeze(), 'RGB')
img.save("color" + sample["name"])

rgb_filter_image_z = rgb_filter_image + z
img = Image.fromarray((255*rgb_filter_image_z.squeeze().permute(1,2,0)).numpy().astype('uint8').squeeze(), 'RGB')
img.save("rgb_filter_image" + sample["name"])


In [None]:
range_h = 1480,1530
range_w = 300,400
import matplotlib.colors as mcolors

plt.figure(figsize=(16,8))
plt.imshow(x.squeeze().numpy(),cmap="gray")
plt.xticks([])
plt.yticks([])
rect = patches.Rectangle((range_w[0],range_h[0]),range_w[1]-range_w[0],range_h[1]-range_h[0],linewidth=1,edgecolor='r',facecolor='none')
ax = plt.gca()
ax.add_patch(rect)
plt.show()


plt.figure(figsize=(12,4))
plt.plot(np.mean(x.squeeze().numpy()[range_h[0]:range_h[1],range_w[0]:range_w[1]],axis=0),"-",color="tab:grey",label="grayscale input image") 
plt.plot(np.mean(rgb_filter_image.squeeze()[0].numpy()[range_h[0]:range_h[1],range_w[0]:range_w[1]],axis=0),"o-",color="tab:red",label = "red")
plt.plot(np.mean(rgb_filter_image.squeeze()[1].numpy()[range_h[0]:range_h[1],range_w[0]:range_w[1]],axis=0),"o-",color="tab:green", label = "green")
plt.plot(np.mean(rgb_filter_image.squeeze()[2].numpy()[range_h[0]:range_h[1],range_w[0]:range_w[1]],axis=0),"o-",color="tab:blue",label= "blue")
plt.plot(np.mean(z.squeeze().numpy()[range_h[0]:range_h[1],range_w[0]:range_w[1]],axis=0),"o-",color="k",label="lenticule boundary (z)")
plt.legend()
plt.show()