In [None]:
import glob
import os
from datetime import datetime
from pathlib import Path

import laspy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from sklearn.metrics import confusion_matrix
from torch_geometric.data import Data, InMemoryDataset

In [None]:
def read_las(pointcloudfile, get_attributes=False, useevery=1):
    """
    :param pointcloudfile: specification of input file (format: las or laz)
    :param get_attributes: if True, will return all attributes in file, otherwise will only return XYZ (default is False)
    :param useevery: value specifies every n-th point to use from input, i.e. simple subsampling (default is 1, i.e. returning every point)
    :return: 3D array of points (x,y,z) of length number of points in input file (or subsampled by 'useevery')
    """

    # Read file
    inFile = laspy.read(pointcloudfile)

    # Get coordinates (XYZ)
    coords = np.vstack((inFile.x, inFile.y, inFile.z)).transpose()
    coords = coords[::useevery, :]

    # Return coordinates only
    if get_attributes == False:
        return coords

    # Return coordinates and attributes
    else:
        las_fields = [info.name for info in inFile.points.point_format.dimensions]
        attributes = {}
        # for las_field in las_fields[3:]:  # skip the X,Y,Z fields
        for las_field in las_fields:  # get all fields
            attributes[las_field] = inFile.points[las_field][::useevery]
        return (coords, attributes)

In [None]:
class PointCloudsInFiles(InMemoryDataset):
    """Point cloud dataset where one data point is a file."""

    def __init__(
        self, root_dir, glob="*", column_name="", max_points=200_000, use_columns=None, augment=None
    ):
        """
        Args:
            root_dir (string): Directory with the datasets
            glob (string): Glob string passed to pathlib.Path.glob
            column_name (string): Column name to use as target variable (e.g. "Classification")
            use_columns (list[string]): Column names to add as additional input
        """
        self.files = list(Path(root_dir).glob(glob))
        self.column_name = column_name
        self.max_points = max_points
        if use_columns is None:
            use_columns = []
        self.use_columns = use_columns
        if augment is None:
            augment = False
        self.augment = augment
        super().__init__()

    def __len__(self):
        # Return length
        if self.augment is False
            return len(self.files)
        else:
            return len(self.files) # NEED TO ADD MULTIPLICATION FOR NUMBER OF AUGMENTS

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Get file name
        filename = str(self.files[idx])

        # Read las/laz file
        coords, attrs = read_las(filename, get_attributes=True)

        # Resample number of points to max_points
        if coords.shape[0] >= self.max_points:
            use_idx = np.random.choice(coords.shape[0], self.max_points, replace=False)
        else:
            use_idx = np.random.choice(coords.shape[0], self.max_points, replace=True)

        # Get x values
        if len(self.use_columns) > 0:
            x = np.empty((self.max_points, len(self.use_columns)), np.float32)
            for eix, entry in enumerate(self.use_columns):
                x[:, eix] = attrs[entry][use_idx]
        else:
            x = coords[use_idx, :]

        # Get coords
        coords = coords - np.mean(coords, axis=0)  # centralize coordinates

        # impute target
        target = attrs[self.column_name]
        target[np.isnan(target)] = np.nanmean(target)

        # Transform data to tensor
        sample = Data(
            x=torch.from_numpy(x).float(),
            y=torch.from_numpy(
                np.unique(np.array(target[use_idx][:, np.newaxis]))
            ).type(torch.LongTensor),
            pos=torch.from_numpy(coords[use_idx, :]).float(),
        )
        if coords.shape[0] < 100:
            return None
        return sample