# This script runs a trained semantic segmenter on HiRISE images using leave-one-region-out cross-validation.

In [None]:
base_folder = ../../'
presaved_folder = base_folder+'Data/Images/HiRISE_8bit_and_P4_mask/'

PatchHeight = 4096 #enough to fit the entire width
PatchWidth=4096
NumChunksAcross = 1
ExtraPaddingPixels = 128

In [None]:
#select a GPU
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

#imports
import numpy as np
import pandas as pd
import copy
import time
import PIL
import tensorflow
import glob as glob
print("tensorflow version = ",tensorflow.__version__)

from p4tools import io
from SemanticSegmentation import model_hrnet_P4,GetStats
from P4_DataHandlers import GetMarkingsCentres,AddHeightBorderVal,CreateFullImageMasks,Load_P4_Data_PreSaved


In [None]:
#get meta data information
metadata_df = io.get_meta_data()

#centres of fans and ellipses in P4 catalog
MarkingsCentres_df = GetMarkingsCentres() #takes about 80 seconds

#tiles 
Tiles_df = io.get_tile_coordinates()
UniqueHiRiseImages = Tiles_df['obsid'].unique()
print("Number of HiRise Images = ",UniqueHiRiseImages.shape[0])

#regions
region_names_df = io.get_region_names()
region_names_df = region_names_df.set_index('obsid')
region_names_df.at['ESP_012620_0975','roi_name'] = 'Buffalo'
region_names_df.at['ESP_012277_0975','roi_name'] = 'Buffalo'
region_names_df.at['ESP_012348_0975','roi_name'] = 'Taichung'

#other meta data
ImageResults_df = io.get_meta_data()
ImageResults_df = ImageResults_df.set_index('OBSERVATION_ID')

#join metadata and region data
ImageResults_df = pd.concat([ImageResults_df, region_names_df], axis=1, sort=False)
ImageResults_df=ImageResults_df.dropna()

UniqueP4Regions = ImageResults_df['roi_name'].unique()
print("Number of P4 regions = ",len(UniqueP4Regions))


In [None]:

Tiles_df = io.get_tile_coordinates()
Tiles_df['x_hirise']=Tiles_df['x_hirise'].astype('int')
Tiles_df['y_hirise']=Tiles_df['y_hirise'].astype('int')

TotalTiles = 0
for Im in  UniqueHiRiseImages:
    ThisImageTiles = Tiles_df[Tiles_df['obsid']==Im]
    ThisImageTiles=ThisImageTiles.reset_index(drop=True)
    TotalTiles = TotalTiles+ThisImageTiles.shape[0]
Tiles_df.insert(0,'tile_id0',Tiles_df['tile_id'].values)
Tiles_df.insert(1,'obsid0',Tiles_df['obsid'].values)
Tiles_df.insert(2,'obsid1',Tiles_df['obsid'].values)
for i in range(Tiles_df.shape[0]):
    Tiles_df.at[i,'obsid1']=Tiles_df.at[i,'obsid1'][-4::]
Tiles_df.insert(3,'region',Tiles_df['obsid'].values)
Tiles_df.insert(4,'SOLAR_LONGITUDE',Tiles_df['obsid'].values)
Tiles_df.insert(5,'START_TIME',Tiles_df['obsid'].values)
Tiles_df.insert(6,'JI',np.zeros(TotalTiles))
Tiles_df.insert(7,'Dice',np.zeros(TotalTiles))
Tiles_df.insert(8,'Precision',np.zeros(TotalTiles))
Tiles_df.insert(9,'Recall',np.zeros(TotalTiles))
Tiles_df.insert(10,'TP',np.zeros(TotalTiles))
Tiles_df.insert(11,'TN',np.zeros(TotalTiles))
Tiles_df.insert(12,'FP',np.zeros(TotalTiles))
Tiles_df.insert(13,'FN',np.zeros(TotalTiles))
Tiles_df.head()

