In [13]:
from skimage.io import imread
import matplotlib.pyplot as plt 
import numpy as np
from os import listdir
from torchvision.transforms.functional import crop
import torch
from PIL import Image
import seaborn as sns

In [49]:
def stitch_together(dir_name, row_count):
    """
    stitch together the predictions made by the model
    
    Args:
         dir_name: directory containg all of the images
         param row_count: number of images in a row
    
    Returns:
            stitched_img: image where everything is stitched together
    """

    image_names = listdir(dir_name)
    image_names.sort()
    stackedImgs = []
    curStack = []
    cur_row = 1
    for img_name in image_names:
        img = imread(dir_name + img_name)
        curStack.append(img)
        if(cur_row % 10 == 0):
            hstacked = np.hstack(curStack)
            stackedImgs.append(hstacked)
            curStack=[]
            cur_row = 0
        cur_row= cur_row + 1
    out = np.vstack(stackedImgs)
    return out

def generate_overlapping_images(img, window_size, step_size):
    """
    Generate crops from image of size window_size by window_size. Crops can be overlapping by 
    setting the step_size
    
    Args:
        img: image as a numpy array
        window_size: Size of the crops that are generated
        step_size: step size for the window
    Returns:
        None
    """
    m,n,_ = img.shape
    x = 0 
    y = 0
    pil_img = Image.fromarray(img)
    max_x = m - window_size
    max_y = n - window_size
    imgs = []
    count = 0
    print("Max_x is {}".format(max_x))
    print("Max_y is {}".format(max_y))
    while(x<=max_x):
        while(y<=max_y):
            crop_img = crop(pil_img, x, y, window_size, window_size)
            crop_img.save("overlap/Stack_{}.png".format(str(count).zfill(4)) )
            count=count+1
            y = y + step_size
        y = 0 
        x = x + step_size
    return imgs
            
def avg_blend(img1,img2):
    return ((img1+img2)/2).astype(np.uint8)

def merge_horiz(left_img, right_img, step_size):
    """
    merge the two images using step size
    
    Args:
        left_img: the image on the left
        right_img: Image on the right
        Step_size: Step size used while generating the image
    Returns:
        combined_img: Image containing the combination of the two where the overlapping region
                      is merged using a blending function
    """
    m,n,_ = left_img.shape
    m1,n1,_ = right_img.shape
    left = left_img[:, 0:n-step_size,:]
    combine_left = left_img[:,n-step_size:,:].astype('float')
    combine_right = right_img[:, 0:n1-step_size,:].astype('float')
    right = right_img[:,n1-step_size:,:]
    combine = avg_blend(combine_left, combine_right)
    return  np.hstack((left, combine, right))

def merge_vert(top_img, bot_img, step_size):
    """
    merge the two images using step size
    
    Args:
        left_img: the image on the left
        right_img: Image on the right
        Step_size: Step size used while generating the image
    Returns:
        combined_img: Image containing the combination of the two where the overlapping region
                      is merged using a blending function
    """
    m,n,_ = top_img.shape
    top = top_img[0:m-step_size,:,:]
    combine_top = top_img[m-step_size:,:,:].astype('float')
    combine_bot = bot_img[0:step_size,:,:].astype('float')
    bot = bot_img[step_size:,:,:]
    combine = avg_blend(combine_top, combine_bot)
    return  np.vstack((top, combine, bot))

def gen_merged_horiz(img_list, step_size):
    """
    takes a list of images and merges them based on step_size
    
    Args:
        img_list: list containing all of the images
        step_size: the step size to be used 
    Returns:
        img: The merged image
    """
    
    img = img_list[0]
    new_list = img_list[1:]
    for img_r in new_list:
        img = merge_horiz(img, img_r, step_size)
    return img

def gen_merged_vert(img_list, step_size):
    """
    takes a list of images and merges them based on step_size
    
    Args:
        img_list: list containing all of the images
        step_size: the step size to be used 
    Returns:
        img: The merged image
    """
    
    img = img_list[0]
    new_list = img_list[1:]
    for img_t in new_list:
        img = merge_vert(img, img_t, step_size)
    return img

def stitch_blend(dir_name, row_count,step_size):
    """
    stitch together the predictions made by the model
    
    Args:
         dir_name: directory containg all of the images
         param row_count: number of images in a row
    
    Returns:
            out: image where everything is stitched together
    """

    image_names = listdir(dir_name)
    image_names.sort()
    stackedImgs = []
    curStack = []
    cur_row = 1
    for img_name in image_names:
        img = imread(dir_name + img_name)
        curStack.append(img)
        if(cur_row % row_count == 0):
            hstacked = gen_merged_horiz(curStack, step_size)
            stackedImgs.append(hstacked)
            curStack=[]
            cur_row = 0
        cur_row= cur_row + 1
    out = gen_merged_vert(stackedImgs, step_size)
    return out
        
        
    
    
    
            
        
    

In [47]:
real = stitch_together('./real_A/', 10)
ov = stitch_blend('./overlap/', 19,256)