In [75]:
!pip install -q gprMax
!pip install -q vtk


%matplotlib inline
from ipywidgets import *
import vtk
import numpy as np
from vtk.util.numpy_support import vtk_to_numpy
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
from xml.etree import ElementTree as ET

plt.rcParams["figure.figsize"] = (15,10)

def gprMax_model(filename):
    objects = []
    materials = []
    srcs_pml = []
    rxs = []
    with open(filename, 'rb') as f:       
        for line in f:
            if line.startswith(b'<Material'):
                line.rstrip(b'\n')
                tmp = (int(ET.fromstring(line).text), ET.fromstring(line).attrib.get('name'))
                materials.append(tmp)
            elif line.startswith(b'<Sources') or line.startswith(b'<PML'):
                line.rstrip(b'\n')
                tmp = (int(ET.fromstring(line).text), ET.fromstring(line).attrib.get('name'))
                srcs_pml.append(tmp)
            elif line.startswith(b'<Receivers'):
                line.rstrip(b'\n')
                tmp = (int(ET.fromstring(line).text), ET.fromstring(line).attrib.get('name'))
                rxs.append(tmp)
                
    
    reader = vtk.vtkXMLImageDataReader()
    reader.SetFileName(filename)
    reader.Update()
    vti = reader.GetOutput()
    shape = vti.GetDimensions()
    #mat_datarange = vti.GetCellData().GetArray('Material').GetRange()
    #print(mat_datarange)
    extent = vti.GetExtent()
    spacing = vti.GetSpacing()
    
    #print(extent, spacing, shape)
    x_start = extent[0]
    x_stop = extent[1]
    y_start = extent[2]
    y_stop = extent[3]
    z_start = extent[4]
    z_stop = extent[5]
    
    dx = spacing[0]
    dy = spacing[1]
    dz = spacing[2]
    
    if z_stop == 1:
        S1 = y_stop
        d1 = dy
        S2 = x_stop
        d2 = dx
    if y_stop == 1:
        S1 = z_stop
        d1 = dz
        S2 = x_stop
        d2 = dx
    if x_stop == 1:
        S1 = z_stop
        d1 = dz
        S2 = y_stop
        d2 = dy   
    
    domain = vtk_to_numpy(vti.GetCellData().GetArray('Material'))
    domain = domain.flatten().reshape(S1, S2)
    PML_Tx = vtk_to_numpy(vti.GetCellData().GetArray('Sources_PML'))
    PML_Tx = PML_Tx.flatten().reshape(S1, S2)
    Rx = vtk_to_numpy(vti.GetCellData().GetArray('Receivers'))
    Rx = Rx.flatten().reshape(S1, S2)
   
    for i in range(len(materials)):
        tmp_mat = materials[i]
        ind = materials.index(tmp_mat)
        w = np.where(domain == tmp_mat[0])
        if w[0].size > 0:
            objects.append(tmp_mat[1])
            domain[w]=objects.index(tmp_mat[1])           
            
    for i in range(len(srcs_pml)):
        tmp = srcs_pml[i]
        ind = srcs_pml.index(tmp)
        w = np.where(PML_Tx == tmp[0])
        if w[0].size > 0:
            objects.append(tmp[1])
            domain[w]=objects.index(tmp[1])
           
    for i in range(len(rxs)):
        tmp = rxs[i]
        ind = rxs.index(tmp)
        w = np.where(Rx == tmp[0])
        if w[0].size > 0:
            objects.append(tmp[1])
            domain[w]=objects.index(tmp[1])
    
    #create the domain plot
    
    #plt.imshow(PML_Tx, interpolation='nearest', origin='lower', alpha=0.1)
    #plt.imshow(Rx, interpolation='nearest', origin='lower', alpha=0.1)
    im = plt.imshow(domain, cmap=plt.get_cmap('jet'), interpolation='nearest', origin='lower', extent=[0, S2*d2, 0, S1*d1])
    
    #get unique entries in the domain  
    entries = np.unique(domain.ravel())
    # get the colour for every entry from the colormap used by imshow
    colours = [ im.cmap(im.norm(entry)) for entry in entries]
    # create a patch for every colour 
    patches = [ mpatches.Patch(color=colours[i], label="{l}".format(l=objects[i])) for i in range(len(entries)) ]
         
    #put patches as legend-handles into the legend
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
    ax=plt.gca()  # get the axis
    ax.set_xlabel('Metres [m]')
    ax.set_ylabel('Metres [m]')
    #finally show the plot
    plt.show()

    
def gprMax_Ascan(filename, rxnumber, rxcomponent):
    import h5py
    """Gets A-scan output data from a model.
    Args:
        filename (string): Filename (including path) of output file.
        rxnumber (int): Receiver output number.
        rxcomponent (str): Receiver output field/current component.
    Returns:
        outputdata (array): Array of A-scans, i.e. B-scan data.
        dt (float): Temporal resolution of the model.
    """

    # Open output file and read some attributes
    f = h5py.File(filename, 'r')
    nrx = f.attrs['nrx']
    dt = f.attrs['dt']

    # Check there are any receivers
    if nrx == 0:
        raise CmdInputError('No receivers found in {}'.format(filename))

    path = '/rxs/rx' + str(rxnumber) + '/'
    availableoutputs = list(f[path].keys())
    g = f[path]
    pos=np.array(g.attrs['Position'])
           
    # Check if requested output is in file
    if rxcomponent not in availableoutputs:
        raise CmdInputError('{} output requested, but the available output for receiver 1 is {}'.format(rxcomponent, ', '.join(availableoutputs)))

    outputdata = f[path + '/' + rxcomponent]
    outputdata = np.array(outputdata)
    time = np.linspace(0,outputdata.size*dt,outputdata.size)/1e-9
    f.close()

    return outputdata, time, pos