In [None]:
#create placeholder cols to hold model results 
ImageResults_df.insert(0,'Region name',np.empty(UniqueHiRiseImages.shape[0],dtype=np.str))
ImageResults_df.insert(1,'Image name',np.empty(UniqueHiRiseImages.shape[0],dtype=np.str))
ImageResults_df.insert(2,'JI',np.zeros(UniqueHiRiseImages.shape[0]))
ImageResults_df.insert(3,'Dice',np.zeros(UniqueHiRiseImages.shape[0]))
ImageResults_df.insert(4,'Area Ratio',np.zeros(UniqueHiRiseImages.shape[0]))
ImageResults_df.insert(5,'Recall',np.zeros(UniqueHiRiseImages.shape[0]))
ImageResults_df.insert(6,'Precision',np.zeros(UniqueHiRiseImages.shape[0]))
ImageResults_df.insert(7,'TP',np.zeros(UniqueHiRiseImages.shape[0]))
ImageResults_df.insert(8,'TN',np.zeros(UniqueHiRiseImages.shape[0]))
ImageResults_df.insert(9,'FP',np.zeros(UniqueHiRiseImages.shape[0]))
ImageResults_df.insert(10,'FN',np.zeros(UniqueHiRiseImages.shape[0]))
ImageResults_df.insert(11,'Centre Correct',np.zeros(UniqueHiRiseImages.shape[0])) 
ImageResults_df

# Run all the models on left out validation regions - downloaded HiRISE Images:


In [None]:
ModelCount = 0
model = model_hrnet_P4(3,2,16,0)
for ToLeaveOut in UniqueP4Regions:

    #select data for validation
    LeaveOneRegionOutList = [n for n in UniqueP4Regions if n != ToLeaveOut]
    print('running trained model ',str(ModelCount+1),' of ',str(len(UniqueP4Regions)),'; using leave out region: ',ToLeaveOut)

    #load the model's weights for this left out region
    model.load_weights(base_folder+'Data/Models/HiRISE_segmenter_leave_out_'+ToLeaveOut+'final.h5')

    #load this region's images
    [ImageList_val,Scales_val,ValImageNames,RegionNames] = Load_P4_Data_PreSaved(presaved_folder,
                                                                      [ToLeaveOut],
                                                                      ImageResults_df, #will be indexed by obs_id
                                                                      Tiles_df) #will be indexed by obs_id
    
    #run the model on all loaded images, get stats, and optionally display and save prediction
    CC=0
    for val_im in ImageList_val:
        
        Im=ValImageNames[CC]
        RegionInd = metadata_df.index[metadata_df['OBSERVATION_ID'] == Im].tolist()[0]
        SolarLong = metadata_df.at[RegionInd,'SOLAR_LONGITUDE']
        StartTime = metadata_df.at[RegionInd,'START_TIME']
        print(Im,ToLeaveOut,SolarLong,StartTime)

        
        ImageResults_df.at[ValImageNames[CC],'Region name']=ToLeaveOut
        ImageResults_df.at[ValImageNames[CC],'Image name']=ValImageNames[CC]

        #images are too big to feed to neural network in one go, so chunk them
        NumChunksDown = int(np.ceil(val_im.shape[0]/PatchHeight))
        
        #pad the image for running predictions on
        Padded_val_im = 128*np.ones((1,NumChunksDown*PatchHeight,NumChunksAcross*PatchWidth,3),'uint8')   
        WidthOffset =int((PatchWidth - val_im.shape[1])/2)
        Padded_val_im[0,0:val_im.shape[0],WidthOffset:WidthOffset+val_im.shape[1],:] = val_im[:,:,0:3]
        Padded_val_im = np.expand_dims(AddHeightBorderVal(Padded_val_im[0,:,:,:],ExtraPaddingPixels),0)

        #don't pad the mask
        val_mask = val_im[:,:,3]

        #run the model
        PredMask = np.zeros((Padded_val_im.shape[1],Padded_val_im.shape[2]),'uint8')
        for i in range(NumChunksDown):
            for j in range(NumChunksAcross):
                InputIm = Padded_val_im[0:1,i*PatchHeight:(i+1)*PatchHeight,j*PatchWidth:(j+1)*PatchWidth,:]
                PredMask[i*PatchHeight:(i+1)*PatchHeight,j*PatchWidth:(j+1)*PatchWidth] = np.argmax(np.squeeze(model.predict(InputIm)),axis=-1)

        #take extra padding off
        PredMask = PredMask[ExtraPaddingPixels:ExtraPaddingPixels+val_mask.shape[0],:]        
        #take padding off the prediction mask
        PredMask = PredMask[0:val_mask.shape[0],WidthOffset:WidthOffset+val_mask.shape[1]]
        
        #get stats
        ImageResults_df.at[ValImageNames[CC],['JI','Dice','Area Ratio','Recall','Precision','TP','TN','FP','FN']]=GetStats(val_mask,PredMask)

        #get stats on centres
        CorrectCount = 0.0
        This_df = MarkingsCentres_df[MarkingsCentres_df['Image Name']==ValImageNames[CC]]
        for i in range(This_df.shape[0]):
            if This_df['row'].values[i]<PredMask.shape[0] and This_df['column'].values[i]<PredMask.shape[1]:
                if PredMask[This_df['row'].values[i],This_df['column'].values[i]]==1:
                    CorrectCount=CorrectCount+1.0
        ImageResults_df.at[ValImageNames[CC],'Centre Correct Fraction']=CorrectCount/This_df.shape[0]
        ImageResults_df.at[ValImageNames[CC],'Centre Correct Count']=CorrectCount
        ImageResults_df.at[ValImageNames[CC],'Centre P4 Count']=This_df.shape[0]
            
        ThisImageTiles = Tiles_df[Tiles_df['obsid']==Im]
        ThisImageTiles=ThisImageTiles.reset_index(drop=True)
        for i in range(ThisImageTiles.shape[0]):

            #val_mask,PredMask
            Start_col =int(ThisImageTiles.at[i,'x_hirise']-420)
            Start_row =int(ThisImageTiles.at[i,'y_hirise']-324)
            Tile_pred=PredMask[Start_row:min(Start_row+324,PredMask.shape[0]),Start_col:min(Start_col+420,PredMask.shape[1])]
            Tile_truth=val_mask[Start_row:min(Start_row+324,PredMask.shape[0]),Start_col:min(Start_col+420,PredMask.shape[1])]

            
            JI,Dice,AreaRatio,Recall,Precision,TP,TN,FP,FN = GetStats(Tile_truth,Tile_pred)

            Ind = Tiles_df.index[Tiles_df['tile_id0'] == ThisImageTiles.at[i,'tile_id']].tolist()[0]

            Tiles_df.at[Ind,'region']=ToLeaveOut
            Tiles_df.at[Ind,'SOLAR_LONGITUDE']=SolarLong
            Tiles_df.at[Ind,'START_TIME']=StartTime
            Tiles_df.at[Ind,'JI']=JI
            Tiles_df.at[Ind,'Dice']=Dice
            Tiles_df.at[Ind,'Recall']=Recall
            Tiles_df.at[Ind,'Precision']=Precision
            Tiles_df.at[Ind,'TP']=TP
            Tiles_df.at[Ind,'TN']=TN
            Tiles_df.at[Ind,'FP']=FP
            Tiles_df.at[Ind,'FN']=FN
            
        CC=CC+1

    ModelCount=ModelCount+1

ImageResults_df=ImageResults_df.sort_values(['Region name','Image name'])
ImageResults_df.to_csv('../../Data/SummaryResults/LORO_by_image.csv')
Tiles_df.to_csv('../../Data/SummaryResults/Stats_by_tiles.csv')

