In [1]:
from utils import saver
import os
from os.path import join, exists

import albumentations as albu
import cv2
import numpy as np
import pandas as pd
import torch
import matplotlib
from matplotlib import pyplot as plt
from torch.utils.data import Dataset

import time
import datetime

import utils
from utils.transforms import get_augmentations, get_transforms

In [2]:
config = saver.load_config_dump("configs/heatmap_augmentations_weightedloss_borders.yaml")
config.dataset.train['csv_path']= "playground/synth_nodule.csv"
config.dataset.val['csv_path']= "both_clean.csv"
config.dataset.test['csv_path']= "generated_annotation.csv"

config.dataset.root['NIH'] = "/home/ailab_user/work/data/NIH/images"
config.dataset.root['generated'] = ""
config.dataset.augmentations = []
config.dataset.transforms
config.protocol = "ae"
config.dataset.img_size = 512

In [3]:
def prune_df(sdf):
    df = pd.read_csv('generated_annotation.csv')
    old_len = len(sdf)
    entries = df['Image Index'].apply(lambda x: x[17:-36]+'.png').values
    sdf = sdf[sdf['Image Index'].apply(lambda x: x not in entries)]
    print('DF pruned from {} to {} # of entries in order not to contain processed samples'.format(old_len, len(sdf)))
    return sdf

In [4]:
class CustomDataset(Dataset):

    def __init__(self, config, phase='train'):
        self.df = a
        self.phase = phase
        self.transforms = get_transforms(config)
        self.augmentations = get_augmentations(config) if 'train' in phase else None
        self.config = config

    def __len__(self):
        return len(self.df) * (self.config.dataset.repeat_dataset+1)

    def __getitem__(self, idx):
        """

        :param idx:
        :return: {..., bbox(4, 2): [upLeft, upRight, downLeft, downRight]}
        """
        idx %= len(self.df)
        img_name = self.df.loc[idx, 'Image Index']
        img_path = os.path.join(self.config.dataset.root[self.df.loc[idx,'Dataset']],
                                img_name)
        
        img_path = img_path.replace(".IMG",".png")
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        bbox = np.array(eval(self.df.loc[idx, 'bbox']))

        resize_pair = albu.Compose([albu.Resize(self.config.dataset.img_size, self.config.dataset.img_size)],
                                   keypoint_params={'format': 'xy', "remove_invisible": False})
        augmented = resize_pair(image=image, keypoints=bbox)

        image = augmented['image']
        bbox = np.round(np.array(augmented['keypoints'])).astype(int)

        target = np.copy(image)
        target[bbox[0, 1]:bbox[2, 1], bbox[0, 0]:bbox[1, 0], :] = 0

        sample = {'image': target, 'target': image, 'bbox': bbox, 'img_name': img_name}

        if self.augmentations and 'train' in self.phase:
            augmented = self.augmentations(image=sample['image'], target=sample['target'], keypoints=sample['bbox'])
            if ((np.array(augmented['keypoints']).min() < 0) or
                    (np.array(augmented['keypoints']).max() > self.config.dataset.img_size)):
                print('Wrong augmentations')
            else:
                sample['image'] = augmented['image']
                sample['target'] = augmented['target']
                sample['bbox'] = np.array(augmented['keypoints'])

        if self.transforms:
            # Apply transform to numpy.ndarray which represents sample image
            transformed = self.transforms(image=sample['image'], target=sample['target'], keypoints=sample['bbox'])
            sample['image'] = transformed['image']
            sample['target'] = transformed['target']
            sample['bbox'] = torch.tensor(transformed['keypoints'])

        sample['image'] = sample['image'].float()
        sample['target'] = sample['target'].float()
        sample['bbox'] = sample['bbox'].float()
        return sample

In [6]:
a = pd.read_csv(config.dataset["val"].csv_path)
a = a[a.Dataset != "JSRT"].reset_index()
a = prune_df(a).reset_index()
a = a.sample(frac=1).reset_index(drop=True)
ds = CustomDataset(config)

DF pruned from 9 to 9 # of entries in order not to contain processed samples


In [7]:
import torch
import importlib

from torch import nn

from utils.activations import get_activation
from utils.param_initialization import get_init_func
models_module = importlib.import_module('protocols.{}.models'.format(config.protocol))
model = getattr(models_module, config.model.name)()
model = nn.Sequential(model, get_activation(config))

path_to_weights = "/home/ailab_user/work/CancerAstro/playground/G_latest.pth"
state_dict = torch.load(path_to_weights)
if "module" in list(state_dict.keys())[0]:
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    state_dict = new_state_dict

model.load_state_dict(state_dict)

<All keys matched successfully>