def gprMax_Bscan(filename, rx, rxcomponent):
    import h5py
    import os
    import glob
    
    files = glob.glob(filename + '*.out')
    outputfiles = [filename for filename in files if '_merged' not in filename]
    modelruns = len(outputfiles)
       
    out, time, pos = gprMax_Ascan(filename + '.out', rx, rxcomponent)
     
    spos = np.array(pos, ndmin=2)
    bscan=np.array(out, ndmin=2)
    bscan = bscan.T
      
    for model in range(modelruns-1):
        out, time, pos = gprMax_Ascan(outputfiles[model], rx, rxcomponent)
        out = np.array(out, ndmin=2)
        bscan = np.append(bscan,out.T, axis=1)
        pos = np.array(pos, ndmin=2)
        spos = np.append(spos,pos, axis=0)
        
    return bscan, time, spos


def merge_files(basefilename, removefiles=False):
    """Merges traces (A-scans) from multiple output files into one new file,
        then optionally removes the series of output files.
    Args:
        basefilename (string): Base name of output file series including path.
        outputs (boolean): Flag to remove individual output files after merge.
    """
    import h5py
    import os
    import glob
    from gprMax._version import __version__
    
    outputfile = basefilename + '_merged.out'
    files = glob.glob(basefilename + '*.out')
    outputfiles = [filename for filename in files if '_merged' not in filename]
    modelruns = len(outputfiles)
    print(modelruns)

    # Combined output file
    fout = h5py.File(outputfile, 'w')

    # Add positional data for rxs
    for model in range(modelruns):
        fin = h5py.File(basefilename + str(model + 1) + '.out', 'r')
        nrx = fin.attrs['nrx']
      
        # Write properties for merged file on first iteration
        if model == 0:
            fout.attrs['Title'] = fin.attrs['Title']
            fout.attrs['gprMax'] = __version__
            fout.attrs['Iterations'] = fin.attrs['Iterations']
            fout.attrs['dt'] = fin.attrs['dt']
            fout.attrs['nrx'] = fin.attrs['nrx']
            for rx in range(1, nrx + 1):
                path = '/rxs/rx' + str(rx)
                grp = fout.create_group(path)
                availableoutputs = list(fin[path].keys())
                for output in availableoutputs:
                    grp.create_dataset(output, (fout.attrs['Iterations'], modelruns), dtype=fin[path + '/' + output].dtype)

        # For all receivers
        for rx in range(1, nrx + 1):
            path = '/rxs/rx' + str(rx) + '/'
            availableoutputs = list(fin[path].keys())
            # For all receiver outputs
            for output in availableoutputs:
                fout[path + '/' + output][:, model] = fin[path + '/' + output][:]

        fin.close()

    fout.close()

    if removefiles:
        for model in range(modelruns):
            file = basefilename + str(model + 1) + '.out'
            os.remove(file)

def plot_Ascan(x, y):
        offset = 0
        p = plt.plot(x,y,'k-')
        plt.fill_betweenx(y,offset,x,where=(x>offset),color='k')
        plt.tight_layout(True)
        ax=plt.gca()  # get the axis
        xmin = np.min(x)
        xmax = np.max(x)
        
        ymin = np.min(y)
        ymax = np.max(y)
        
        ax.set_xlim([-np.max(np.abs(x)), np.max(np.abs(x))])
        ax.set_ylim([ymin, ymax])
        ax.set_ylim(ax.get_ylim()[::-1])        # invert the axis
        ax.xaxis.tick_top()   
        scale_str = ax.get_yaxis().get_scale()
       
        xmin,xmax = ax.get_xlim()
        ymin,ymax = ax.get_ylim()
        ax.set_ylabel('Time [ns]')
        ax.set_xlabel('Field Strength [V/m]')
        aspect = 0.5
        if scale_str=='linear':
            asp = abs((xmax-xmin)/(ymax-ymin))/aspect
        elif scale_str=='log':
            asp = abs((scipy.log(xmax)-scipy.log(xmin))/(scipy.log(ymax)-scipy.log(ymin)))/aspect
        ax.grid(which='both', axis='both', linestyle='-.')
        ax.set_aspect(asp)
        plt.show()

def plot_Bscan(scan,time,time_offset=0):
    scan_max = np.max(np.max(np.abs(scan)))
    plt.imshow(scan, cmap='seismic', extent=[0,scan.shape[1],np.max(time)-time_offset,0-time_offset], aspect=100, vmin=-scan_max, vmax=scan_max)
    plt.colorbar
    plt.show()

