In [2]:
# Install modules if not existing

!pip install -q vtk
!pip install -q bitstruct==3.10.0
!pip install -q bitstring==3.1.5

# Import necessary libraries

%matplotlib inline
from IPython.display import clear_output, HTML, display
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
display(HTML("<style>div.output_scroll  {height: 30em}; </style>"))
import ipywidgets as wd
import vtk
import numpy as np
from vtk.util.numpy_support import vtk_to_numpy
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
from xml.etree import ElementTree as ET
import os
import h5py
import subprocess
import warnings
warnings.filterwarnings("ignore")

# Set figure size

plt.rcParams["figure.figsize"] = (15,10)

# Add gprMax to session. This needs to be executed every time a new Noteable session is started

c_dir=os.getcwd()

try:
    subprocess.run(['python', '-m', 'gprMax', '-h', '/dev/null'], check = True, capture_output=True)
except subprocess.CalledProcessError:
    os.chdir(os.path.expanduser("~"))
    os.chdir('gprMax')
    os.system('python setup.py install >/dev/null')
    
print('gprMax has been added to session!')
os.chdir(c_dir)

def gprMax_Ascan(filename, rxnumber, rxcomponent):
    
    """Gets A-scan output data from a model.
    Args:
        filename (string): Filename (including path) of output file.
        rxnumber (int): Receiver output number.
        rxcomponent (str): Receiver output field/current component.
    Returns:
        ascan (array): A-scan data.
        time (array): Time steps of the simulation.
    """

    # Open output file and read some attributes
    
    f = h5py.File(filename, 'r')
    nrx = f.attrs['nrx']
    dt = f.attrs['dt']

    # Check there are any receivers
    
    if nrx == 0:
        raise CmdInputError('No receivers found in {}'.format(filename))

    path = '/rxs/rx' + str(rxnumber) + '/'
    availableoutputs = list(f[path].keys())
    g = f[path]
    pos = np.array(g.attrs['Position'])
           
    # Check if requested output is in file
    
    if rxcomponent not in availableoutputs:
        raise CmdInputError('{} output requested, but the available output for receiver 1 is {}'.format(rxcomponent, ', '.join(availableoutputs)))

    ascan = f[path + '/' + rxcomponent]
    ascan = np.array(ascan)
    time = np.linspace(0, ascan.size*dt, ascan.size)/1e-9
    f.close()

    return ascan, time, pos

def gprMax_Bscan(filename, rx, rxcomponent):
    
    """Gets B-scan output data from a model.
    Args:
        filename (string): Filename (including path) of output file.
        rxnumber (int): Receiver output number.
        rxcomponent (str): Receiver output field/current component.
    Returns:
        bscan (array): Array of A-scans, i.e. B-scan data.
        time (array): Time steps of the simulation.
    """
    
    import h5py
    import os
    import glob
    
    # Load data
    
    filename = filename[0:-4]
    os.system(f'python -m tools.outputfiles_merge {filename}')
    
    file = h5py.File(f'{filename}_merged.out', 'r')
    
    path = f'/rxs/rx{rx}/{rxcomponent}'
    bscan = np.array(file.get(path))
        
    file.close()
    
    # Load one A-scan and read attributes
    
    file0 = filename + '1' + '.out'
    out, time, pos = gprMax_Ascan(file0, rx, rxcomponent)

    return bscan, time

