In [1]:
"""
## author: Kevin Sanchez
## contact: kevin.j.sanchez@nasa.gov
## project: Satillite_Contrail_Unet
## status: Complete, but needs organizing.

Description: 
        Plot a single granule (chosen based on index given by the Image_Num input parameter). On the plot click points to form 
    polygons and from that polygon either change all values within the polygon to positive labels (representing a contrail)
    or negative labels (representing no contrail) using keyboard events (see key events below). The mouse scroll is used to zoom
    into the image. Holding a left mouse click and dragging the mouse will drag the image until the click is released.
    Finally, key events are used to close and save the updated mask, undo the last click, undo the last label task (i.e. using 
    polygon to add or remove contrail labels). If you are opening a granule that has already been updated, you will be asked 
    for keyboard input after the third coding block to identify if you wish to load or overwrite the existing updated mask 
    (or neither which cancels code execution).


    Parameters
    ----------
    Data_Dir_Path: str
        Path to directory with Granules.
    Image_Num: int
        Granule index to skip to from derived list of granules.
    scale: float
        Scales how fast the scroll zoom will zoom in and out.
    Query_Overwrite: bool
        if an updated mask exist, do you want to overwrite the current version
        
        
    Click events
    ------------
    Mouse left click:
        A mouse left click on the plotted image (mask or satellite image) will add a point to a list of points that 
        are used to form a polygon. The point will appear on the image and lines will be drawn between each point in order. 
        Note: mouse clicks must be performed in order (i.e. in a clockwise or counter-clock wise order).
    Mouse right click:
        Holding left click and moving the mose will drag the image until the click is released.
        
    Scroll events
    -------------
    Mouse scroll:
        Scrolling the mouse will zoom in and out of the image.
    
    Key events
    ----------
    ! (shift+1): 
        Remove the last point added with a click event.
    a:
        Add positive labeled points with the polygon (i.e. label points in polygon as contrail on the mask)
    d:
        Delete positive labeled points with the polygon (i.e. label points in polygon as not a contrail on the mask)
    u:
        Undo the last label action (the 'a' or 'd' key events). Points from the click event, for the polygon, are reploted.
    `:
        Toggle between the satellite image and the mask. The polygon from click events will remain.
        starting from the top right corner.
    C (shift+c):
        Close the figure, save the updated mask (as a png and '.contrail-mask' data file) and end the program execution. 
        When rerunning the code, if an updated mask already exists, a text prompt will ask if you wish to overwrite or 
        load the existing updated mask.
            
        
"""
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib import pyplot as plt
%matplotlib tk
import numpy as np
import struct
import os
import glob
import sys

#Parameters
scale = 1.5 #varies mouse scroll zoom sensitivity
Image_Num =  0 #index of image granule in Data_Dir_Path to load
Query_Overwrite = False #False = open updated mask to edit, True = open original mask to edit and overwrite updated mask
Data_Dir_Path = 'C:/Users/kjsanche/Desktop/contrail_practice/'
#Data_Dir_Path = Data_Dir_Path + '2018MYD/092/A2018092.0820/' # load specific granule (set Image_Num = 0)

C:/Users/kjsanche/Desktop/contrail_practice/


## Obtain satellite image and mask paths

In [2]:
# 
auxList = glob.glob(Data_Dir_Path + "/**/*01__1km.AUX", recursive = True)
granulePath = os.path.normpath(auxList[Image_Num] + os.sep + os.pardir) #get parent directory

satImage11 = glob.glob(granulePath + "/31__1km.raw", recursive = True)
satImage12 = glob.glob(granulePath + "/32__1km.raw", recursive = True)
#satImage11_12 = glob.glob(granulePath + "/31-32_1km.raw", recursive = True)

SZ = glob.glob(granulePath + "/SZ__1km.raw", recursive = True)


maskFile = glob.glob(granulePath + "/*.contrail-mask", recursive = True)
maskFile = [ x for x in maskFile if "_sw" not in x ] #remove _sw files

saveImgPath = maskFile[0].split('.contrail-mask')[0]

cntWrites = 0

print('Granule Path:' + saveImgPath)


Granule Path:C:\Users\kjsanche\Desktop\contrail_practice\A2018091.0505\MOD021KM-A2018091.0505


## Functions to read data

In [4]:
def readContrailMask(path, imgDim):
    with open(path, mode='rb') as file: # binary
        fileContent = file.read()
        data = np.uint16(np.reshape(struct.unpack("B"*imgDim[0]*imgDim[1], fileContent),(imgDim[1],imgDim[0])))
    return data


