In [None]:
"""
XraySpectraFitter: Fits x-ray spectra generated from XraySpectraGenerator
Created by Jerry LaRue, larue@chapman.edu, 1/2016
Last modified by Jerry LaRue, larue@chapman.edu, 2/2016
"""

import numpy as np
import h5py
import os
import matplotlib.pyplot as plt
import GaussFitter

class SpectraFitter ( object ) :
    
    def __init__ ( self ) :
        
        pass
    
    def Go ( self, File_Parameters, Input, Output ) :
        
        self.File_Parameters = File_Parameters
        self.Folder_Input = Input[0]
        self.File_Input = Input[1]
        self.Folder_Output = Output[0]
        self.File_Output = Output[1]
        
        print '----------------------------------------'
        print 'Fitting spectra from: '
        print self.Folder_Input + self.File_Input
        
        ##### Parameters #####
        
        ParametersFile = __import__(self.File_Parameters, globals(), locals(), [], -1)
        par = ParametersFile.Parameters()
        par.Analysis()
        par.Fitting()
        self.Experiment = par.Experiment
        
        ##### Check for files #####
        
        self.Success = True
        if os.path.isfile(self.Folder_Input + self.File_Input) :
            f = h5py.File(self.Folder_Input + self.File_Input, 'r')
        else :
            self.Success = False
            print self.File_Input + ' file missing'
        
        if os.path.isfile(self.Folder_Input + self.File_Input) :
            
            # General
            if not 'Run_List' in f :
                self.Success = False
                print 'Run list missing'
            if not 'Experiment' in f:
                self.Success = False
                print 'Experiment name missing'
            
            # Energy
            if not 'Energy/Units' in f :
                self.Success = False
                print 'Energy units missing'
            if not 'Energy/Resolution' in f :
                self.Success = False
                print 'Energy resolution missing'
            if not 'Energy/Values' in f :
                self.Success = False
                print 'Energy values missing'
            
            # Delay
            if not 'Delay/Units' in f :
                self.Success = False
                print 'Delay units missing'
            if not 'Delay/Resolution' in f :
                self.Success = False
                print 'Delay resolution missing'
            if not 'Delay/Values' in f :
                self.Success = False
                print 'Delay Values missing'
            
            # Spectra
            if not par.Spectra_Signal in f :
                self.Success = False
                print 'Signal spectra missing'
            if par.Fit_Signal_NumGauss == 0 :
                self.Success = False
                print 'Need to have at least 1 gauss to fit spectra'
            if par.Fit_Reference_NumGauss > 0 :
                if not par.Spectra_Reference in f :
                    self.Success = False
                    print 'Reference spectra missing'
            
            # XES Scale
            if par.Spectra_Type == 'XES' :
                if not 'XES/Values' in f :
                    self.Success = False
                    print 'XES values missing'
                if not 'XES/Units' in f :
                    self.Success = False
                    print 'XES units missing'
        
        if self.Success :
            
            ##### Load data #####
            
            # General
            Runs = f['Run_List'][...]
            Experiment = f['Experiment'][...]
            
            # Energy
            Energy_Units = f['Energy/Units'][...]
            Energy_Resolution = f['Energy/Resolution'][...]
            Energy_Values = f['Energy/Values'][...]
            
            # Delay
            Delay_Units = f['Delay/Units'][...]
            Delay_Resolution = f['Delay/Resolution'][...]
            Delay_Values = f['Delay/Values'][...]
            
            # Spectra
            Spectra_Signal = f[par.Spectra_Signal][...]
            if par.Fit_Reference_NumGauss > 0 :
                Spectra_Reference = f[par.Spectra_Reference][...]
            
            # XES Scale
            if par.Spectra_Type == 'XES' :
                Spectra_XES_Values = f['XES/Values'][...]
                Spectra_XES_Units = f['XES/Units'][...]
            
            f.close()
            
            # Output File
            if self.Folder_Output + self.File_Output == self.Folder_Input + self.File_Input :
                f = h5py.File(self.Folder_Output + self.File_Output, 'a')
            else :
                if not os.path.exists(self.Folder_Output):
                    os.makedirs(self.Folder_Output)
                f = h5py.File(self.Folder_Output + self.File_Output, 'w')
            dt = h5py.special_dtype(vlen=bytes)
        
            ##### XES spectra #####
            
            if par.Spectra_Type == 'XES' :
                Index_Energy = (np.abs(par.Fit_Energy_Value - Energy_Values)).argmin()
                Spectra_Signal = Spectra_Signal[:,Index_Energy,:]
                if par.Fit_Reference_NumGauss > 0 :
                    Spectra_Reference = Spectra_Reference[:,Index_Energy]
                X_Values = Spectra_XES_Values
                X_Units = Spectra_XES_Units
            elif par.Spectra_Type == 'XAS' :
                X_Values = Energy_Values
                X_Units = Energy_Units
            else :
                print 'Select either XAS or XES'
                self.Success = False
            
            if par.Fit_Signal_NumGauss <= 0 :
                print 'Number of Gaussians for Signal spectra must be >0'
                self.Success = False
            
        if self.Success :
            
            print 'Delay in units of ' + Delay_Units
            if par.Spectra_Type == 'XAS' :
                print 'Energy in units of ' +X_Units
            if par.Spectra_Type == 'XES' :
                print 'XES in units of ' +X_Units
            
            ###### Trim data #####
            
            # Data Range
            X_Index_Min = (np.abs(X_Values - par.Fit_ROI_Min)).argmin()
            X_Index_Max = (np.abs(X_Values - par.Fit_ROI_Max)).argmin()
            
            # Data
            Spectra_Signal = Spectra_Signal[X_Index_Min:X_Index_Max]
            Spectra_Signal = np.transpose(Spectra_Signal)
            if par.Fit_Reference_NumGauss > 0 :
                Spectra_Reference = Spectra_Reference[X_Index_Min:X_Index_Max]
                Spectra_Reference = np.transpose(Spectra_Reference)
            X_Values = X_Values[X_Index_Min:X_Index_Max]
            
            ##### Fit reference spectra #####
            
            Parameters_Fit = np.zeros((0))
            Parameters_Fit_Reference = np.zeros((0))
            Parameters_Fixed = np.zeros((0))
            if par.Fit_Reference_NumGauss > 0 :
                if par.Fit_Reference_NumGauss > 0 :
                    Parameters_Fit = np.append(Parameters_Fit, [par.Fit_Reference_Peak1_Amplitude,par.Fit_Reference_Peak1_Position,par.Fit_Reference_Peak1_Width])
                    Parameters_Fixed = np.append(Parameters_Fixed, [True, True, True])
                if par.Fit_Reference_NumGauss > 1 :
                    Parameters_Fit = np.append(Parameters_Fit, [par.Fit_Reference_Peak2_Amplitude,par.Fit_Reference_Peak2_Position,par.Fit_Reference_Peak2_Width])
                    Parameters_Fixed = np.append(Parameters_Fixed, [True, True, True])
                if par.Fit_Reference_NumGauss > 2 :
                    Parameters_Fit = np.append(Parameters_Fit, [par.Fit_Reference_Peak3_Amplitude,par.Fit_Reference_Peak3_Position,par.Fit_Reference_Peak3_Width])
                    Parameters_Fixed = np.append(Parameters_Fixed, [True, True, True])
                Parameters_Fit_Found = GaussFitter.multigaussfit(X_Values, Spectra_Reference, ngauss=par.Fit_Reference_NumGauss, params=Parameters_Fit)
                Parameters_Fit_Reference = Parameters_Fit_Found[0]            
            if par.Fit_Signal_NumGauss > 0 :
                Parameters_Fit = np.append(Parameters_Fit_Reference, [par.Fit_Signal_Peak1_Amplitude,par.Fit_Signal_Peak1_Position,par.Fit_Signal_Peak1_Width])
                Parameters_Fixed = np.append(Parameters_Fixed, [par.Fit_Signal_Peak1_Amplitude_Fix, par.Fit_Signal_Peak1_Position_Fix, par.Fit_Signal_Peak1_Width_Fix])
            if par.Fit_Signal_NumGauss > 1 :
                Parameters_Fit = np.append(Parameters_Fit, [par.Fit_Signal_Peak2_Amplitude,par.Fit_Signal_Peak2_Position,par.Fit_Signal_Peak2_Width])
                Parameters_Fixed = np.append(Parameters_Fixed, [par.Fit_Signal_Peak2_Amplitude_Fix, par.Fit_Signal_Peak2_Position_Fix, par.Fit_Signal_Peak2_Width_Fix])
            if par.Fit_Signal_NumGauss > 2 :
                Parameters_Fit = np.append(Parameters_Fit, [par.Fit_Signal_Peak3_Amplitude,par.Fit_Signal_Peak3_Position,par.Fit_Signal_Peak3_Width])
                Parameters_Fixed = np.append(Parameters_Fixed, [par.Fit_Signal_Peak3_Amplitude_Fix, par.Fit_Signal_Peak3_Position_Fix, par.Fit_Signal_Peak3_Width_Fix])
            
            ##### Fit signal spectra #####
            
            Fit_X_Values = np.zeros((0))
            j = 0
            while par.Fit_ROI_Min + j * par.Fit_X_Delta <= par.Fit_ROI_Max :
                Fit_X_Values = np.append(Fit_X_Values, par.Fit_ROI_Min + j * par.Fit_X_Delta)
                j = j + 1
            Fit_Spectra_Signal = np.zeros((len(Delay_Values),len(Fit_X_Values)))
            if par.Fit_Reference_NumGauss > 0 :
                Fit_Spectra_Reference = np.zeros((len(Fit_X_Values)))
                Fit_Data_Difference = np.zeros((len(Delay_Values),len(Fit_X_Values)))
                Fit_Data_Sum = np.zeros((len(Delay_Values),len(Fit_X_Values)))
                Fit_Data_Contrast = np.zeros((len(Delay_Values),len(Fit_X_Values)))
                j = 0
                while j < len(Fit_X_Values) :
                    k = 0
                    while k < par.Fit_Reference_NumGauss :
                        Fit_Spectra_Reference[j] = Fit_Spectra_Reference[j] + Parameters_Fit[3 * k] * np.exp( -(Fit_X_Values[j] - Parameters_Fit[3 * k + 1])**2 / (2.0*Parameters_Fit[3 * k + 2]**2) )
                        k = k + 1
                    j = j + 1
            
            i = 0
            while i < len(Delay_Values) :
                
                # Fit the Data
                Intensities = np.array([0.1])
                ChiSquares = np.zeros((0))
                Streak = 0
                Go = True
                Counter = 0
                while Go :
                    if Streak > 10 :
                        Go = False
                        Index = (np.abs(ChiSquares - min(ChiSquares))).argmin()
                    else :
                        Index = -1
                    Parameters_Fit_Intensities = np.zeros((0))
                    Parameters_Fit_Intensities = np.append(Parameters_Fit_Intensities,Parameters_Fit)
                    if par.Fit_Reference_NumGauss > 0 :
                        j = 0
                        while j < par.Fit_Reference_NumGauss :
                            Parameters_Fit_Intensities[j * 3] = Parameters_Fit[j * 3] * Intensities[Index]
                            j = j + 1
                    else :
                        Go = False
                    Parameters_Fit_Found = GaussFitter.multigaussfit(X_Values, Spectra_Signal[i], ngauss=par.Fit_Reference_NumGauss+par.Fit_Signal_NumGauss, params=Parameters_Fit_Intensities, err=True, fixed=Parameters_Fixed)
                        #limitedmin=[True,False,True,True,False,True,True,False,True,True,False,True])
                    Parameters_Fit_Signal = Parameters_Fit_Found[0]
                    Counter_Max = 500
                    if Counter > Counter_Max :
                        Go = False
                        print 'Warning: Fit failed after ' + str(Counter_Max) + ' iterations.'
                    if Go :
                        if len(ChiSquares) > 0 :
                            if Parameters_Fit_Found[-1] > ChiSquares[-1] :
                                Streak = Streak + 1
                        ChiSquares = np.append(ChiSquares,Parameters_Fit_Found[-1])
                        Intensities = np.append(Intensities, Intensities[-1] + 0.01)
                    Counter = Counter + 1
                if self.Success :
                    
                    ##### Make fits #####
                    
                    j = 0
                    while j < len (Fit_X_Values) :
                        k = 0
                        while k < par.Fit_Reference_NumGauss + par.Fit_Signal_NumGauss :
                            Fit_Spectra_Signal[i][j] = Fit_Spectra_Signal[i][j] + Parameters_Fit_Signal[3 * k] * np.exp( -(Fit_X_Values[j] - Parameters_Fit_Signal[3 * k + 1])**2 / (2.0*Parameters_Fit_Signal[3 * k + 2]**2) )
                            k = k + 1
                        j = j + 1
                    if par.Fit_Reference_NumGauss > 0 :
                        Fit_Data_Difference[i,:] = Fit_Spectra_Signal[i,:] - Fit_Spectra_Reference
                        Fit_Data_Sum[i,:] = Fit_Spectra_Signal[i,:] + Fit_Spectra_Reference
                        Fit_Data_Contrast[i,:] = Fit_Data_Difference[i,:] / Fit_Data_Sum[i,:] * 100
                    
                    ##### Plot data #####
                    
                    # Figure Data
                    Fit_Spectra_Reference_Gauss = np.zeros((par.Fit_Reference_NumGauss,len(Fit_X_Values)))
                    Fit_Spectra_Signal_Gauss = np.zeros((par.Fit_Reference_NumGauss + par.Fit_Signal_NumGauss,len(Fit_X_Values)))
                    j = 0
                    while j < len(Fit_X_Values) :
                        k = 0
                        while k < par.Fit_Reference_NumGauss :
                            Fit_Spectra_Reference_Gauss[k][j] = Fit_Spectra_Reference_Gauss[k][j] + Parameters_Fit_Reference[3 * k] * np.exp( -(Fit_X_Values[j] - Parameters_Fit_Reference[3 * k + 1])**2 / (2.0*Parameters_Fit_Reference[3 * k + 2]**2) )
                            k = k + 1
                        k = 0
                        while k < par.Fit_Reference_NumGauss + par.Fit_Signal_NumGauss :
                            Fit_Spectra_Signal_Gauss[k][j] = Fit_Spectra_Signal_Gauss[k][j] + Parameters_Fit_Signal[3 * k] * np.exp( -(Fit_X_Values[j] - Parameters_Fit_Signal[3 * k + 1])**2 / (2.0*Parameters_Fit_Signal[3 * k + 2]**2) )
                            k = k + 1
                        j = j + 1
                    
                    if par.Show_Plots == True :
                        # Reference & Signal
                        fig = plt.figure()
                        Signal = fig.add_subplot(111)
                        Plot_Spectra_Signal = Signal.plot(X_Values, Spectra_Signal[i])
                        Plot_Fit_Signal = Signal.plot(Fit_X_Values, Fit_Spectra_Signal[i])
                        if par.Fit_Reference_NumGauss > 0 :
                            Reference = fig.add_subplot(111)
                            Diff = fig.add_subplot(111)
                            Plot_Spectra_Reference = Reference.plot(X_Values, Spectra_Reference)
                            Plot_Fit_Reference = Reference.plot(Fit_X_Values, Fit_Spectra_Reference)
                            Diff = Signal.plot(Fit_X_Values, Fit_Data_Difference[i])
                        Signal.set_xlabel('Energy')
                        Signal.set_ylabel('Intensity [au]')
                        Signal.set_title('Data ' + str(int(Delay_Values[i])) + ' fs')
                        fig.show()
                    
                    # Total intensities
                    Intensity_Reference = 0
                    Intensity_Signal = 0
                    j = 0
                    while j < par.Fit_Reference_NumGauss + par.Fit_Signal_NumGauss :
                        if j < par.Fit_Reference_NumGauss :
                            Intensity_Reference = Intensity_Reference + Parameters_Fit_Reference[j * 3] * Parameters_Fit_Reference[j * 3 + 2]
                        Intensity_Signal = Intensity_Signal + Parameters_Fit_Signal[j * 3] * Parameters_Fit_Signal[j * 3 + 2]
                        j = j + 1
                    
                    print '********************'
                    print 'Delay = ' + str(Delay_Values[i]) + ' fs'
                    if par.Fit_Reference_NumGauss > 0 :
                        print 'Reference Intensity = ' + str(Intensity_Reference)
                    print 'Signal Intensity = ' + str(Intensity_Signal)
                i = i + 1
            
            ##### Save to file #####
            
            print '********************'
            
            # Group name
            if par.Spectra_Type == 'XAS' :
                Fits_Name = 'Fits_XAS/'
            if par.Spectra_Type == 'XES' :
                Fits_Name = 'Fits_XES_' + str(Energy_Values[Index_Energy]) +'/'
                Fits_Name = Fits_Name.replace('.','_')
            
            if self.Folder_Output + self.File_Output == self.Folder_Input + self.File_Input :
                print 'Appending fits to:'
                print self.Folder_Output + self.File_Output
            else :
                print 'Saving fits to:'
                print self.Folder_Output + self.File_Output
                
                # General
                dataSet = f.create_dataset('Run_List', data = Runs, dtype = 'int32')
                dataSet = f.create_dataset('Experiment', data = Experiment, dtype = dt)
                
                # Energy
                dataSetText = f.create_dataset('Energy/Units', data = Energy_Units, dtype = dt)
                dataSet = f.create_dataset('Energy/Resolution', data = Energy_Resolution, dtype = np.float64)
                dataSet = f.create_dataset('Energy/Values', data = Energy_Values, dtype = np.float64)
                
                # Delay
                dataSetText = f.create_dataset('Delay/Units', data = Delay_Units, dtype = dt)
                dataSet = f.create_dataset('Delay/Resolution', data = Delay_Resolution, dtype = np.float64)
                dataSetText = f.create_dataset('Delay/Values', data = Delay_Values, dtype = np.float64)
            
            # Clear data for overwriting
            if Fits_Name in f :
                del f[Fits_Name]
            
            # X
            if par.Spectra_Type == 'XAS' :
                dataSetText = f.create_dataset(Fits_Name + 'Energy_Units', data = X_Units, dtype = dt)
                dataSet = f.create_dataset(Fits_Name + 'Energy_Values', data = Fit_X_Values, dtype = np.float64)
            if par.Spectra_Type == 'XES' :
                dataSetText = f.create_dataset(Fits_Name + 'XES_Units', data = X_Units, dtype = dt)
                dataSet = f.create_dataset(Fits_Name + 'XES_Values', data = Fit_X_Values, dtype = np.float64)
            
            # Fits
            dataSet = f.create_dataset(Fits_Name + 'Signal', data = Fit_Spectra_Signal, dtype = np.float64)
            if par.Fit_Reference_NumGauss > 0 :
                dataSet = f.create_dataset(Fits_Name + 'Reference', data = Fit_Spectra_Reference, dtype = np.float64)
            
            f.close()
            
            print 'Done\n'
            
        else :
            print 'Fitting cancelled\n'
    
    def File_Parameters ( self ) :
    
        if self.Success :
            return self.File_Parameters
        else :
            return ''
    
    def Folder_Output ( self ) :
        
        if self.Success :
            return self.Folder_Output
        else :
            return ''
    
    def File_Output ( self ) :
        
        if self.Success :
            return self.File_Output
        else :
            return ''
        
    def Experiment ( self ) :
        
        if self.Success :
            return self.Experiment
        else :
            return ''
    
    def Success ( self ) :
        
        return self.Success