In [8]:
from skimage.transform import match_histograms
from skimage import exposure
import cv2
def hist_match(img, ref, bbox=None, rounds=1, margin = 5):
    img = img.copy()
    if bbox is not None and isinstance(bbox, np.ndarray):
        bbox = bbox.astype(int)
        y_min, y_max = bbox[:, 1].min() - margin, bbox[:, 1].max() + margin
        x_min, x_max = bbox[:, 0].min() - margin, bbox[:, 0].max() + margin
        for _ in range(rounds):
            img[y_min:y_max, x_min:x_max] = match_histograms(img[y_min:y_max, x_min:x_max], ref[y_min:y_max, x_min:x_max])
    else:
        for _ in range(rounds):
            img = match_histograms(img, ref)
    return img.astype(np.float32)

def sharpen(img, kernel, bbox=None, rounds=1):
    img = img.copy()
    if bbox is not None and isinstance(bbox, np.ndarray):
        bbox = bbox.astype(int)
        margin = 0
        y_min, y_max = bbox[:, 1].min() - margin, bbox[:, 1].max() + margin
        x_min, x_max = bbox[:, 0].min() - margin, bbox[:, 0].max() + margin
        for _ in range(rounds):
            img[y_min:y_max, x_min:x_max] = cv2.filter2D(img[y_min:y_max, x_min:x_max], -1, kernel)
    else:
        for _ in range(rounds):
            img = cv2.filter2D(img, -1, kernel)
    return img

def apply(img, func, bbox=None, rounds=1, margin = 5, only_border=False,  **kwargs):
    img = img.copy()
    if bbox is not None and isinstance(bbox, np.ndarray):
        bbox = bbox.astype(int)
        y_min, y_max = bbox[:, 1].min() - margin, bbox[:, 1].max() + margin
        x_min, x_max = bbox[:, 0].min() - margin, bbox[:, 0].max() + margin
        
        y_min_center_box, y_max_center_box = bbox[:, 1].min() + margin, bbox[:, 1].max() - margin
        x_min_center_box, x_max_center_box = bbox[:, 0].min() + margin, bbox[:, 0].max() - margin
        
        if only_border:
            for _ in range(rounds):
                img[y_min:y_min_center_box, x_min:x_max] = func(img[y_min:y_min_center_box, x_min:x_max], **kwargs)
                img[y_max_center_box:y_max, x_min:x_max] = func(img[y_max_center_box:y_max, x_min:x_max], **kwargs)
                img[y_min_center_box:y_max_center_box, x_min:x_min_center_box] = func(img[y_min_center_box:y_max_center_box, x_min:x_min_center_box], **kwargs)
                img[y_min_center_box:y_max_center_box, x_max_center_box:x_max] = func(img[y_min_center_box:y_max_center_box, x_max_center_box:x_max], **kwargs)
        else:
            for _ in range(rounds):
                img[y_min:y_max, x_min:x_max] = func(img[y_min:y_max, x_min:x_max], **kwargs)
    else:
        for _ in range(rounds):
            img = func(img, **kwargs)
    return img

kernels = {'sharpen': np.array([[0,-1,0],
                       [-1,5,-1], 
                       [0,-1,0]]),
          'sharpen_mask': 
          np.array([[1,4,6,4,1],
                   [4,16,24,16,4],
                   [6,24,-476, 24,6],
                   [4,16,24,16,4],
                   [1,4,6,4,1]]) * -1 / 256,
          }
def post_proc(out, trg, bbox):
    """
    out: img with generated nodule
    trg: initial image
    bbox: bbox ^)
    """
    ret = out.copy()

    ret = sharpen(ret, kernels['sharpen_mask'], bbox, rounds = 1);
    ret = apply(ret, cv2.medianBlur, bbox=bbox, rounds=1, ksize=3);
    
    ret = sharpen(ret, kernels['sharpen_mask'], bbox, rounds = 1);
    ret = apply(ret, cv2.medianBlur, bbox=bbox, rounds=1, ksize=5);
    
    ret = sharpen(ret, kernels['sharpen_mask'], bbox, rounds = 1);
    ret = apply(ret, cv2.medianBlur, bbox=bbox, rounds=1, ksize=3);
    
    ret = hist_match(ret, trg, bbox=bbox)
    return ret

def post_proc2(out, trg, bbox):
    """
    out: img with generated nodule
    trg: initial image
    bbox: bbox ^)
    """
    ret = out.copy()

    ret = hist_match(ret, trg, bbox=bbox, margin=0)
    ret = apply(ret, cv2.medianBlur, bbox=bbox,
                rounds=1, margin=2, ksize=5, only_border=False);

    return ret

In [9]:
import matplotlib.pylab as plt
import numpy as np
from PIL import Image
%matplotlib qt
import time

i = 1
j = 0

data = ds[i]
img = data['target'].permute(1,2,0)
bbox = []
pos = []
is_bbox_plotted = 0
scatter_obj1 = None
scatter_obj2 = None

