# Filtering lidar point clouds #

To run this notebook you need to have several packages installed. One way to install in through Anaconda with the following:

`conda env create -n jack -f environment.yml`

`conda activate jack`

This notebook was designed to filter lidar point clouds to a desired pulse density. It works through some functions defined in the first code cell, and then applies these to a test file in the following code cells. The process involves:
- reading in the data to a recAray
- removing broken pulses from the data (pulses whose returns are not all present in the data)
- creating point and pulse density raster images from the point clouds
- spatially filtering pulses through random sampling within 10 m grid cells, to achive desired pulse densities
- creating new LAZ files of the filtered data

The first code cell imports the required libraries and functions.

In [None]:
import os
import sys
import laspy
import pynninterp
import numpy
import matplotlib.pyplot as plt
from numba import jit
from osgeo import gdal
from osgeo import osr
from scipy import ndimage
from scipy import interpolate
from pyproj import crs


def las2rec(infile):
    """
    Reads in a laz file and returns a recArray.
    """
    with laspy.open(infile) as fh:
        las = fh.read()
    las_data = numpy.rec.fromarrays([las.return_num, las.num_returns, las.gps_time, las.intensity,
                                     las.classification, las.x, las.y, las.z],
                                 names=['RETURN_NUMBER', 'NUMBER_OF_RETURNS', 'TIMESTAMP', 'INTENSITY',
                                        'CLASSIFICATION', 'X', 'Y', 'Z'],
                                 formats=['u1', 'u1', '<f8', '<i4', 'u1', '<f8', '<f8', '<f8'])
    
    las_data.sort(order='TIMESTAMP')
    
    return las_data, las.header


def removeBrokenPulses(data):
    """
    Removes broken pulses from the input data recArray. Checks if
    TIMESTAMPs are the same and return_numbers are complete for each
    return in each pulse
    """
    firstReturnIndex = numpy.where(data['RETURN_NUMBER'] == 1)[0]
    goodReturns = numpy.zeros(data['X'].size, dtype=numpy.uint8)
    for i in firstReturnIndex:
        numReturns = data['NUMBER_OF_RETURNS'][i]
        if numReturns == 1:
            goodReturns[i] = 1
        if numReturns > 1:
            t1 = data['TIMESTAMP'][i]
            t = data['TIMESTAMP'][i:i+numReturns]
            returnNums = data['RETURN_NUMBER'][i:i+numReturns]
            if ((t1 * numReturns == numpy.sum(t)) and
                (numpy.sum(returnNums) == sum(range(1, numReturns+1)))):
                goodReturns[i:i+numReturns] = 1
    return data[goodReturns == 1]


def filterByPulse(points, density):
    """
    For each 10 x 10 m grid cell, the filtered number of pulses are selected randomly.
    The input 'density' is the desired number of pulses per cell (100m^2), so 300 = 3 pulses per m^2
    """
    pixelSize = 10
    returnNumber = points['RETURN_NUMBER']
    numberOfReturns = points['NUMBER_OF_RETURNS']
    (x, y, z) = (points['X'], points['Y'], points['Z'])
    minX, maxX, minY, maxY, minZ, maxZ = get_mmXYZ(x,y,z)
    if minX == int(minX):
        minX = minX - pixelSize
    if minY == int(minY):
        minY = minY - pixelSize
    tileXsize = numpy.ceil(maxX) - int(minX)
    tileYsize = numpy.ceil(maxY) - int(minY)
    nRows = int(numpy.ceil(tileYsize / pixelSize))
    nCols = int(numpy.ceil(tileXsize / pixelSize))
    
    firstReturnsToKeep = numpy.zeros(x.size, dtype=numpy.uint8)
    (row, col) = xyToRowCol(x, y, int(minX), int(numpy.ceil(maxY)), pixelSize)
    for r in range(nRows):
        for c in range(nCols):
            ind = numpy.where((row == r) & (col == c) & (returnNumber == 1))[0]
            numPulses = x[ind].size
            if numPulses > density:
                selection = numpy.random.choice(ind, size=density, replace=False)
                firstReturnsToKeep[selection] = 1
    
    # Set pointsToKeep to 1 for all subsequent returns by using NUMBER_OF_RETURNS
    pointsToKeep = numpy.copy(firstReturnsToKeep)
    uniqueNums = numpy.unique(numberOfReturns[(numberOfReturns > 1) & (firstReturnsToKeep == 1)])
    for i in uniqueNums:
        for j in range(1, i):
            ind = numpy.where((numberOfReturns == i) & (firstReturnsToKeep == 1))[0]
            pointsToKeep[ind + j] = 1
    
    return points[pointsToKeep == 1]
    