def plot_Ascan(scan, time, rotate = False):
    
    """ Plot A-scan data.
    Args:
        scan (array): A-scan Field values
        time (array): Time steps
        rotate (boolean): If True rotate A-scan plot by 90 degrees
    Returns:
        None
    """
    
    offset = 0
    
    x = scan
    y = time
    
    xmin = np.min(x)
    xmax = np.max(x)
        
    ymin = np.min(y)
    ymax = np.max(y)
    
    plt.tight_layout()
    ax=plt.gca() 
    aspect = 0.5
    
    if rotate==False:
        p = plt.plot(x,y,'k-')
        plt.fill_betweenx(y,offset,x,where=(x>offset),color='k')       
        ax.set_xlim([-np.max(np.abs(x)), np.max(np.abs(x))])
        ax.set_ylim([ymin, ymax])
        ax.set_ylim(ax.get_ylim()[::-1])        # invert the axis
        ax.set_ylabel('Time [ns]')
        ax.set_xlabel('Field Strength [V/m]')
        ax.xaxis.tick_top() 
        scale_str = ax.get_yaxis().get_scale()
        if scale_str=='linear':
            asp = abs((xmax-xmin)/(ymax-ymin))/aspect
        elif scale_str=='log':
            asp = abs((scipy.log(xmax)-scipy.log(xmin))/(scipy.log(ymax)-scipy.log(ymin)))/aspect
    elif rotate==True:
        p = plt.plot(y,x,'k-')
        plt.fill_between(y,offset,x,where=(x>offset),color='k')       
        ax.set_ylim([-np.max(np.abs(x)), np.max(np.abs(x))])
        ax.set_xlim([ymin, ymax])
        ax.set_xlabel('Time [ns]')
        ax.set_ylabel('Field Strength [V/m]')
        ax.xaxis.tick_bottom() 
        scale_str = ax.get_yaxis().get_scale()
        if scale_str=='linear':
            asp = abs((ymax-ymin)/(xmax-xmin))/(aspect*3)
        elif scale_str=='log':
            asp = abs((scipy.log(ymax)-scipy.log(ymin))/(scipy.log(xmax)-scipy.log(xmin)))/(aspect*3)
  
    ax.grid(which='both', axis='both', linestyle='-.')
    ax.set_aspect(asp)
    plt.show()

def plot_Bscan(scan, time, cmap='seismic', time_offset=0):
    
    """ Plot B-scan data.
    Args:
        scan (array): B-scan Field values
        time (array): Time steps
        cmap (string): Color map
        time_offset (float): Time offset
    Returns:
        None
    """
        
    scan_max = np.max(np.max(np.abs(scan)))
    plt.imshow(scan, cmap=cmap, extent=[0,scan.shape[1],np.max(time)-time_offset,0-time_offset], aspect=6, vmin=-scan_max, vmax=scan_max)
    plt.colorbar
    ax=plt.gca()
    ax.set_xlabel('Trace Number ')
    ax.set_ylabel('Time [ns]')
    plt.show()    
    

def plot_source(type, amp, center_freq):
    from gprMax.waveforms import Waveform
    from tools.plot_source_wave import check_timewindow, mpl_plot
    
    w = Waveform()
    w.type = type
    w.amp = amp
    w.freq = center_freq
    
    timewindow = 15*(1.0/center_freq)
    dt = 1/(500*center_freq)
    
    timewindow, iterations = check_timewindow(timewindow, dt)
    plt = mpl_plot(w, timewindow, dt, iterations, fft=True)
    
def view_file(filename):
    
    """ Print on screen gprMax input file.
    Args:
        filename (string): Name of file to print on screen (.in extension)
    Returns:
        None
    """
    
    f = open(filename, 'r')
    inputFile = f.read()
    print(inputFile)
    