def readSatelliteImage(path, imgDim):
    with open(path, mode='rb') as file: # binary
        fileContent = file.read()
        data = np.uint32(np.reshape(struct.unpack("H"*imgDim[0]*imgDim[1] , fileContent),(imgDim[1],imgDim[0])))
    return data


def getImageDimentions(path):
    f = open(path, "r")
    #print(f.read())
    for line in f:
        print(line)
        if "RawDefinition" in line:
            linestr = line.split()
            dim = [int(linestr[1]), int(linestr[2])]
            break
    f.close()
    return dim


## Run functions to load data

In [5]:
imgDim = getImageDimentions(auxList[Image_Num])

# Read a mask file (or load updated mask file for futher changes)
if not Query_Overwrite:
    with open(maskFile[0] + 'Update', 'rb') as f:
        mask = np.load(f)
else:
    mask = readContrailMask(maskFile[0],imgDim)

# Read the corresponding satellite image files and take difference between 11 and 12 micron images 
satImg11 = readSatelliteImage(satImage11[0],imgDim)
satImg12 = readSatelliteImage(satImage12[0],imgDim)
#SZangle = readSatelliteImage(SZ[0],imgDim)
satImg = satImg11.astype(float)/100-satImg12.astype(float)/100
satImg[satImg>10] = 0

#print('SZ angle =', SZangle.mean()/100)


AuxilaryTarget: SZ__1km.raw

RawDefinition: 2462 2104 1



## Create click, key, scroll events to label dataset