def gridDensity(points, epsg, outfile, imageType, pixelSize=10):
    """
    Create a raster grid of point density 
    """
    (x, y, z) = (points['X'], points['Y'], points['Z'])
    minX, maxX, minY, maxY, minZ, maxZ = get_mmXYZ(x,y,z)
    if minX == int(minX):
        minX = minX - pixelSize
    if minY == int(minY):
        minY = minY - pixelSize
    tileXsize = numpy.ceil(maxX) - int(minX)
    tileYsize = numpy.ceil(maxY) - int(minY)
    nRows = int(numpy.ceil(tileYsize / pixelSize))
    nCols = int(numpy.ceil(tileXsize / pixelSize))
    
    if imageType == "pointcount":
        (row, col) = xyToRowCol(x, y, int(minX), int(numpy.ceil(maxY)), pixelSize)  
        pointCount = numpy.zeros((nRows, nCols), dtype=numpy.uint16)
        pointGridding(pointCount, row, col)
        writeImage(pointCount, outfile, driver='GTiff', tlx=int(minX),
                   tly=int(numpy.ceil(maxY)), binsize=pixelSize, epsg=epsg,
                   nullVal=9999)
    
    elif imageType == "pulsecount":
        x = x[points['RETURN_NUMBER'] == 1]
        y = y[points['RETURN_NUMBER'] == 1]
        (row, col) = xyToRowCol(x, y, int(minX), int(numpy.ceil(maxY)), pixelSize)      
        pulseCount = numpy.zeros((nRows, nCols), dtype=numpy.uint16)
        pointGridding(pulseCount, row, col)
        writeImage(pulseCount, outfile, driver='GTiff', tlx=int(minX),
                   tly=int(numpy.ceil(maxY)), binsize=pixelSize, epsg=epsg,
                   nullVal=9999)
    else:
        print('Error: imagetype must be either pointcount or pulsecount')


def writeImage(image, outfile, driver='GTiff', tlx=0.0, tly=0.0, binsize=0.0,
               epsg=None, nullVal=None):
    """
    Write data to a GDAL supported image file format
    """
    if len(image.shape)==2:
        ny,nx = image.shape
        nz=1
    if len(image.shape)==3:
        nz,ny,nx = image.shape            
    driver = gdal.GetDriverByName(driver)
    dt = image.dtype
    
    if dt == 'uint8': gdaldtype = gdal.GDT_Byte
    if dt == 'int16': gdaldtype = gdal.GDT_Int16
    if dt == 'uint16': gdaldtype = gdal.GDT_UInt16
    if dt == 'int32': gdaldtype = gdal.GDT_Int32
    if dt == 'float32': gdaldtype = gdal.GDT_Float32
    if dt == 'float64': gdaldtype = gdal.GDT_Float64
    
    ds = driver.Create(outfile, nx, ny, nz, gdaldtype, ['COMPRESS=LZW'])
    ds.SetGeoTransform([tlx,binsize,0,tly,0,-binsize])

    if epsg is not None:
        proj = osr.SpatialReference()
        proj.ImportFromEPSG(epsg)
        ds.SetProjection(proj.ExportToWkt())
    if nz>1:
        for i in range(nz):
            band = ds.GetRasterBand(i+1)
            band.WriteArray(image[i,:,:],0,0)
    else:
        band = ds.GetRasterBand(1)    
        band.WriteArray(image,0,0)
    
    # Set the null value on every band
    if nullVal is not None:
        for i in range(nz):
            band = ds.GetRasterBand(i+1)
            band.SetNoDataValue(nullVal)

    ds.FlushCache()
    

