In [None]:
# Use if main package is not in python path
# import os, sys
# sys.path.append(os.path.dirname(os.getcwd()))

import numpy as np
import torch
import matplotlib.pyplot as plt
import celldetection as cd

from torch.cuda.amp import autocast
from models.ArtefactLocalization_forward import ArtefactLocalization
from models.cpn_custom_forward_2 import CPN
from models.BinaryInpainting_forward import BinaryInpainting
from models.ImageInpainting_forward import ImageInpainting
from models.NeuralInpainting import NeuralInpainting
from utils.datasets import TrainSetArtefacts, TrainSetIntact
from utils.ops import convert_to_uint8, mask_cpn

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
cuda = torch.cuda.is_available()
device = torch.device("cuda:0")

# Artifact Localization

In [None]:
crop_dim = 256
data = TrainSetArtefacts(1, crop_dim)

AL = ArtefactLocalization(cuda, device, crop_dim)

In [None]:
crop, _ = data.__getitem__(0)
anno = AL.forward(crop.unsqueeze(1))
anno = anno.cpu().detach().numpy()

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title("Crop")
plt.imshow(crop[0], cmap="gray")
plt.axis("off")
plt.subplot(1,2,2)
plt.title("Predicted Artifact")
plt.imshow(anno[0,0], cmap="gray")
plt.axis("off")
plt.show()

# Binary Inpainting

In [None]:
crop_dim = 256
data = TrainSetIntact(1, crop_dim, n_test=1, test=True)

cpn = CPN(cuda, device, order=2)
BI = BinaryInpainting(cuda, device, max_len=192)

In [None]:
crop, mask = data.__getitem__(0)
    
crop = crop.unsqueeze(1)
mask = mask.unsqueeze(1)
    
cpn_out = cpn.inference(crop)
cpn_masked = mask_cpn(cpn_out, mask)[0]

fourier = cd.asnumpy(cpn_masked["final_fourier"])[0]
locs = cd.asnumpy(cpn_masked["xy"])[0]
contours = cd.data.cpn.fourier2contour(fourier, locs)
labels = cd.data.cpn.contours2labels(contours, size=crop.shape[-2:])

painting = BI.forward(cpn_masked, mask)[0]  # Might happen inplace, investigate that

crop = crop[0,0].cpu().detach().numpy()
mask = mask[0,0].cpu().detach().numpy()
painting = painting[0]

plt.figure(figsize=(16,5))
plt.subplot(1,3,1)
plt.title("Crop", size=12)
plt.imshow(crop, cmap="gray", vmin=-1, vmax=1)
plt.axis("off")
plt.subplot(1,3,2)
plt.title("Masked Segmentation", size=12)
plt.imshow(np.any(labels > 0, axis=-1), cmap="gray", vmin=0, vmax=1)
plt.imshow(mask, cmap="gray", alpha=0.15)
plt.axis("off")
plt.subplot(1,3,3)
plt.title("Binary Painting", size=12)
plt.imshow(painting, cmap="gray", vmin=0, vmax=1)
plt.axis("off")
plt.show()

# Image Inpainting

In [None]:
crop_dim = 256
data = TrainSetIntact(1, crop_dim, n_test=1, test=True)

cpn = CPN(cuda, device, order=2)
BI = BinaryInpainting(cuda, device, max_len=192, order=2)
II = ImageInpainting(cuda, device)

In [None]:
crop, mask = data.__getitem__(0)
    
crop = crop.unsqueeze(1)
mask = mask.unsqueeze(1)

cpn_out = cpn.inference(crop)
cpn_masked = mask_cpn(cpn_out, mask)[0]
with torch.cuda.amp.autocast():
    BI_painting = BI.forward(cpn_masked, mask)[0][0]  # Might happen inplace, investigate that
    painting = II.forward(crop, mask.half(), BI_painting)

crop = crop[0,0].cpu().detach().numpy()
mask = mask[0,0].cpu().detach().numpy()
painting = painting[0,0].cpu().detach().numpy()

plt.figure(figsize=(16,5))
plt.subplot(1,3,1)
plt.title("Masked Crop")
plt.imshow(crop*(1-mask), cmap="gray", vmin=-1, vmax=1)
plt.axis("off")
plt.subplot(1,3,2)
plt.title("Binary Painting")
plt.imshow(BI_painting, cmap="gray", vmin=0, vmax=1)
plt.axis("off")
plt.subplot(1,3,3)
plt.title("Image Painting")
plt.imshow(painting, cmap="gray", vmin=-1, vmax=1)
plt.axis("off")
plt.show()

# Neural Inpainting

In [None]:
crop_dim = 256
data = TrainSetArtefacts(100, crop_dim, n_test=10, test=True)

nin = NeuralInpainting(cuda, device, max_len=192, order=2)

In [None]:
crop, _ = data.__getitem__(0)
crop = crop.unsqueeze(1)

out = nin.inference(crop)
print(out.keys())

cpn = out["cpn"]
labels_masked = cd.data.contours2labels(cd.asnumpy(cpn["contours"][0]), size=crop.shape[2:])
labels_masked = np.any(labels_masked > 0, axis=-1)

crop = out["crop"][0,0].cpu().detach().numpy()
mask = out["mask"][0,0].cpu().detach().numpy()
crop_masked = out["crop_masked"][0,0].cpu().detach().numpy()
labels  = out["binary_painting"][0,0].cpu().detach().numpy()
painting = out["painting"][0,0].cpu().detach().numpy()

crop = convert_to_uint8(crop, inp_range=(-1,1))
mask = convert_to_uint8(mask, inp_range=(0,1))
crop_masked = convert_to_uint8(crop_masked, inp_range=(-1,1))
labels = convert_to_uint8(labels, inp_range=(0,1))
labels_masked = convert_to_uint8(labels_masked, inp_range=(0,1))
painting = convert_to_uint8(painting, inp_range=(-1,1))

plt.figure(figsize=(16,11))
plt.subplot(2,3,1)
plt.title("Crop", size=15)
plt.imshow(crop, cmap="gray", vmin=0, vmax=255)
plt.axis("off")
plt.subplot(2,3,2)
plt.title("Predicted Artifact", size=15)
plt.imshow(mask, cmap="gray", vmin=0, vmax=255)
plt.axis("off")
plt.subplot(2,3,3)
plt.title("Masked Crop", size=15)
plt.imshow(crop_masked, cmap="gray", vmin=0, vmax=255)
plt.axis("off")
plt.subplot(2,3,4)
plt.title("Cell Segmentation", size=15)
plt.imshow(labels_masked, cmap="gray", vmin=0, vmax=255)
plt.axis("off")
plt.subplot(2,3,5)
plt.title("Binary Painting", size=15)
plt.imshow(labels, cmap="gray", vmin=0, vmax=255)
plt.axis("off")
plt.subplot(2,3,6)
plt.title("Image Painting", size=15)
plt.imshow(painting, cmap="gray", vmin=0, vmax=255)
plt.axis("off")
plt.show()