diff --git a/tardis_pytorch/dist_pytorch/datasets/patches.py b/tardis_pytorch/dist_pytorch/datasets/patches.py index f5feeead..f49e6eba 100644 --- a/tardis_pytorch/dist_pytorch/datasets/patches.py +++ b/tardis_pytorch/dist_pytorch/datasets/patches.py @@ -8,7 +8,7 @@ # MIT License 2021 - 2023 # ####################################################################### -from typing import Optional, Tuple, Union +from typing import Tuple, Union import numpy as np import torch @@ -21,30 +21,23 @@ class PatchDataSet: """ BUILD PATCHED DATASET + Main change in v0.1.0RC3 + - Build moved to 3D patches + Class for computing optimal patch size for a maximum number of points per patch. The optimal size of the patch is determined by max number of points. It works by first - checking if 'init_patch_size' != 0, if True then 'init_patch_size' is set as - max size that contain the whole point cloud. Then class optimizes the size of - 'init_patch_size' by sequentially dropping it by a set 'drop_rate' value. This - action is continue till 'init_patch_size' is < 0 or, class found 'init_patch_size' - where all computed patches have a number of points below the threshold. - - Patches are computed building voxel 2D/3D grid of the size given by 'init_patch_size' + calculating boundary box which is used to build 3D voxels. Voxel size is initiate + and reduced to optimize voxel sizes fo point cloud can be cut fo patches with + specified 'max_number_of_points'. In the end, patches with a smaller number of points are marge with their neighbor in a way that will respect 'max_number_of_points' policy. Output is given as a list of arrays as torch.Tensor or np.ndarray. Args: - label_cls (np.ndarray, None): Optional class id array for each point in the - point cloud. - rgb (np.ndarray, None): Optional RGB feature array for each point in the point - cloud. - patch_3d (bool): If True, compute patches in 3D. If False, patches are - computed in 2D and if coord (N, 3), Z dimension is np.inf. max_number_of_points (int): Maximum allowed a number of points per patch. - init_patch_size (float): Initial patch size. If 0, the initial patch size - is taken as the highest value from the computed boundary box. + voxel_size (float): Initial voxel size from which optimization starts. + overlap (float): Percentage of overlapping voxel size drop_rate (float): Optimizer step size for reducing the size of patches. graph (bool): If True output computed graph for each patch of point cloud. tensor (bool): If True output all datasets as torch.Tensor. @@ -52,332 +45,250 @@ class PatchDataSet: def __init__( self, - label_cls=None, - rgb=None, - patch_3d=False, max_number_of_points=500, - init_patch_size=0, - drop_rate=1, + voxel_size=15, overlap=0.15, + drop_rate=0.1, graph=True, tensor=True, ): - # Global data setting - self.label_cls = label_cls - self.rgb = rgb - self.segments_id = None - self.coord = None - # Point cloud down-sampling setting self.DOWNSAMPLING_TH = max_number_of_points # Patch setting + self.voxel = voxel_size + self.drop_rate = drop_rate self.TORCH_OUTPUT = tensor self.GRAPH_OUTPUT = graph - self.PATCH_3D = patch_3d - self.init_patch_size = init_patch_size - self.INIT_PATCH_SIZE = init_patch_size - self.expand = 0.025 - self.EXPAND = 0.025 # Expand boundary box by 2.5% - self.stride = overlap + self.EXPAND = 0.1 # Expand boundary box by 10% self.STRIDE = overlap # Create 15% overlaps between patches - if init_patch_size == 0: - self.SIZE_EXPAND = self.EXPAND - self.PATCH_STRIDE = self.STRIDE - else: - self.SIZE_EXPAND = init_patch_size * self.EXPAND - self.PATCH_STRIDE = init_patch_size * self.STRIDE + # Initialization + self.INIT_PATCH_SIZE = np.zeros((2, 3)) - self.drop_rate = drop_rate - self.DROP_RATE = drop_rate - - def _init_parameters(self): - # Patch setting - self.INIT_PATCH_SIZE = self.init_patch_size - if self.init_patch_size == 0: - self.EXPAND = self.expand - self.STRIDE = self.stride - self.SIZE_EXPAND = self.EXPAND - self.PATCH_STRIDE = self.STRIDE - else: - self.SIZE_EXPAND = self.init_patch_size * self.expand - self.PATCH_STRIDE = self.init_patch_size * self.stride - - def _boundary_box(self) -> np.ndarray: + def boundary_box(self, coord) -> np.ndarray: """ Utile class function to compute boundary box in 2D or 3D Returns: np.ndarray: Boundary box dimensions """ - box_dim = self.coord.shape[1] - # Define x,y and (z) min and max sizes - if box_dim in [2, 3]: - min_x = np.min(self.coord[:, 0]) - self.SIZE_EXPAND - max_x = np.max(self.coord[:, 0]) + self.SIZE_EXPAND - - min_y = np.min(self.coord[:, 1]) - self.SIZE_EXPAND - max_y = np.max(self.coord[:, 1]) + self.SIZE_EXPAND - if box_dim == 3 and np.min(self.coord[:, 2]) != 0: - min_z = np.min(self.coord[:, 2]) - self.SIZE_EXPAND - max_z = np.max(self.coord[:, 2]) + self.SIZE_EXPAND + if coord.shape[1] == 3: + min_x, min_y, min_z = np.min(coord, axis=0) + max_x, max_y, max_z = np.max(coord, axis=0) else: + min_x, min_y = np.min(coord, axis=0) + max_x, max_y = np.max(coord, axis=0) min_z, max_z = 0, 0 - return np.array([(min_x, min_y, min_z), (max_x, max_y, max_z)]) + dx = ((min_x + max_x) / 2) - min_x + min_x, max_x = min_x - dx * self.EXPAND, max_x + dx * self.EXPAND - def _collect_patch_idx(self, patches: np.ndarray) -> Tuple[list, list]: - """ - Utile class function to compute patch metrics. + dy = ((min_y + max_y) / 2) - min_y + min_y, max_y = min_y - dy * self.EXPAND, max_y + dy * self.EXPAND - Args: - patches (np.ndarray): List of all patch centers. + dz = ((min_z + max_z) / 2) - min_z + min_z, max_z = min_z - dz * self.EXPAND, max_z + dz * self.EXPAND - Returns: - Tuple[list, list]: ID's list of all patches with more than 1 point - and list of point numbers in each patch. - """ - not_empty_patch = [ - i - for i, patch in enumerate(patches) - if len(self._points_in_patch(patch_center=patch)) > 1 - ] - points_no = [ - self.coord[self._points_in_patch(patch_center=patch)].shape[0] - for patch in patches - if len(self._points_in_patch(patch_center=patch)) > 1 - ] - - return not_empty_patch, points_no + return np.array([(min_x, min_y, min_z), (max_x, max_y, max_z)]) @staticmethod - def _normalize_idx(coord_with_idx: np.ndarray) -> np.ndarray: + def center_patch(bbox, voxel_size=1) -> np.ndarray: """ - Utile class function to replace ids with ordered output ID values for - each point in patches. In other words, it produces a standardized ID for each - point so it can be identified with the source. + Creates a regular grid within a bounding box. Args: - coord_with_idx (np.ndarray): Coordinate id value i. - Returns: - np.ndarray: An array all points in a patch with corrected ID value. - """ - unique_idx, inverse_idx = np.unique(coord_with_idx[:, 0], return_inverse=True) - norm_idx = np.arange(len(unique_idx)) - - for i, id_ in enumerate(unique_idx): - mask = coord_with_idx[:, 0] == id_ - coord_with_idx[:, 0][mask] = norm_idx[inverse_idx[mask]] - return coord_with_idx - - def _output_format(self, data: np.ndarray) -> Union[np.ndarray, torch.Tensor]: - """ - Utile class function to output array in the correct format (numpy or tensor). - - Args: - data (np.ndarray): Input data for format change. + bbox: list or tuple of 6 floats representing the bounding box as (xmin, ymin, zmin, xmax, ymax, zmax) + voxel_size: float representing the size of each voxel Returns: - np.ndarray: Array in file format specified by self.torch_output. + np.ndarray: Of shape (N, 3) representing the center coordinates of each voxel """ - if self.TORCH_OUTPUT: - data = torch.from_numpy(data).type(torch.float32) - - return data - - def _patch_centers(self, boundary_box: np.ndarray) -> np.ndarray: - """ - Utile class function to compute patches given stored patch size and - boundary box to output center coordinate for all possible overlapping - patches. + # Calculate the number of voxels along each axis + n_x = int(np.ceil((bbox[1, 0] - bbox[0, 0]) / voxel_size)) + n_y = int(np.ceil((bbox[1, 1] - bbox[0, 1]) / voxel_size)) + n_z = int(np.ceil((bbox[1, 2] - bbox[0, 2]) / voxel_size)) // 2 + + # Calculate the coordinates of the voxel centers + if n_x < 2: + n_x = 2 + x = np.linspace( + bbox[0, 0] - voxel_size / 2, bbox[1, 0] + voxel_size / 2, n_x + ) - Args: - boundary_box (np.ndarray): Computer point cloud boundary box. + if n_y < 2: + n_y = 2 + y = np.linspace( + bbox[0, 1] - voxel_size / 2, bbox[1, 1] + voxel_size / 2, n_y + ) - Returns: - np.ndarray: Array with XYZ coordinates to localize patch centers. - """ - patch = [] - patch_positions_x, patch_positions_y = [], [] - bb_min, bb_max = boundary_box[0], boundary_box[1] + if n_z < 2: + n_z = 2 + z = np.linspace( + bbox[0, 2] - voxel_size / 2, bbox[1, 2] + voxel_size / 2, n_z + ) - if len(bb_min) == 3: - z_mean = bb_max[2] / 2 - else: - z_mean = 0 - - # Find X positions for patches - x_pos = bb_min[0] + (self.INIT_PATCH_SIZE / 2) - patch_positions_x.append(x_pos) - - while bb_max[0] > x_pos: - x_pos = x_pos + self.INIT_PATCH_SIZE - self.PATCH_STRIDE - patch_positions_x.append(x_pos) - - # Find Y positions for patch - y_pos = bb_min[1] + (self.INIT_PATCH_SIZE / 2) - patch_positions_y.append(y_pos) - while bb_max[1] > y_pos: - y_pos = y_pos + self.INIT_PATCH_SIZE - self.PATCH_STRIDE - patch_positions_y.append(y_pos) - - # Bind X and Y patch positions - patch_positions_x = patch_positions_x[::2] - patch_positions_y = patch_positions_y[::2] - - # Find Z position for patch - if not self.PATCH_3D: # Get 3D patches. Z position is center of bb - for i in patch_positions_x: - patch.append( - np.vstack( - ( - [i] * len(patch_positions_y), - patch_positions_y, - [z_mean] * len(patch_positions_y), - ) - ).T - ) - else: # Get 3D patches. Z position is computed as X and Y position - patch_positions_z = [] - - z_pos = bb_min[2] + (self.INIT_PATCH_SIZE / 2) - patch_positions_z.append(z_pos) - - while bb_max[2] > z_pos: - z_pos = z_pos + self.INIT_PATCH_SIZE - self.PATCH_STRIDE - patch_positions_z.append(z_pos) - - for i in patch_positions_x: - for j in patch_positions_z: - patch.append( - np.vstack( - ( - [i] * len(patch_positions_y), - patch_positions_y, - [j] * len(patch_positions_y), - ) - ).T - ) + xv, yv, zv = np.meshgrid(x, y, z, indexing="ij") + voxel_centers = np.column_stack((xv.flatten(), yv.flatten(), zv.flatten())) - return np.vstack(patch) + return voxel_centers - def _points_in_patch(self, patch_center: np.ndarray) -> tuple: + def points_in_patch(self, coord: np.ndarray, patch_center: np.ndarray) -> bool: """ Utile class function for filtering point cloud and output only point in patch. Args: + coord (np.ndarray): 3D coordinate array. patch_center (np.ndarray): Array (1, 3) for the given patch center. Returns: tuple(bool): Array of all points that are enclosed in the given patch. """ - patch_size = self.INIT_PATCH_SIZE + self.PATCH_STRIDE + patch_size_x = self.INIT_PATCH_SIZE[0] + patch_size_y = self.INIT_PATCH_SIZE[1] + patch_size_z = self.INIT_PATCH_SIZE[2] + + # Bounding box for patch center + patch_x, patch_y, patch_z = patch_center + patch_min_x, patch_max_x = patch_x - patch_size_x, patch_x + patch_size_x + patch_min_y, patch_max_y = patch_y - patch_size_y, patch_y + patch_size_y + patch_min_z, patch_max_z = patch_z - patch_size_z, patch_z + patch_size_z coord_idx = ( - (self.coord[:, 0] <= (patch_center[0] + patch_size)) - & (self.coord[:, 0] >= (patch_center[0] - patch_size)) - & (self.coord[:, 1] <= (patch_center[1] + patch_size)) - & (self.coord[:, 1] >= (patch_center[1] - patch_size)) + (coord[:, 0] <= patch_max_x) + & (coord[:, 0] >= patch_min_x) + & (coord[:, 1] <= patch_max_y) + & (coord[:, 1] >= patch_min_y) + & (coord[:, 2] <= patch_max_z) + & (coord[:, 2] >= patch_min_z) ) return coord_idx - def optimize_patch_size(self) -> Union[Tuple[list, list], Tuple[np.ndarray, list]]: + def optimal_patches(self, coord: np.ndarray) -> list[bool]: """ Main class function to compute optimal patch size. - The function takes init stored variable and iteratively searches for patch size + The function takes init stored variable and iteratively searches for voxel size small enough that allow for all patches to have an equal or less max number of points. - """ - """ Initial check for patches """ - b_box = self._boundary_box() - - if self.coord.shape[0] <= self.DOWNSAMPLING_TH: - patch_coord_x = b_box[1][0] - ((abs(b_box[0][0]) + abs(b_box[1][0])) / 2) - patch_coord_y = b_box[1][1] - ((abs(b_box[0][1]) + abs(b_box[1][1])) / 2) - if b_box.shape[1] == 3: - patch_coord_z = b_box[1][2] - ( - (abs(b_box[0][2]) + abs(b_box[1][2])) / 2 - ) - patches_coord = [patch_coord_x, patch_coord_y, patch_coord_z] - else: - patches_coord = [patch_coord_x, patch_coord_y] + Args: + coord (np.ndarray): List of coordinates for voxelize + """ + bbox = self.boundary_box(coord) + voxel = self.voxel + self.drop_rate + all_patch = [] - patch_idx = [0] + """ Find points index in patches """ + th = 1 + while th != 0: + all_patch = [] - return patches_coord, patch_idx + """Initialize search with new voxel size""" + voxel = round(voxel - self.drop_rate, 1) - # Initial patronization with self.INIT_PATCH_SIZE - if self.INIT_PATCH_SIZE == 0: - self.INIT_PATCH_SIZE = np.max(b_box) - patch_size = self.INIT_PATCH_SIZE - self.PATCH_STRIDE = patch_size * self.STRIDE + patch_grid = self.center_patch(bbox=bbox, voxel_size=voxel) + if len(patch_grid) < 2: + continue - patches_coord = self._patch_centers(boundary_box=b_box) - patch_idx, piv = self._collect_patch_idx(patches=patches_coord) + x_pos = np.sort(np.unique(patch_grid[:, 0])) + if len(x_pos) > 1: + x_pos = (x_pos[1] - x_pos[0]) / 2 + else: + x_pos = bbox[[1, 0]] - x_pos[0] - # Optimize patch size based on no_point threshold - break_if = 0 + y_pos = np.sort(np.unique(patch_grid[:, 1])) + if len(y_pos) > 1: + y_pos = (y_pos[1] - y_pos[0]) / 2 + else: + y_pos = bbox[1, 1] - y_pos[0] - drop_rate = self.DROP_RATE - while not all(i <= self.DOWNSAMPLING_TH for i in piv): - self.INIT_PATCH_SIZE = self.INIT_PATCH_SIZE - self.DROP_RATE + z_pos = np.sort(np.unique(patch_grid[:, 2])) + if len(z_pos) > 1: + z_pos = (z_pos[1] - z_pos[0]) / 2 + else: + z_pos = bbox[1, 2] - z_pos[0] - if self.INIT_PATCH_SIZE <= 0: - break_if += 1 + self.INIT_PATCH_SIZE = [ + x_pos + (x_pos * self.STRIDE), + y_pos + (y_pos * self.STRIDE), + z_pos + (z_pos * self.STRIDE), + ] - self.DROP_RATE = drop_rate / 2 - self.INIT_PATCH_SIZE = patch_size - self.DROP_RATE + """Find points in each patch""" + for patch in patch_grid: + point_idx = self.points_in_patch(coord=coord, patch_center=patch) + all_patch.append(point_idx) - if break_if == 3: - print( - "Could not find valid patch size, prediction of full point cloud!" - ) - return [patches_coord[0]], [patch_idx[0]] + """ Combine smaller patches with threshold limit """ + new_patch = [] + while len(all_patch) > 0: + df = all_patch[0] - self.SIZE_EXPAND = self.INIT_PATCH_SIZE * self.EXPAND - self.PATCH_STRIDE = self.INIT_PATCH_SIZE * self.STRIDE + if np.sum(df) >= self.DOWNSAMPLING_TH: + new_patch.append(df) + all_patch.pop(0) + else: + while np.sum(df) <= self.DOWNSAMPLING_TH: + if len(all_patch) == 1: + break + if np.sum(df) + np.sum(all_patch[1]) > self.DOWNSAMPLING_TH: + break + df += all_patch[1] + all_patch.pop(1) + new_patch.append(df) + all_patch.pop(0) + all_patch = new_patch - patches_coord = self._patch_centers(boundary_box=self._boundary_box()) - patch_idx, piv = self._collect_patch_idx(patches=patches_coord) + all_patch = [patch for patch in all_patch if np.sum(patch) > 0] + th = sum([True for p in all_patch if np.sum(p) > self.DOWNSAMPLING_TH]) - return patches_coord, patch_idx + return all_patch - def patched_dataset( - self, coord: np.ndarray, mesh=False, dist_th: Optional[float] = None - ) -> Union[Tuple[list, list, list, list, list], Tuple[list, list, list, list]]: + @staticmethod + def normalize_idx(coord_with_idx: np.ndarray) -> np.ndarray: """ - Main function for processing dataset and return patches. + Utile class function to replace ids with ordered output ID values for + each point in patches. In other words, it produces a standardized ID for each + point, so it can be identified with the source. Args: - coord (np.ndarray): 2D or 3D array of the point cloud. - mesh (boolean): If True, build a graph for meshes, not filaments. - dist_th (float): Distance threshold for graph from meshes. - + coord_with_idx (np.ndarray): Coordinate id value i. Returns: - list[np.ndarray or torch.Tensor]: - List of arrays (N, 2) or (N, 3) with coordinates of points per patch + np.ndarray: An array all points in a patch with corrected ID value. + """ + unique_idx, inverse_idx = np.unique(coord_with_idx[:, 0], return_inverse=True) + norm_idx = np.arange(len(unique_idx)) - List of an array (N, 3) with RGB value for each point peer patch + for _, id_ in enumerate(unique_idx): + mask = coord_with_idx[:, 0] == id_ + coord_with_idx[:, 0][mask] = norm_idx[inverse_idx[mask]] + return coord_with_idx - An optional list of all computed graphs from each coord_patch + def output_format(self, data: np.ndarray) -> Union[np.ndarray, torch.Tensor]: + """ + Utile class function to output array in the correct format (numpy or tensor) - List of array (N, 1) with ordered ID value for each point per patch. - The ordered ID value allows reconstructing point cloud from patches + Args: + data (np.ndarray): Input data for format change. - List of an array (N, 3) with classes id for each point peer patch + Returns: + np.ndarray: Array in file format specified by self.torch_output. """ + if self.TORCH_OUTPUT: + data = torch.from_numpy(data).type(torch.float32) + return data + + def patched_dataset( + self, coord: np.ndarray, label_cls=None, rgb=None, mesh=False + ) -> Union[Tuple[list, list, list, list, list], Tuple[list, list, list, list]]: coord_patch = [] graph_patch = [] output_idx = [] - self._init_parameters() - if self.GRAPH_OUTPUT: if coord.shape[1] not in [3, 4]: TardisError( @@ -386,11 +297,12 @@ def patched_dataset( "If graph True, coord must by of shape" f"[Dim x X x Y x (Z)], but is: {coord.shape}", ) - self.segments_id = coord - self.coord = coord[:, 1:] + segmented_coord = coord + coord = coord[:, 1:] graph_builder = BuildGraph(mesh=mesh) else: + graph_builder = None if coord.shape[1] not in [2, 3]: TardisError( "113", @@ -398,158 +310,98 @@ def patched_dataset( "If graph False, coord must by of shape" f"[X x Y x (Z)], but is: {coord.shape}", ) - self.segments_id = None - self.coord = coord - - # if mesh: - # if dist_th is None: - # TardisError('124', - # 'tardis_pytorch/dist_pytorch/datasets/patches.py', - # 'If mesh, dist_th cannot be None!') - - # Check if point cloud is smaller than max allowed point - if self.coord.shape[0] <= self.DOWNSAMPLING_TH: - """Transform 2D coord to 3D of shape [Z, Y, X]""" - if self.coord.shape[1] == 2: - coord_ds = np.vstack( + segmented_coord = None + if coord.shape[1] == 2: + """Transform 2D coord to 3D of shape [X, Y, Z]""" + coord = np.vstack( ( - self.coord[:, 0], - self.coord[:, 1], - np.zeros((self.coord.shape[0],)), + coord[:, 0], + coord[:, 1], + np.zeros((coord.shape[0],)), ) ).T else: - coord_ds = self.coord - coord_ds = [True for _ in list(range(0, coord_ds.shape[0], 1))] + coord = coord + + if coord.shape[0] <= self.DOWNSAMPLING_TH: + coord_ds = [True for _ in list(range(0, coord.shape[0], 1))] """ Build point cloud for each patch """ - coord_patch.append(self._output_format(self.coord[coord_ds, :])) + coord_patch.append(self.output_format(coord[coord_ds, :])) """ Optionally - Build graph for each patch """ if self.GRAPH_OUTPUT: - coord_label = self.segments_id[coord_ds, :] - coord_label = self._normalize_idx(coord_label) + coord_label = segmented_coord[coord_ds, :] + coord_label = self.normalize_idx(coord_label) - if mesh: - graph_patch.append( - self._output_format( - graph_builder(coord=coord_label, dist_th=dist_th) - ) - ) - else: - graph_patch.append( - self._output_format(graph_builder(coord=coord_label)) - ) + graph_patch.append(self.output_format(graph_builder(coord=coord_label))) """ Build output index for each patch """ output_idx.append(np.where(coord_ds)[0]) """ Build class label index for each patch """ - if self.label_cls is not None: - cls_patch = [self.label_cls] + if label_cls is not None: + cls_patch = [label_cls] else: - cls_patch = [self._output_format(np.zeros((1, 1)))] + cls_patch = [self.output_format(np.zeros((1, 1)))] """ Build rgb node label index for each patch """ - if self.rgb is not None: - rgb_patch = [self.rgb] + if rgb is not None: + rgb_patch = [rgb] else: - rgb_patch = [self._output_format(np.zeros((1, 1)))] + rgb_patch = [self.output_format(np.zeros((1, 1)))] else: # Build patches for PC with max num. of point per patch - """Find optimal patch centers""" - patches_centers, patches_idx = self.optimize_patch_size() - - all_patch = [] cls_patch = [] rgb_patch = [] - """ Find all patches """ - for i in patches_idx: - all_patch.append(self._points_in_patch(patches_centers[i])) - - """ Combine smaller patches with threshold limit """ - new_patch = [] - while len(all_patch) > 0: - df = all_patch[0] - - if df.sum() >= self.DOWNSAMPLING_TH: - new_patch.append(df) - all_patch.pop(0) - else: - while df.sum() <= self.DOWNSAMPLING_TH: - if len(all_patch) == 1: - break - if df.sum() + all_patch[1].sum() > self.DOWNSAMPLING_TH: - break - df += all_patch[1] - all_patch.pop(1) - new_patch.append(df) - all_patch.pop(0) - - all_patch = new_patch + """ Find points index in patches """ + all_patch = self.optimal_patches(coord=coord) - """ Build patches """ + """Build embedded feature per patch""" for i in all_patch: """Find points and optional images for each patch""" df_patch_keep = i - df_patch = self.coord[df_patch_keep, :] + df_patch = coord[df_patch_keep, :] output_df = np.where(df_patch_keep)[0] - - # Transform 2D coord to 3D of shape [Z, Y, X] - if df_patch.shape[1] == 2: - coord_ds = np.vstack( - (np.zeros((df_patch.shape[0],)), df_patch[:, 1], df_patch[:, 0]) - ).T - else: - coord_ds = df_patch - - coord_ds = [True for _ in list(range(0, coord_ds.shape[0], 1))] + coord_ds = [True for _ in list(range(0, df_patch.shape[0], 1))] """ Build point cloud for each patch """ - coord_patch.append(self._output_format(df_patch[coord_ds, :])) + coord_patch.append(self.output_format(df_patch[coord_ds, :])) """ Optionally - Build graph for each patch """ if self.GRAPH_OUTPUT: - segment_patch = self.segments_id[df_patch_keep, :] - segment_patch = self._normalize_idx(segment_patch[coord_ds, :]) - - if mesh: - graph_patch.append( - self._output_format( - graph_builder(coord=segment_patch, dist_th=dist_th) - ) - ) - else: - graph_patch.append( - self._output_format(graph_builder(coord=segment_patch)) - ) + segment_patch = segmented_coord[df_patch_keep, :] + segment_patch = self.normalize_idx(segment_patch[coord_ds, :]) + + graph_patch.append( + self.output_format(graph_builder(coord=segment_patch)) + ) """ Build output index for each patch """ output_idx.append(output_df[coord_ds]) """ Build class label index for each patch """ - if self.label_cls is not None: - cls_df = self.label_cls[df_patch_keep] + if label_cls is not None: + cls_df = label_cls[df_patch_keep] cls_new = np.zeros((cls_df.shape[0], 200)) else: cls_df = [0] cls_new = np.zeros((1, 200)) - for id, j in enumerate(cls_df): + for id_, j in enumerate(cls_df): df = np.zeros((1, 200)) df[0, int(j)] = 1 - cls_new[id, :] = df - - cls_patch.append(self._output_format(cls_new)) + cls_new[id_, :] = df + cls_patch.append(self.output_format(cls_new)) """ Build rbg node label index for each patch """ - if self.rgb is not None: - rgb_df = self.rgb[df_patch_keep] + if rgb is not None: + rgb_df = rgb[df_patch_keep] else: rgb_df = np.zeros((1, 1)) - rgb_patch.append(self._output_format(rgb_df)) + rgb_patch.append(self.output_format(rgb_df)) if self.GRAPH_OUTPUT: return coord_patch, rgb_patch, graph_patch, output_idx, cls_patch diff --git a/tardis_pytorch/dist_pytorch/utils/utils.py b/tardis_pytorch/dist_pytorch/utils/utils.py index 59f84d8a..678748cc 100644 --- a/tardis_pytorch/dist_pytorch/utils/utils.py +++ b/tardis_pytorch/dist_pytorch/utils/utils.py @@ -90,7 +90,7 @@ def pc_median_dist(pc: np.ndarray, avg_over=False, box_size=0.15) -> float: return 1.0 knn_df = [] - for id, i in enumerate(pc): + for id, _ in enumerate(pc): knn, _ = tree.query(pc[id].reshape(1, -1), k=4) knn_df.append(knn[0][1]) diff --git a/tardis_pytorch/utils/metrics.py b/tardis_pytorch/utils/metrics.py index 6b2e42fd..7a14da59 100644 --- a/tardis_pytorch/utils/metrics.py +++ b/tardis_pytorch/utils/metrics.py @@ -15,6 +15,7 @@ from sklearn.metrics import auc, average_precision_score, roc_curve +# AUPR not AUC!!!! def compare_dict_metrics(last_best_dict: dict, new_dict: dict) -> bool: """ Compares two metric dictionaries and returns the one with the highest @@ -299,8 +300,10 @@ def mcov( if weight: mCov = 0 else: - mCov = [] + mCov = 0 + unique_target = np.unique(targets[:, 0]) + G = len(unique_target) unique_input = np.unique(input[:, 0]) if eval: @@ -336,16 +339,16 @@ def mcov( if weight: mCov += w * 1.0 else: - mCov.append(1.0) + mCov += 1.0 else: if weight: mCov += w * 1.0 else: - mCov.append(df) # Pick max IoU for GT instance + mCov += df # Pick max IoU for GT instance if weight: - return mCov / len(unique_target) - return np.mean(mCov) + return mCov + return np.mean(mCov) / G def mwcov( @@ -388,12 +391,12 @@ def mwcov( df = np.max(df) if df > 1.0: mwCov += w * 1.0 - mCov.append(1.0) + mCov += 1.0 else: mCov += w * 1.0 - mCov.append(df) # Pick max IoU for GT instance + mCov += 1.0 # Pick max IoU for GT instance - return mCov, mwCov / len(unique_target) + return mCov / len(unique_target), mwCov def confusion_matrix( diff --git a/tardis_pytorch/utils/predictor.py b/tardis_pytorch/utils/predictor.py index e9b985b1..19d7d297 100644 --- a/tardis_pytorch/utils/predictor.py +++ b/tardis_pytorch/utils/predictor.py @@ -819,7 +819,9 @@ def __call__(self, *args, **kwargs): if self.predict in ["Filament", "Microtubule"]: np.savetxt( - join(self.am_output, f"{i[:-self.in_format]}_Segments_filter.csv"), + join( + self.am_output, f"{i[:-self.in_format]}_Segments_filter.csv" + ), self.filter_splines(segments=self.segments), delimiter=",", )