def get_mmXYZ(x,y,z):
    """
    basic XYZ min/max extraction
    """
    minX=numpy.min(x)
    maxX=numpy.max(x)
    minY=numpy.min(y)
    maxY=numpy.max(y)
    minZ=numpy.min(z)
    maxZ=numpy.max(z)
    return minX,maxX,minY,maxY,minZ,maxZ

        
def xyToRowCol(x, y, xMin, yMax, pixSize):
    """
    For the given pixel size and xMin, yMax, convert the given arrays of x and y
    into arrays of row and column in a regular grid across the tile.
    """
    col = ((x - numpy.floor(xMin)) / pixSize).astype(numpy.uint32)
    row = ((numpy.ceil(yMax) - y) / pixSize).astype(numpy.uint32)
    return (row, col)


def rec2las(outlas, data, header, epsg):
    """
    Creates a laz file from a recArray of data, using the supplies epsg and header information.
    """
    las = laspy.create(point_format=header.point_format, file_version=header.version)
    las.header.scales = header.scales
    las.header.offset = header.offset
    las.header.add_crs(crs.CRS.from_user_input(epsg))
    las.return_num = data['RETURN_NUMBER']
    las.num_returns = data['NUMBER_OF_RETURNS']
    las.gps_time = data['TIMESTAMP']
    las.intensity = data['INTENSITY']
    las.classification = data['CLASSIFICATION'] 
    las.x = data['X']
    las.y = data['Y']
    las.z = data['Z']
    las.write(outlas)
    

@jit
def pointGridding(grid, row, col):
    """
    Create grid for count of returns.
    """
    numPts = len(row)
    for i in range(numPts):
        grid[row[i], col[i]] += 1


### Example 1 ###

The following cell demonstrates how to read a LAZ file, grid the point and pulse densities, remove the broken pulses, and output a filtered LAZ file.

In [None]:
# Read in test data and make density rasters
lasFile = r'test.laz'
epsg = 28356
(data, header) = las2rec(lasFile)
print('Number of returns %i'%data['X'].size)
gridDensity(data, epsg, lasFile.replace('.laz', '_pointcount.tif'), 'pointcount')
gridDensity(data, epsg, lasFile.replace('.laz', '_pulsecount.tif'), 'pulsecount')

# Remove broken pulses, make density rasters, and export to laz
data = removeBrokenPulses(data)
print('Returns after removing broken pulses %i'%data['X'].size)
outLas = r'test_fixed.laz'
gridDensity(data, epsg, outLas.replace('.laz', '_pointcount.tif'), 'pointcount')
gridDensity(data, epsg, outLas.replace('.laz', '_pulsecount.tif'), 'pulsecount')
rec2las(outLas, data, header, epsg)
(data, header) = las2rec(outLas)

### Example 2 ###

The following cell demonstrates how to filter a point cloud to a desired pulse density. 

In [None]:
# Read in the LAZ fle we fixed
lasFile = r'test_fixed.laz'
epsg = 28356
(data, header) = las2rec(lasFile)
print('Number of returns %i'%data['X'].size)

# Filter to 190 pulses per 100m^2, create a new LAZ file and grid the point and pulse density 
# (this is just under then minimum point density for the dataset)
data = filterByPulse(data, 190)
print('Number of returns %i'%data['X'].size)
outLas = r'test_fixed_190.laz'
rec2las(outLas, data, header, epsg)
gridDensity(data, epsg, outLas.replace('.laz', '_pointcount.tif'), 'pointcount')
gridDensity(data, epsg, outLas.replace('.laz', '_pulsecount.tif'), 'pulsecount')