def onclick(event):
    global pos
    global i
    global img,bbox
    pos.append([event.xdata,event.ydata])
    data = ds[i]
    img = data['target']
    pic.set_data(img.permute(1,2,0))
    
    if len(pos) == 2:
        #[upLeft, upRight, downLeft, downRight]}
        pos = np.array(pos)
        pos = [pos[0],(pos[1][0],pos[0][1]),(pos[0][0],pos[1][1]),pos[1]]
        #radius = 6
      #  pos = pos[0]
       # pos = [(pos[0]-radius,pos[1]-radius),(pos[0]+radius,pos[1]-radius),(pos[0]-radius,pos[1]+radius),(pos[0]+radius,pos[1]+radius)]
        bbox = np.round(np.array(pos)).astype(int)
        target = np.copy(img)
        target[:,bbox[0, 1]:bbox[2, 1], bbox[0, 0]:bbox[1, 0]] = 0
        output = model(torch.tensor(target).unsqueeze(0))[0].permute(1,2,0)
        pic.set_data(output.detach().numpy())
        f.canvas.draw()
        f.canvas.flush_events()
        pos = []
        
def onpress(event):
    global i,j
    global pic2,bbox
    global tag
    global is_bbox_plotted, scatter_obj1, scatter_obj2
    global annotation
    global data
    tag = 'init'
    if event.key == "x":
        i+=1
        img = get_current_image()
        pic.set_data(img)
        bbox = []
        f.canvas.draw()
        f.canvas.flush_events()
    if event.key == "z":
        i-=1
        img = get_current_image()
        pic.set_data(img)
        bbox = []
        f.canvas.draw()
        f.canvas.flush_events()
    if event.key == "c":
        cur_img = pic.get_array()
        orig_img = get_current_image().numpy()
        postprocess = process[j](cur_img,orig_img,bbox)
        tag = process[j].__name__
        if j==len(process)-1:
            j = 0
        else:
            j+=1
        pic2 = ax[1].imshow(postprocess)
        f.canvas.draw()
        f.canvas.flush_events()
    if event.key == "a":
        img = pic2.get_array().filled().transpose(0,1,2)
        img_name = generate_name(tag)
        annotation = annotation.append({'bbox': str(bbox.tolist()), 'Image Index': img_name,
                                       'Dataset': 'generated'}, ignore_index=True)
        annotation.to_csv('generated_annotation.csv', index=False)
        im = Image.fromarray((img * 255).astype(np.uint8))
        im.save(img_name)
    if event.key == "b":
        if data['img_name'].startswith('generated_images'):
            bbox = data['bbox']
        if is_bbox_plotted:
            scatter_obj1.remove()
            scatter_obj2.remove()
            
            f.canvas.draw()
            f.canvas.flush_events()

            is_bbox_plotted = 0
        elif not len(bbox) == 0:
            scatter_obj1 = ax[0].scatter(bbox[:, 0], bbox[:, 1], c='r', s=2 ** 3)
            scatter_obj2 = ax[1].scatter(bbox[:, 0], bbox[:, 1], c='r', s=2 ** 3)
            is_bbox_plotted = 1
            f.canvas.draw()
            f.canvas.flush_events()

        
def generate_name(tag):
    return 'generated_images/' + ds[i]['img_name'][:-4]+'_' + \
                str(datetime.datetime.now().strftime("%d-%m-%Y_(%H-%M-%S)")) + "_synthetic.png"

def identity_proc(out, trg, bbox):
    return out

def get_current_image():
    global data
    data = ds[i]
    img = data['target'].permute(1,2,0)
    return img

process = [post_proc,post_proc2,identity_proc]
f, ax = plt.subplots(1,2, figsize=(60,30), sharey=True)
a = ax[0]
pic = a.imshow(img)
plt.tight_layout()
pic2 = []
annotation = pd.read_csv('generated_annotation.csv')

f.canvas.mpl_connect("key_press_event",onpress)
f.canvas.mpl_connect('button_press_event', onclick)
f.show()

'''
You can save only after postproc applied
generated_annotation.csv must exist before
to see Generated images with bbox's set path to generated_images.csv
Show only not processed before images
''';

In [20]:
df = pd.read_csv('generated_annotation.csv')
df.tail()

Unnamed: 0.1,Unnamed: 0,Image Index,bbox,Dataset
35,,generated_images/00027224_000_13-02-2020_(15-2...,"[[114, 234], [130, 234], [114, 254], [130, 254]]",generated
36,,generated_images/00021009_000_13-02-2020_(15-2...,"[[114, 211], [133, 211], [114, 233], [133, 233]]",generated
37,,generated_images/00000699_000_13-02-2020_(15-2...,"[[144, 140], [168, 140], [144, 164], [168, 164]]",generated
38,,generated_images/00019961_001_13-02-2020_(15-3...,"[[171, 163], [191, 163], [171, 178], [191, 178]]",generated
39,,generated_images/00018185_000_13-02-2020_(15-3...,"[[357, 226], [374, 226], [357, 245], [374, 245]]",generated
