In [None]:
import skimage
from skimage import io
import matplotlib.pyplot as plt
import json
import cv2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.patches as patches
import matplotlib

In [None]:
def preprocess_image(img, mask, homography):
    '''
    cuts out provided mask for timestamp
    warps image to the perspective of the center camera with provided homography matrix
    '''
    img = img[..., :3].astype(np.float32)/255
    img[mask==0] = np.nan
    processed_img = img.copy()
    for dim in range(3):
        processed_img[:,:,dim] = cv2.warpPerspective(img[:,:,dim],homography,img.shape[:2],borderMode = cv2.BORDER_CONSTANT,borderValue = np.nan) 
    return processed_img

In [None]:
def integrate_images(img_list):
    '''
    averages all images in img_list
    removes pixels of the average image that are not covered by all images
    '''
    integrated_img = np.nanmean(np.array(img_list), axis = 0)
    integrated_img_nan = integrated_img.copy()
    for dim in range(3):
        integrated_img_nan[:,:,dim] = np.where(np.isnan(img_list).any(axis = 0).all(axis = 2), np.nan, integrated_img[:,:,dim])
    return integrated_img_nan

In [None]:
def crop_image(img):
    '''
    crops images to a hardcoded size, so that no nan-values occur in any of the images
    '''
    cropped_img = img[370:685,140:840,:]
    return cropped_img

In [None]:
def label_image(img,label,shift = False):
    '''
    labels image with provided labels as red rectangles
    shift: set True for cropped images and False for images in original size
    '''
    if shift:
        x_shift, y_shift = 140, 370
    else:
        x_shift, y_shift = 0, 0
    x = label[0]-x_shift
    y = label[1]-y_shift
    w = label[2]
    h = label[3]
    for i in range(x, x+w+1):
        img[y,i,:] = [1,0,0]
        img[y+h+1,i,:] = [1,0,0]
    for j in range(y+1, y+h):
        img[j,x,:] = [1,0,0]
        img[j,x+w+1,:] = [1,0,0]
    return img

In [None]:
def create_all_images(dataset, save = False, plot = True, label = False):
    '''
    creates integrated and cropped images as well as images showing the variance between 7 timesteps 
        and images depicting only pixels with the highest 0.1% of variance (var_threshold)
    dataset: 'train', 'test' or 'validation'
    save: set to True if images should be saved in subfolders of 'data_WiSAR/data'
    plot: set to True if images should be plotted in the jupyter notebook
    label: set to True if images should be labelled given that labels are provided - labelled images get
        saved in a separate folder    
    '''
    
    data_dir = os.path.join('data_WiSAR','data')
    mask = skimage.io.imread(os.path.join(data_dir,'mask.png'))
    if label:
        try:
            with open(os.path.join(data_dir,dataset,'labels.json')) as f:
                labels = json.load(f)
        except:
            label = False
    labelled_flag = ''
    if label:
        labelled_flag = '_labelled'
    if save:
        if not os.path.isdir(os.path.join(data_dir,dataset + '_processed_images'+labelled_flag)):
            os.makedirs(os.path.join(data_dir,dataset + '_processed_images'+labelled_flag))
    for folder in os.listdir(os.path.join(data_dir, dataset)):
        if folder.startswith(dataset) or folder.startswith('valid'):
            if save:
                save_folder = os.path.join(data_dir,dataset + '_processed_images'+labelled_flag,folder)
                if not os.path.isdir(save_folder):
                    os.makedirs(save_folder)
            with open(os.path.join(data_dir,dataset,folder,'homographies.json')) as f:
                homographies = json.load(f)
            if label:
                label_coords = labels[folder][0]
            all_processed_images = []
            cropped_images = []
            for timepoint in [0,1,2,3,4,5,6]:
                processed_images = []
                for camera in ['B01', 'B02', 'B03', 'B04', 'B05', 'G01', 'G02', 'G03', 'G04', 'G05']:
                    image = skimage.io.imread(os.path.join(data_dir,dataset,folder,str(timepoint)+'-'+camera+'.png'))
                    homography = np.array(homographies[str(timepoint) + '-' + camera])
                    processed_image = preprocess_image(image, mask, homography)
                    processed_images.append(processed_image)
                    all_processed_images.append(processed_image)
                integrated_image = integrate_images(processed_images)
                cropped_image = crop_image(integrated_image)
                cropped_images.append(cropped_image)
                if label:
                    integrated_image = label_image(integrated_image, label = label_coords)
                    cropped_image = label_image(cropped_image, label = label_coords, shift = True)

                if save:
                    matplotlib.image.imsave(os.path.join(save_folder,'integrated_image_'+str(timepoint)+labelled_flag+'.png'), integrated_image)
                    matplotlib.image.imsave(os.path.join(save_folder,'cropped_image_'+str(timepoint)+labelled_flag+'.png'), cropped_image)

                if plot:
                    plt.figure()
                    plt.imshow(integrated_image)
                    plt.figure()
                    plt.imshow(cropped_image)            

            integrated_image_all = integrate_images(all_processed_images)
            cropped_image_all = crop_image(integrated_image_all)
            var_image = np.nanvar(cropped_images, axis = 0)/np.max(np.nanvar(cropped_images, axis = 0))
            var_threshold_image = np.where(var_image > np.nanpercentile(var_image, 99.9), 1.0, 0.0)
                      
            if label:
                integrated_image_all = label_image(integrated_image_all, label = label_coords)
                cropped_image_all = label_image(cropped_image_all, label = label_coords, shift = True)
                var_image = label_image(var_image, label = label_coords, shift = True)
                var_threshold_image = label_image(var_threshold_image, label = label_coords, shift = True)

            if save:
                matplotlib.image.imsave(os.path.join(save_folder,'integrated_image_all'+labelled_flag+'.png'), integrated_image_all)
                matplotlib.image.imsave(os.path.join(save_folder,'cropped_image_all'+labelled_flag+'.png'), cropped_image_all)
                matplotlib.image.imsave(os.path.join(save_folder,'image_var'+labelled_flag+'.png'), var_image)
                matplotlib.image.imsave(os.path.join(save_folder,'image_var_threshold'+labelled_flag+'.png'), var_threshold_image)

            if plot:
                plt.figure()
                plt.imshow(integrated_image_all)
                plt.figure()
                plt.imshow(cropped_image_all)   
                plt.figure()
                plt.imshow(var_image)
                plt.figure()
                plt.imshow(var_threshold_image) 

In [None]:
# create and save all images
create_all_images('train', save = True, plot = False, label = False)
create_all_images('test', save = True, plot = False, label = False)
create_all_images('validation', save = True, plot = False, label = False)
create_all_images('validation', save = True, plot = False, label = True)