def gprMax_model(filename):
    
    """ Display gprMax model 2D geometry on screen.
    Args:
        filename (string): Name of geometry file to load (.vti extension)
    Returns:
        None
    """
    
    objects = []
    materials = []
    srcs_pml = []
    rxs = []
    with open(filename, 'rb') as f:       
        for line in f:
            if line.startswith(b'<Material'):
                line.rstrip(b'\n')
                tmp = (int(ET.fromstring(line).text), ET.fromstring(line).attrib.get('name'))
                materials.append(tmp)
            elif line.startswith(b'<Sources') or line.startswith(b'<PML'):
                line.rstrip(b'\n')
                tmp = (int(ET.fromstring(line).text), ET.fromstring(line).attrib.get('name'))
                srcs_pml.append(tmp)
            elif line.startswith(b'<Receivers'):
                line.rstrip(b'\n')
                tmp = (int(ET.fromstring(line).text), ET.fromstring(line).attrib.get('name'))
                rxs.append(tmp)
                
    
    reader = vtk.vtkXMLImageDataReader()
    reader.SetFileName(filename)
    reader.Update()
    vti = reader.GetOutput()
    shape = vti.GetDimensions()
    extent = vti.GetExtent()
    spacing = vti.GetSpacing()
    
    #print(extent, spacing, shape)
    x_start = extent[0]
    x_stop = extent[1]
    y_start = extent[2]
    y_stop = extent[3]
    z_start = extent[4]
    z_stop = extent[5]
    
    dx = spacing[0]
    dy = spacing[1]
    dz = spacing[2]
    
    if z_stop == 1:
        S1 = y_stop
        d1 = dy
        S2 = x_stop
        d2 = dx
    if y_stop == 1:
        S1 = z_stop
        d1 = dz
        S2 = x_stop
        d2 = dx
    if x_stop == 1:
        S1 = z_stop
        d1 = dz
        S2 = y_stop
        d2 = dy   
    
    domain = vtk_to_numpy(vti.GetCellData().GetArray('Material'))
    domain = domain.flatten().reshape(S1, S2)
    PML_Tx = vtk_to_numpy(vti.GetCellData().GetArray('Sources_PML'))
    PML_Tx = PML_Tx.flatten().reshape(S1, S2)
    Rx = vtk_to_numpy(vti.GetCellData().GetArray('Receivers'))
    Rx = Rx.flatten().reshape(S1, S2)
   
    for i in range(len(materials)):
        tmp_mat = materials[i]
        ind = materials.index(tmp_mat)
        w = np.where(domain == tmp_mat[0])
        if w[0].size > 0:
            objects.append(tmp_mat[1])
            domain[w]=objects.index(tmp_mat[1])           
            
    for i in range(len(srcs_pml)):
        tmp = srcs_pml[i]
        ind = srcs_pml.index(tmp)
        w = np.where(PML_Tx == tmp[0])
        if w[0].size > 0:
            objects.append(tmp[1])
            domain[w]=objects.index(tmp[1])
           
    for i in range(len(rxs)):
        tmp = rxs[i]
        ind = rxs.index(tmp)
        w = np.where(Rx == tmp[0])
        if w[0].size > 0:
            objects.append(tmp[1])
            domain[w]=objects.index(tmp[1])
    
    #create the domain plot
    
    im = plt.imshow(domain, cmap=plt.get_cmap('jet'), interpolation='nearest', origin='lower', extent=[0, S2*d2, 0, S1*d1])
    
    #get unique entries in the domain  
    
    entries = np.unique(domain.ravel())
    
    # get the colour for every entry from the colormap used by imshow
    
    colours = [ im.cmap(im.norm(entry)) for entry in entries]
    
    # create a patch for every colour 
    
    patches = [ mpatches.Patch(color=colours[i], label="{l}".format(l=objects[i])) for i in range(len(entries)) ]
         
    #put patches as legend-handles into the legend
    
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
    ax=plt.gca()  # get the axis
    ax.set_xlabel('Metres [m]')
    ax.set_ylabel('Metres [m]')
    
    #finally show the plot
    
    plt.show()
    
    
def merge_files(basefilename, removefiles=False):
    """Merges traces (A-scans) from multiple output files into one new file,
        then optionally removes the series of output files.
    Args:
        basefilename (string): Base name of output file series including path.
        outputs (boolean): Flag to remove individual output files after merge.
    """
    import h5py
    import os
    import glob
    from gprMax._version import __version__
    
    outputfile = basefilename + '_merged.out'
    files = glob.glob(basefilename + '*.out')
    outputfiles = [filename for filename in files if '_merged' not in filename]
    modelruns = len(outputfiles)
    print(modelruns)

    # Combined output file
    fout = h5py.File(outputfile, 'w')

    # Add positional data for rxs
    for model in range(modelruns):
        fin = h5py.File(basefilename + str(model + 1) + '.out', 'r')
        nrx = fin.attrs['nrx']
      
        # Write properties for merged file on first iteration
        if model == 0:
            fout.attrs['Title'] = fin.attrs['Title']
            fout.attrs['gprMax'] = __version__
            fout.attrs['Iterations'] = fin.attrs['Iterations']
            fout.attrs['dt'] = fin.attrs['dt']
            fout.attrs['nrx'] = fin.attrs['nrx']
            for rx in range(1, nrx + 1):
                path = '/rxs/rx' + str(rx)
                grp = fout.create_group(path)
                availableoutputs = list(fin[path].keys())
                for output in availableoutputs:
                    grp.create_dataset(output, (fout.attrs['Iterations'], modelruns), dtype=fin[path + '/' + output].dtype)

        # For all receivers
        for rx in range(1, nrx + 1):
            path = '/rxs/rx' + str(rx) + '/'
            availableoutputs = list(fin[path].keys())
            # For all receiver outputs
            for output in availableoutputs:
                fout[path + '/' + output][:, model] = fin[path + '/' + output][:]

        fin.close()

    fout.close()

    if removefiles:
        for model in range(modelruns):
            file = basefilename + str(model + 1) + '.out'
            os.remove(file)

    