In [6]:
class ZoomPanClick:
    def __init__(self):
        self.press = None
        self.cur_xlim = None
        self.cur_ylim = None
        self.x0 = None
        self.y0 = None
        self.x1 = None
        self.y1 = None
        self.xpress = None
        self.ypress = None


    def zoom_factory(self, ax, base_scale = 1.5):
        def zoom(event):
            cur_xlim = ax.get_xlim()
            cur_ylim = ax.get_ylim()

            xdata = event.xdata # get event x location
            ydata = event.ydata # get event y location

            if event.button == 'up':
                # deal with zoom in
                scale_factor = 1 / base_scale
            elif event.button == 'down':
                # deal with zoom out
                scale_factor = base_scale
            else:
                # deal with something that should never happen
                scale_factor = 1
                print(event.button)

            new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
            new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

            relx = (cur_xlim[1] - xdata)/(cur_xlim[1] - cur_xlim[0])
            rely = (cur_ylim[1] - ydata)/(cur_ylim[1] - cur_ylim[0])

            ax.set_xlim([xdata - new_width * (1-relx), xdata + new_width * (relx)])
            ax.set_ylim([ydata - new_height * (1-rely), ydata + new_height * (rely)])
            ax.figure.canvas.draw()

        fig = ax.get_figure() # get the figure of interest
        fig.canvas.mpl_connect('scroll_event', zoom)
        
        return zoom
    
    def pan_click_factory(self, ax):
        def onPressPan(event):
            if event.button == 3:
                if event.inaxes != ax: return
                self.cur_xlim = ax.get_xlim()
                self.cur_ylim = ax.get_ylim()
                self.press = self.x0, self.y0, event.xdata, event.ydata
                self.x0, self.y0, self.xpress, self.ypress = self.press

        def onRelease(event):
            if event.button == 3:
                self.press = None
                ax.figure.canvas.draw()
            
        def onMotion(event):
            if self.press is None: return
            if event.inaxes != ax: return
            dx = event.xdata - self.xpress
            dy = event.ydata - self.ypress
            self.cur_xlim -= dx
            self.cur_ylim -= dy
            ax.set_xlim(self.cur_xlim)
            ax.set_ylim(self.cur_ylim)
            ax.figure.canvas.draw()
        
        def onPressPoint(event):
            if event.button == 1:
                global clickpt, mask, satImg, imgDim, satImgax, maskImg, lines
                ix, iy = event.xdata, event.ydata
                if ix == None or iy == None:
                    return
                # append click points
                clickpt.append((ix, iy))
                if len(clickpt) > 1:
                    ax.lines.remove(lines[0])
                if len(clickpt)==1:
                    lines = plt.plot(*zip(*clickpt+[clickpt[0]]), '.r', markersize=1)
                else:
                    lines = plt.plot(*zip(*clickpt+[clickpt[0]]), 'r',linewidth=0.2)
                ax.figure.canvas.draw()
            return

        def add_remove_contrail(add_remove):
            global clickpt, mask, maskImg, old_clicks, lines
            #use clicks to make polygon   
            x, y = np.meshgrid(np.arange(imgDim[0]), np.arange(imgDim[1])) # make a canvas with coordinates
            x, y = x.flatten(), y.flatten()
            points = np.vstack((x,y)).T     
            p = Path(clickpt)
            grid = p.contains_points(points)

            #make mask where values in polygon are set to 0
            polygonMask = grid.reshape(imgDim[1],imgDim[0]) # now we have a mask with points inside a polygon
            mask[polygonMask] = add_remove

            #update figure
            maskImg.set_data(mask) # change variable in maskimg
            fig.canvas.draw()
            fig.canvas.flush_events()
            old_clicks = clickpt
            clickpt = [] #reset click values
            ax.lines.remove(lines[0])
            ax.figure.canvas.draw()

            return


        def onClickToggleSave(event):
            global mask, maskFile, cntWrites, clickpt, old_clicks, lines

            #toggle the visible state of the satellite and mask images
            if event.key == '`':
                b1 = satImgax.get_visible()
                b2 = maskImg.get_visible()
                satImgax.set_visible(not b1)
                maskImg.set_visible(not b2)
                plt.draw()

            #undo last add/remove        
            if event.key == 'u':
                #cntWrites = 1
                if cntWrites > 0:
                    with open(maskFile[0] + 'Update' + str(cntWrites), 'rb') as f:
                        mask = np.load(f)
                    cntWrites -= 1
                    maskImg.set_data(mask)
                    fig.canvas.draw()
                    fig.canvas.flush_events()
                clickpt = old_clicks
                lines = plt.plot(*zip(*clickpt+[clickpt[0]]), 'r',linewidth=0.2)
                fig.canvas.draw()
                fig.canvas.flush_events()

            #undo last click point
            if event.key == '!':
                if len(clickpt) > 0:
                    ax.lines.remove(lines[0])
                    clickpt.pop()
                if len(clickpt)==1:
                    lines = plt.plot(*zip(*clickpt+[clickpt[0]]), '.r', markersize=1)
                elif len(clickpt)>1:
                    lines = plt.plot(*zip(*clickpt+[clickpt[0]]), 'r',linewidth=0.2)
                ax.figure.canvas.draw()

            #write current state in temporary file and add contrail
            if event.key == 'a':
                if len(clickpt)>2:
                    cntWrites += 1
                    with open(maskFile[0] + 'Update' + str(cntWrites), 'wb') as f:
                        np.save(f, mask)
                    with open(maskFile[0] + 'Update', 'wb') as f:
                        np.save(f, mask)
                    add_remove_contrail(1)

            #write current state in temporary file and remove contrail
            if event.key == 'd':
                if len(clickpt)>2:
                    cntWrites += 1
                    with open(maskFile[0] + 'Update' + str(cntWrites), 'wb') as f:
                        np.save(f, mask)
                    with open(maskFile[0] + 'Update', 'wb') as f:
                        np.save(f, mask)
                    add_remove_contrail(0)

            #close figure, remove old updates
            if event.key == 'C':
                plt.axis([0, imgDim[0], imgDim[1], 0])
                fig.canvas.draw()
                fig.canvas.flush_events()
                maskImg.set_visible(True)
                plt.draw()
                fig.savefig(saveImgPath + '_finalMask.png')
                plt.close()
                with open(maskFile[0] + 'Update', 'wb') as f:
                    np.save(f, mask)
                while cntWrites > 0:
                    os.remove(maskFile[0] + 'Update' + str(cntWrites))
                    cntWrites -= 1

        fig = ax.get_figure() # get the figure of interest

        # attach the call back
        fig.canvas.mpl_connect('button_press_event',onPressPan)
        fig.canvas.mpl_connect('button_release_event',onRelease)
        fig.canvas.mpl_connect('motion_notify_event',onMotion)
        fig.canvas.mpl_connect('button_press_event',onPressPoint)
        fig.canvas.mpl_connect('key_press_event',onClickToggleSave)

        #return the function
        return onMotion

## Plot image and mask and call event handeling class (ZoomPanClick()) to label data

In [7]:

# Initialize clickpt and lines
clickpt= []
lines = []

fig = plt.figure()
ax = plt.axes([.04,.04,.95,.95])
zp = ZoomPanClick()
figZoom = zp.zoom_factory(ax, base_scale = scale)
figPan = zp.pan_click_factory(ax)

#plot image and mask
maskImg = plt.imshow(mask, cmap = 'gray')
satImgax = plt.imshow(satImg, cmap = 'gray')
manager = plt.get_current_fig_manager()
manager.resize(*manager.window.maxsize())
plt.show()

#save visable image if not already created 
if not os.path.isfile(saveImgPath + '_visImg.png'):
    fig.savefig(saveImgPath + '_visImg.png')

#save original mask image if not already created
if not os.path.isfile(saveImgPath + '_origMask.png'):
    fig.savefig(saveImgPath + '_origMask.png')

satImgax.set_visible(False)

