In [47]:
""" Assignment 1 - Nagaraj Raparthi
""" 

# Import required libraries
import numpy as np
import time
from matplotlib import pyplot as plt
from skimage.transform import resize

# Function to retrieve r, g, b planes from Prokudin-Gorskii glass plate images
def read_strip(path):
    image = plt.imread(path)
    height = int(image.shape[0] / 3)

    # For images with different bit depth
    scalingFactor = 255 if (np.max(image) <= 255) else 65535
    
    # Separating the glass image into R, G, and B channels
    b = image[: height, :] / scalingFactor
    g = image[height: 2 * height, :] / scalingFactor
    r = image[2 * height: 3 * height, :] / scalingFactor
    return r, g, b


# circshift implementation similar to matlab
def circ_shift(channel, shift):
    shifted = np.roll(channel, shift[0], axis = 0)
    shifted = np.roll(shifted, shift[1], axis = 1)
    return shifted


# The main part of the code. Implement the FindShift function
def find_shift(im1, im2, xLow, xHigh, yLow, yHigh):
    
    #print('i range:','(',xLow,',',xHigh,')')
    #print('j range:','(',yLow,',',yHigh,')')
    
    disp_list=[]
    ssd_list=[]
   
    for i in range(xLow,xHigh):
        for j in range(yLow,yHigh):
            check = circ_shift(im1,(i,j))
            s=np.sum((im2-check)**2)
            ssd_list.append(s)
            disp_list.append((i,j))
            
    index=ssd_list.index(min(ssd_list))
    print(disp_list[index])
    return (disp_list[index])

#image pyramid to descale the image
def image_pyramid(Red,Green,Blue,scaleFactor):    
    R4 = resize(r, (r.shape[0] // scaleFactor, r.shape[1] // scaleFactor))
    G4 = resize(g, (g.shape[0] // scaleFactor, g.shape[1] // scaleFactor))
    B4 = resize(b, (b.shape[0] // scaleFactor, b.shape[1] // scaleFactor))
    
    return R4, G4, B4
            

if __name__ == '__main__':
    # Setting the input output file path
    imageDir = '../Images/'
    imageName = 'cathedral.jpg'
    outDir = '../Results/'
    
    ext = imageName.split('.')
    
    # Get r, g, b channels from image strip
    r, g, b = read_strip(imageDir + imageName)
    
    disp_list = []
    ssd_list = []
    
    startTime = time.time()
    
    if ext[1] == 'jpg':
        print('loading JPG image')
        print('Name of the image: ' + imageName)
        
        # Calculate shift
        print('R to G shift:')
        rShift = find_shift(r, g, -20, 21, -20, 21)
        print('B to G shift:')
        bShift = find_shift(b, g, -20, 21, -20, 21)
        
        
    elif ext[1] == 'tif':
        print('loading TIF image')
        print('Name of the image: ' + imageName)
        
        print('Initial scale of HighRes image:',r.shape[0],'*',r.shape[1])
        
        # Calculate Pyramid
        r8, g8, b8 = image_pyramid(r,g,b,8)
        print('Scale down by 8:',r8.shape[0],'*',r8.shape[1])
        
        # Calculate first shift
        print('R to G shift by scaleFactor 8:')
        rShift = find_shift(r8, g8, -20, 21, -20, 21)
        print('B to G shift by scaleFactor 8:')
        bShift = find_shift(b8, g8, -20, 21, -20, 21)
        
        
        
        # Calculate Pyramid
        r4,g4,b4 = image_pyramid(r,g,b,4)
        print('Scale down by 4:',r4.shape[0],'*',r4.shape[1])
        
        # Calculate second shift
        print('R to G shift by scaleFactor 4:')
        rShift = find_shift(r4, g4, ((rShift[0]*2)-2), ((rShift[0]*2)+2), ((rShift[1]*2)-2), ((rShift[1]*2)+2))
        print('B to G shift by scaleFactor 4:')
        bShift = find_shift(b4, g4, ((bShift[0]*2)-2), ((bShift[0]*2)+2), ((bShift[1]*2)-2), ((bShift[1]*2)+2))
        
        # Calculate Pyramid
        r2,g2,b2 = image_pyramid(r,g,b,2)
        print('Scale down by 2:',r2.shape[0],'*',r2.shape[1])
        
        # Calculate third shift
        print('R to G shift by scaleFactor 2:')
        rShift = find_shift(r2, g2, ((rShift[0]*2)-2), ((rShift[0]*2)+2), ((rShift[1]*2)-2), ((rShift[1]*2)+2))
        print('B to G shift by scaleFactor 4:')
        bShift = find_shift(b2, g2, ((bShift[0]*2)-2), ((bShift[0]*2)+2), ((bShift[1]*2)-2), ((bShift[1]*2)+2))
        
        # Top of the pyramid
        print('Back to original image resolution:',r.shape[0],'*',r.shape[1])
        
        # Calculate last shift
        print('R to G shift:')
        rShift = find_shift(r, g, ((rShift[0]*2)-2), ((rShift[0]*2)+2), ((rShift[1]*2)-2), ((rShift[1]*2)+2))
        print('B to G shift:')
        bShift = find_shift(b, g, ((bShift[0]*2)-2), ((bShift[0]*2)+2), ((bShift[1]*2)-2), ((bShift[1]*2)+2))
        
        
        
        
    # Shifting the images using the obtained shift values
    finalB = circ_shift(b, bShift)
    finalG = g
    finalR = circ_shift(r, rShift)
        
    
    
    endTime = time.time()

        
    
    print('Time taken in seconds :',endTime-startTime)

    # Putting together the aligned channels to form the color image
    finalImage = np.stack((finalR, finalG, finalB), axis = 2)
    
    print(finalImage[0])

    # Writing the image to the Results folder
    plt.imsave(outDir + imageName[:-4] + '.jpg', finalImage)
    
    

loading JPG image
Name of the image: cathedral.jpg
R to G shift:
(7, 0)
B to G shift:
(-1, 1)
Time taken in seconds : 6.243360996246338
[[0.99607843 0.99607843 0.99607843]
 [0.99607843 0.99607843 0.99215686]
 [0.99607843 0.99215686 0.99215686]
 ...
 [0.99607843 0.99607843 0.99607843]
 [0.99607843 0.99607843 0.99607843]
 [0.99607843 0.99607843 0.99607843]]