def gprMax_to_dzt(filename, rx, rxcomponent, centerFreq, distTx_Rx, trace_step):
    
    """ Convert gprMax output files to .dzt format.
    Args:
        filename (string): Name of file to convert (.out extension)
        rx (integer): Receiver number
        rxcomponent (string): Field component
        centerFreq (float): Center frequency
        distTx_Rx (float): Distance between Tx and Rx
        trace_step (float): Step size
    Returns:
        None
    """
    
    import h5py as h5
    import os
    import sys
    import struct
    import bitstruct
    import datetime
    from bitstring import Bits
    from scipy import signal
    
    # ------------------------------- Information specified by the user ---------------------------------------

    # Specify gprMax file path name
    
    file_path_name = filename

    # Specify center frequency (MHz)
    
    center_freq = centerFreq

    # Specify Tx-Rx distance
    
    distance = distTx_Rx

    # Trace step
    
    trace_step = trace_step

    # Choose E-field component
    
    comp = rxcomponent

    # ---------------------------------------------------------------------------------------------------------    
        
    # Read gprMax data
    
    file = h5py.File(f'{filename}_merged.out', 'r')

    # Read E-field component
    
    path = f'/rxs/rx{rx}/{rxcomponent}'
    data = np.array(file.get(path))
    
    # Read time step
    
    time_step = file.attrs['dt']
    file.close()

    data = (data * 32767)/ np.max(np.abs(data))
    data[data > 32767] = 32767
    data[data < -32768] = -32768
    data = np.round(data)

    # Number of samples and traces
    
    [noSamples, noTraces] = np.shape(data)

    # Convert time step to ns
    
    time_step = time_step*10**9

    # Sampling frequency (MHz)
    
    sampling_freq = (1 / time_step)*10**3

    # Time window (ns)
    
    time_window = time_step*noSamples

    # DZT file name
    
    fileName = filename

    # Resample data to 1024 samples

    data = signal.resample(data, 1024)
    time_step = time_window / np.shape(data)[0]
    sampling_freq = (1 / time_step)*10**3

    # ------------------------------------------------ DZT file header -----------------------------------------------------


    tag = 255                                      # 0x00ff if header, 0xfnff for old file Header
    dataOffset = 1024                              # Constant 1024
    noSamples = np.shape(data)[0]                  # Number of samples
    bits = 16                                      # Bits per data word (8 or 16)
    binaryOffset = 32768                           # Binary offset (8 bit -> 128, 16 bit -> 32768)
    sps = 0                                        # Scans per second
    spm = 1 / trace_step                           # Scans per metre
    mpm = 0                                        # Meters per mark
    position = 0                                   # Position (ns)
    time_window = time_window                      # Time window (ns)
    noScans = 0                                    # Number of passes for 2D files

    dateTime = datetime.datetime.now()             # Current datetime

    # Date and time created
    
    createdSec = dateTime.second
    if createdSec > 29: createdSec = 29
    createdMin = dateTime.minute
    createdHour = dateTime.hour
    createdDay = dateTime.day
    createdMonth = dateTime.month
    createdYear = dateTime.year-1980

    # Date and time modified
    
    modifiedSec = dateTime.second
    if modifiedSec > 29: modifiedSec = 29
    modifiedMin = dateTime.minute
    modifiedHour = dateTime.hour
    modifiedDay = dateTime.day
    modifiedMonth = dateTime.month
    modifiedYear = dateTime.year-1980

    offsetRG = 0                                   # Offset to range gain function
    sizeRG = 0                                     # Size of range gain function
    offsetText = 0                                 # Offset to text
    sizeText = 0                                   # Size of text
    offsetPH = 0                                   # Offset to processing history
    sizePH = 0                                     # Size of processing history
    noChannels = 1                                 # Number of channels
    epsr = 5                                       # Average dielectric constant
    topPosition = 0                                # Top position (m)
    vel = (299792458 / np.sqrt(epsr)) * 10 ** -9
    range0 = vel * (time_window / 2)                # Range (meters)
    xStart = 0                                     # X start coordinate
    xFinish = noTraces*trace_step-trace_step       # X finish coordinate
    servoLevel = 0                                 # Gain servo level
    reserved = 0                                   # Reserved
    antConfig = 0                                  # Antenna Configuration
    setupConfig = 0                                # Setup Configuration
    spp = 0                                        # Scans per pass
    noLine = 0                                     # Line number
    yStart = 0                                     # Y start coordinate
    yFinish = 0                                    # Y finish coordinate
    lineOrder = 0
    dataType = 2                                   # Data type

    antennaName ='antName'
    if len(antennaName) > 14:
        antennaName = antennaName[0:14]
    elif len(antennaName) < 14:
        antennaName = antennaName.ljust(14)

    channelMask= 0                                 # Channel mask

    fName = fileName                               # File name
    if len(fName) > 12:
        fName = fName[0:12]
    elif len(fName) < 12:
        fName = fName.ljust(12)

    checkSum = 0                                   # Check sum for header

    # -------------------------------------------------------------------------------------------------------------------


    # ----------------------------------------- Convert to bytes and write to file --------------------------------------

    # Open file to write

    with open(fileName + '.dzt', 'wb') as fid:
 
        # Write header

        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, tag); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, dataOffset); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, noSamples); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, bits); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('s16<', dataStruct, 0, binaryOffset); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, sps); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, spm); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, mpm); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, position); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, time_window); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, noScans); fid.write(dataStruct); 

        sec = Bits(uint=createdSec, length=5)
        min = Bits(uint=createdMin, length=6)
        hour = Bits(uint=createdHour, length=5)
        day = Bits(uint=createdDay, length=5)
        month = Bits(uint=createdMonth, length=4)
        year = Bits(uint=createdYear, length=7)
        b = Bits().join([year, month, day, hour, min, sec])
        createDate = b.tobytes(); fid.write(bitstruct.pack('>r32<', createDate))
    
        sec = Bits(uint=modifiedSec, length=5)
        min = Bits(uint=modifiedMin, length=6)
        hour = Bits(uint=modifiedHour, length=5)
        day = Bits(uint=modifiedDay, length=5)
        month = Bits(uint=modifiedMonth, length=4)
        year = Bits(uint=modifiedYear, length=7)
        b = Bits().join([year, month, day, hour, min, sec])
        modifiedDate = b.tobytes(); fid.write(bitstruct.pack('>r32<', modifiedDate))
    
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, offsetRG); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, sizeRG); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, offsetText); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, sizeText); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, offsetPH); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, sizePH); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, noChannels); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, epsr); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, topPosition); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, range0); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, xStart); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, xFinish); fid.write(dataStruct);
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, servoLevel); fid.write(dataStruct); 
        dataStruct = bytearray(3); bitstruct.pack_into('r24<', dataStruct, 0, reserved); fid.write(dataStruct);
        dataStruct = bytearray(1); bitstruct.pack_into('u8<', dataStruct, 0, antConfig); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, setupConfig); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, spp); fid.write(dataStruct); 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, noLine); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, yStart); fid.write(dataStruct); 
        dataStruct = bytearray(4); bitstruct.pack_into('f32<', dataStruct, 0, yFinish); fid.write(dataStruct); 
        dataStruct = bytearray(1); bitstruct.pack_into('u8<', dataStruct, 0, lineOrder); fid.write(dataStruct); 
        dataStruct = bytearray(1); bitstruct.pack_into('r8<', dataStruct, 0, dataType); fid.write(dataStruct);
        fid.write(bitstruct.pack('t14<', antennaName)) 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, channelMask); fid.write(dataStruct);
        fid.write(bitstruct.pack('t12<', fName)) 
        dataStruct = bytearray(2); bitstruct.pack_into('u16<', dataStruct, 0, checkSum); fid.write(dataStruct); 

        # Move to 1024 to write data

        fid.seek(dataOffset)
        data = data + binaryOffset
        data = np.array(data,dtype='<H')
        fid.write(data.T.astype('<H').tobytes());

        # Close file 
    
        fid.close()
    
    print('Dzt file has been written!')
    