In [None]:
#get stats by region
RegionResults_df = pd.DataFrame()
for group_name, df_group in  ImageResults_df.groupby('Region name'):
    RegionResults_df.at[group_name,'NumImages']=df_group.shape[0]
    RegionResults_df.at[group_name,'Total TP']=df_group['TP'].sum()
    RegionResults_df.at[group_name,'Total FP']=df_group['FP'].sum()
    RegionResults_df.at[group_name,'Total FN']=df_group['FN'].sum()
    RegionResults_df.at[group_name,'Total TN']=df_group['TN'].sum()
    RegionResults_df.at[group_name,'Total Centre P4 Count']=df_group['Centre P4 Count'].sum()
    RegionResults_df.at[group_name,'Total Centre Correct Count']=df_group['Centre Correct Count'].sum()
    RegionResults_df.at[group_name,'Region JI']=df_group['TP'].sum()/(df_group['TP'].sum()+df_group['FN'].sum()+df_group['FP'].sum())
    RegionResults_df.at[group_name,'Region Dice']=2*df_group['TP'].sum()/(2*df_group['TP'].sum()+df_group['FN'].sum()+df_group['FP'].sum())
    RegionResults_df.at[group_name,'Region Precision']=df_group['TP'].sum()/(df_group['TP'].sum()+df_group['FP'].sum())
    RegionResults_df.at[group_name,'Region Recall']=df_group['TP'].sum()/(df_group['TP'].sum()+df_group['FN'].sum())
    RegionResults_df.at[group_name,'Region Area Ratio']=(df_group['TP'].sum()+df_group['FN'].sum())/(df_group['TP'].sum()+df_group['FP'].sum())
    RegionResults_df.at[group_name,'Region Centre Correct Fraction']=df_group['Centre Correct Count'].sum()/df_group['Centre P4 Count'].sum()
    
    #per image stats
    RegionResults_df.at[group_name,'Median JI']=df_group['JI'].median()
    RegionResults_df.at[group_name,'Median Dice']=df_group['Dice'].median()
    RegionResults_df.at[group_name,'Median Precision']=df_group['Precision'].median()
    RegionResults_df.at[group_name,'Median Recall']=df_group['Recall'].median()
    RegionResults_df.at[group_name,'Median Area Ratio']=df_group['Area Ratio'].median()
    RegionResults_df.at[group_name,'Median Centre Correct Fraction']=df_group['Centre Correct Fraction'].median()
    
    RegionResults_df.at[group_name,'Max JI']=df_group['JI'].max()
    RegionResults_df.at[group_name,'Max Dice']=df_group['Dice'].max()
    RegionResults_df.at[group_name,'Max Precision']=df_group['Precision'].max()
    RegionResults_df.at[group_name,'Max Recall']=df_group['Recall'].max()
    RegionResults_df.at[group_name,'Max Area Ratio']=df_group['Area Ratio'].max()
    RegionResults_df.at[group_name,'Max Centre Correct Fraction']=df_group['Centre Correct Fraction'].max()
    
    RegionResults_df.at[group_name,'Min JI']=df_group['JI'].min()
    RegionResults_df.at[group_name,'Min Dice']=df_group['Dice'].min()
    RegionResults_df.at[group_name,'Min Precision']=df_group['Precision'].min()
    RegionResults_df.at[group_name,'Min Recall']=df_group['Recall'].min()
    RegionResults_df.at[group_name,'Min Area Ratio']=df_group['Area Ratio'].min()
    RegionResults_df.at[group_name,'Min Centre Correct Fraction']=df_group['Centre Correct Fraction'].min()
RegionResults_df=RegionResults_df.sort_values('NumImages')
RegionResults_df.to_csv('../../Data/SummaryResults/LORO.csv')

In [None]:
RegionResults_df['Region Dice'].mean()