In [1]:
# Implementation dgcnn.pytorch
# Experiment 12: Testing
# 07.11.2022
# Read txt for description

In [2]:
# 3 streams then concatenate after first block
# OA = %

# Define model
# from model.py
# COPY THIS BLOCK FOR TESTING !!

import os
import sys
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F


def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx


def get_graph_feature_rgb(x, k=20, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    # print('x shape', x.size())
    if idx is None:
        if dim9 == False:
            # idx = knn(x, k=k)   # (batch_size, num_points, k)
            # Use normXYZ + RGB + SWIR + geo for knn search
            idx = knn(x[:,3:], k=k)   # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 6:], k=k)
    
    # only use RGB for graph feature
    x = x[:,6:9,:] # RGB
    
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
  
    return feature      # (batch_size, 2*num_dims, num_points, k)

def get_graph_feature_swir(x, k=20, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    # print('x shape', x.size())
    if idx is None:
        if dim9 == False:
            # idx = knn(x, k=k)   # (batch_size, num_points, k)
            # Use normXYZ + RGB + SWIR + geo for knn search
            idx = knn(x[:,3:], k=k)   # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 6:], k=k)
    
    # only use SWIR for graph feature
    x = x[:,9:153,:] # SWIR
    
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
  
    return feature      # (batch_size, 2*num_dims, num_points, k)

def get_graph_feature_geo(x, k=20, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    # print('x shape', x.size())
    if idx is None:
        if dim9 == False:
            # idx = knn(x, k=k)   # (batch_size, num_points, k)
            # Use normXYZ + RGB + SWIR + geo for knn search
            idx = knn(x[:,3:], k=k)   # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 6:], k=k)
    
    # only use geo for graph feature
    x = x[:,153:181,:] # geo
    
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
  
    return feature      # (batch_size, 2*num_dims, num_points, k)

def get_graph_feature(x, k=20, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim9 == False:
            idx = knn(x, k=k)   # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 6:], k=k)
    
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
  
    return feature      # (batch_size, 2*num_dims, num_points, k)

class DGCNN_semseg(nn.Module):
    def __init__(self, args):
        super(DGCNN_semseg, self).__init__()
        self.args = args
        self.k = args_k
        
        self.bn1_rgb = nn.BatchNorm2d(64)
        self.bn2_rgb = nn.BatchNorm2d(64)

        self.bn1_swir = nn.BatchNorm3d(4)
        self.bn2_swir = nn.BatchNorm3d(4)
        self.bn3_swir = nn.BatchNorm2d(64)
        self.bn4_swir = nn.BatchNorm2d(64)
        
        self.bn1_geo = nn.BatchNorm2d(64)
        self.bn2_geo = nn.BatchNorm2d(64)
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(64)        
        self.bn6 = nn.BatchNorm2d(64)


        self.bn7 = nn.BatchNorm1d(512)
        self.bn8 = nn.BatchNorm1d(256)
        
        
        # RGB
        self.conv1_rgb = nn.Sequential(nn.Conv2d(dim_rgb*2, 64, kernel_size=1, bias=False),
                                   self.bn1_rgb,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2_rgb = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2_rgb,
                                   nn.LeakyReLU(negative_slope=0.2))

        # SWIR
        self.conv1_swir = nn.Sequential(nn.Conv3d(1, 4, kernel_size=(32,1,1), bias=False),
                                   self.bn1_swir,
                                   nn.LeakyReLU(negative_slope=0.2))      
        self.conv2_swir = nn.Sequential(nn.Conv3d(4, 4, kernel_size=(32,1,1), bias=False),
                                   self.bn2_swir,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3_swir = nn.Sequential(nn.Conv2d(dim_swir*2, 64, kernel_size=1, bias=False),
                                   self.bn3_swir,
                                   nn.LeakyReLU(negative_slope=0.2))      
        self.conv4_swir = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn4_swir,
                                   nn.LeakyReLU(negative_slope=0.2))

        
        # geo
        self.conv1_geo = nn.Sequential(nn.Conv2d(dim_geo*2, 64, kernel_size=1, bias=False),
                                   self.bn1_geo,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2_geo = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2_geo,
                                   nn.LeakyReLU(negative_slope=0.2))
        
        # all
        self.conv1 = nn.Sequential(nn.Conv2d(196*2, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.conv7 = nn.Sequential(nn.Conv1d(192, 512, kernel_size=1, bias=False),
                                   self.bn7,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv8 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False),
                                   self.bn8,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=args_dropout)
        self.conv9 = nn.Conv1d(256, args_num_class, kernel_size=1, bias=False) # CHANGE TO NUMBER OF CLASSES
        

    def forward(self, x):
        batch_size = x.size(0)
        num_points = x.size(2)
        
        # RGB
        x_rgb = get_graph_feature_rgb(x, k=self.k, dim9=False)   # (batch_size, 9, num_points) -> (batch_size, 9*2, num_points, k) 
        x_rgb = self.conv1_rgb(x_rgb)                       # (batch_size, 9*2, num_points, k) -> (batch_size, 64, num_points, k)
        x_rgb = self.conv2_rgb(x_rgb)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x_rgb = x_rgb.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        
        # SWIR
        x_swir3d = get_graph_feature_swir(x, k=self.k, dim9=False)   # (batch_size, 9, num_points) -> (batch_size, 9*2, num_points, k)
        x_swir3d = x_swir3d.unsqueeze(1)
        x_swir3d = self.conv1_swir(x_swir3d)                       # (batch_size, 9*2, num_points, k) -> (batch_size, 64, num_points, k)
        x_swir3d = self.conv2_swir(x_swir3d)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x_swir3d = x_swir3d.max(dim=2,keepdim=False)[0]
        x_swir3d = x_swir3d.max(dim=-1,keepdim=False)[0]
        
        x_swir = get_graph_feature_swir(x, k=self.k, dim9=False)   # (batch_size, 9, num_points) -> (batch_size, 9*2, num_points, k)
        x_swir = self.conv3_swir(x_swir)
        x_swir = self.conv4_swir(x_swir)
        x_swir = x_swir.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        
        # geo
        x_geo = get_graph_feature_geo(x, k=self.k, dim9=False)   # (batch_size, 9, num_points) -> (batch_size, 9*2, num_points, k)
        x_geo = self.conv1_geo(x_geo)                       # (batch_size, 9*2, num_points, k) -> (batch_size, 64, num_points, k)
        x_geo = self.conv2_geo(x_geo)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x_geo = x_geo.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        
        x_all_cat = torch.cat((x_rgb, x_swir3d, x_swir, x_geo), dim=1) # RGB + SWIR + geo # 192 + swir3d features
        
        x_all_graph_0 = get_graph_feature(x_all_cat, k=self.k)
        x_all_1 = self.conv1(x_all_graph_0)
        x_all_2 = self.conv2(x_all_1)
        x_all_2_max = x_all_2.max(dim=-1, keepdim=False)[0]
        
        x_all_graph_1 = get_graph_feature(x_all_2_max, k=self.k)
        x_all_3 = self.conv3(x_all_graph_1)
        x_all_4 = self.conv4(x_all_3)
        x_all_4_max = x_all_4.max(dim=-1, keepdim=False)[0]
        
        x_all_graph_2 = get_graph_feature(x_all_4_max, k=self.k)
        x_all_5 = self.conv5(x_all_graph_2)
        x_all_6 = self.conv6(x_all_5)
        x_all_6_max = x_all_6.max(dim=-1, keepdim=False)[0]
        
        x_cat = torch.cat((x_all_2_max, x_all_4_max, x_all_6_max), dim=1)
        
        x_fc1 = self.conv7(x_cat)                       # (batch_size, 64*3, num_points) -> (batch_size, 512, num_points)
        x_fc2 = self.conv8(x_fc1)                       # (batch_size, 512, num_points) -> (batch_size, 256, num_points)
        x_dp = self.dp1(x_fc2)
        x_end = self.conv9(x_dp)                       # (batch_size, 256, num_points) -> (batch_size, num_class, num_points)
        
        return x_end

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import numpy as np
import glob
import os
import sys

def room2blocks(data, label, num_point, block_size=1.0, stride=1.0,
                random_sample=False, sample_num=None, sample_aug=1, use_all_points=False):
    """ Prepare block training data.
    Args:
        data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1]
            assumes the data is shifted (min point is origin) and aligned
            (aligned with XYZ axis)
        label: N size uint8 numpy array from 0-12
        num_point: int, how many points to sample in each block
        block_size: float, physical size of the block in meters
        stride: float, stride for block sweeping
        random_sample: bool, if True, we will randomly sample blocks in the room
        sample_num: int, if random sample, how many blocks to sample
            [default: room area]
        sample_aug: if random sample, how much aug
    Returns:
        block_datas: K x num_point x 6 np array of XYZRGB, RGB is in [0,1]
        block_labels: K x num_point x 1 np array of uint8 labels
        
    TODO: for this version, blocking is in fixed, non-overlapping pattern.
    """
    assert (stride <= block_size)

    limit = np.amax(data, 0)[0:3]

    # Get the corner location for our sampling blocks    
    xbeg_list = []
    ybeg_list = []
    if not random_sample:
        num_block_x = int(np.ceil((limit[0] - block_size) / stride)) + 1
        num_block_y = int(np.ceil((limit[1] - block_size) / stride)) + 1
        for i in range(num_block_x):
            for j in range(num_block_y):
                xbeg_list.append(i * stride)
                ybeg_list.append(j * stride)
    else:
        num_block_x = int(np.ceil(limit[0] / block_size))
        num_block_y = int(np.ceil(limit[1] / block_size))
        if sample_num is None:
            sample_num = num_block_x * num_block_y * sample_aug
        for _ in range(sample_num):
            xbeg = np.random.uniform(-block_size, limit[0])
            ybeg = np.random.uniform(-block_size, limit[1])
            xbeg_list.append(xbeg)
            ybeg_list.append(ybeg)

    # Collect blocks
    block_data_list = []
    block_label_list = []
    idx = 0
    for idx in range(len(xbeg_list)):
        xbeg = xbeg_list[idx]
        ybeg = ybeg_list[idx]
        xcond = (data[:, 0] <= xbeg + block_size) & (data[:, 0] >= xbeg)
        ycond = (data[:, 1] <= ybeg + block_size) & (data[:, 1] >= ybeg)
        cond = xcond & ycond
        if np.sum(cond) < 20:  # discard block if there are less than 20 pts. # Check !!
            if np.sum(cond) > 0:
                print('Discard block!! Number of points = ', np.sum(cond))
            continue

        block_data = data[cond, :]
        block_label = label[cond]

        if use_all_points:
            block_data_list.append(block_data)
            block_label_list.append(block_label)
        else:
            # randomly subsample data
            block_data_sampled, block_label_sampled = \
                sample_data_label(block_data, block_label, num_point)
            block_data_list.append(np.expand_dims(block_data_sampled, 0))
            block_label_list.append(np.expand_dims(block_label_sampled, 0))

    if use_all_points:
        block_data_return, block_label_return = np.array(block_data_list), np.array(block_label_list)
    else:
        block_data_return, block_label_return = np.concatenate(block_data_list, 0), np.concatenate(block_label_list, 0)
    print('block_data_return_size:')
    print(np.array(block_data_return).shape)

    return block_data_return, block_label_return

In [4]:
import os
import sys
import glob
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset

class S3DISDataset_eval(Dataset):  # load data block by block, without using h5 files
    def __init__(self, split='train', data_root='trainval_fullarea', num_point=4096, test_area='5', block_size=1.0, stride=1.0, num_class=20, use_all_points=False, num_thre = 1024):
        super().__init__()
        self.num_point = num_point
        self.block_size = block_size
        self.use_all_points = use_all_points
        self.stride = stride
        self.num_thre = num_thre
        rooms = sorted(os.listdir(data_root))
        rooms = [room for room in rooms if 'Area_' in room]
        if split == 'train':
            rooms_split = [room for room in rooms if not 'Area_{}'.format(test_area) in room]
        else:
            rooms_split = [room for room in rooms if 'Area_{}'.format(test_area) in room]
        self.room_points, self.room_labels = [], []
        self.room_coord_min, self.room_coord_max = [], []

        room_idxs = []
        for index, room_name in enumerate(rooms_split):
            room_path = os.path.join(data_root, room_name)
            room_data = np.load(room_path)
            points, labels = room_data[:, 0:-1], room_data[:, -1] # CHECK !! label always in the last column !!
            coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
            self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max)
            block_points, block_labels = room2blocks(points, labels, self.num_point, block_size=self.block_size,
                                                       stride=self.stride, random_sample=False, sample_num=None, use_all_points=self.use_all_points)
            room_idxs.extend([index] * int(block_points.shape[0]))  # extend with number of blocks in a room
            self.room_points.append(block_points), self.room_labels.append(block_labels)
        self.room_points = np.concatenate(self.room_points)
        self.room_labels = np.concatenate(self.room_labels)

        self.room_idxs = np.array(room_idxs)
        print("Totally {} samples in {} set.".format(len(self.room_idxs), split))

    def __getitem__(self, idx):  # get items in one block
        room_idx = self.room_idxs[idx]
        selected_points = self.room_points[idx]   # num_point * XYZ RGB SWIR geo
        current_labels = self.room_labels[idx]   # num_point
        center = np.mean(selected_points, axis=0)
        N_points = selected_points.shape[0]

        current_points = np.zeros((N_points, data_dimension+6))  # data dimension + XYZ + normXYZ, Check!
        # add normalized XYZ (column 3,4,5)
        current_points[:, 3] = selected_points[:, 0] / self.room_coord_max[room_idx][0]
        current_points[:, 4] = selected_points[:, 1] / self.room_coord_max[room_idx][1]
        current_points[:, 5] = selected_points[:, 2] / self.room_coord_max[room_idx][2]
        # recenter for each block
        selected_points[:, 0] = selected_points[:, 0] - center[0]
        selected_points[:, 1] = selected_points[:, 1] - center[1]
        # normalize RGB
        selected_points[:, 3:6] /= 255.0
        # SWIR already normalized
        # Colummn: XYZ, normXYZ, RGB, SWIR, geo
        current_points[:,0:3] = selected_points[:,0:3] # for XYZ
        current_points[:,6:9] = selected_points[:,3:6] # for RGB
        current_points[:,9:153] = selected_points[:,6:150] # for SWIR
        current_points[:,153:181] = selected_points[:,150:178] # for geo

        return current_points, current_labels

    def __len__(self):
        return len(self.room_idxs)

In [5]:
def calculate_sem_IoU(pred_np, seg_np, num_classes):  # num_classes: S3DIS 13
    I_all = np.zeros(num_classes)
    U_all = np.zeros(num_classes)
    for sem_idx in range(len(seg_np)):
        for sem in range(num_classes):
            I = np.sum(np.logical_and(pred_np[sem_idx] == sem, seg_np[sem_idx] == sem))
            U = np.sum(np.logical_or(pred_np[sem_idx] == sem, seg_np[sem_idx] == sem))
            I_all[sem] += I
            U_all[sem] += U
    return I_all / U_all

In [6]:
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
# from data import S3DISDataset
# from data import S3DISDataset_eval
# from model import DGCNN_semseg
import numpy as np
from torch.utils.data import DataLoader
# from util import cal_loss, IOStream
import sklearn.metrics as metrics
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def test():
    DUMP_DIR = path_dump_dir
    all_true_cls = []
    all_pred_cls = []
    all_true_seg = []
    all_pred_seg = []

    dataset = S3DISDataset_eval(split='test', data_root=args_data_dir, num_point=args_num_points, test_area=args_test_area,
                           block_size=args_block_size, stride=args_block_size, num_class=args_num_classes, num_thre=100, use_all_points=True)
    test_loader = DataLoader(dataset, batch_size=args_test_batch_size, shuffle=False, drop_last=False)

    room_idx = np.array(dataset.room_idxs)
    num_blocks = len(room_idx)

    fout_data_label = []
    for room_id in np.unique(room_idx):
        out_data_label_filename = 'Area_%s_pred_gt_%s.txt' % (test_area, args_predict_name)
        out_data_label_filename = os.path.join(DUMP_DIR, out_data_label_filename)
        fout_data_label.append(open(out_data_label_filename, 'w+'))     

    device = torch.device("cuda" if args_cuda else "cpu")

    # io.cprint('Start overall evaluation...')

    # Try to load models
    if args_model == 'dgcnn':
        model = DGCNN_semseg(nn.Module).to(device)
    else:
        raise Exception("Not implemented")

    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(path_model))
    model = model.eval()

    print('model restored')

    test_acc = 0.0
    count = 0.0
    test_true_cls = []
    test_pred_cls = []
    test_true_seg = []
    test_pred_seg = []

    print('Start testing ...')
    num_batch = 0
    for data, seg in tqdm(test_loader):
        th_subblock = 20000
        st_sb = 0
        en_sb = th_subblock
        if data.shape[1] > th_subblock:
            print('Too many points in the block. Split the block!!')
            
            # Split data into n sub-blocks
            n_subblocks = int(np.ceil(data.shape[1]/th_subblock))
            print('N subblocks', n_subblocks)
            for split in range(n_subblocks):
                print('Working on subblock = ', split+1)
                if split+1 < n_subblocks:
                    n_pts_subblock = th_subblock
                else:
                    n_pts_subblock = data.shape[1] - (split*th_subblock) 
                
                data_split = torch.zeros([1,n_pts_subblock,data.shape[2]],dtype=torch.float64)
                seg_split = torch.zeros([1,n_pts_subblock],dtype=torch.float64)
                data_split[:,:,:] = data[:,st_sb:en_sb,:]
                seg_split[:,:] = seg[:,st_sb:en_sb]
                data_split, seg = data_split.to(device), seg.to(device)
                data_split = data_split.permute(0, 2, 1).float()
                batch_size = data_split.size()[0]
                
                st_sb += th_subblock
                en_sb += th_subblock

                seg_pred = model(data_split)
                seg_pred = seg_pred.permute(0, 2, 1).contiguous()
                pred = seg_pred.max(dim=2)[1]
                seg_np = seg_split.cpu().numpy()
                pred_np = pred.detach().cpu().numpy()
                test_true_cls.append(seg_np.reshape(-1))
                test_pred_cls.append(pred_np.reshape(-1))
                test_true_seg.append(seg_np)
                test_pred_seg.append(pred_np)

                # write prediction results

                for batch_id in range(batch_size):
                    pts = data_split[batch_id, :, :]
                    pts = pts.permute(1, 0).float()
                    l = seg_split[batch_id, :]
                    pts[:, 6:9] *= 255.0 # unnormalized RGB, previously in 3:6
                    pred_ = pred[batch_id, :]
                    logits = seg_pred[batch_id, :, :]
                    # compute room_id
                    room_id = room_idx[num_batch + batch_id]
                    for i in range(pts.shape[0]):
                        fout_data_label[room_id].write('%f %f %f %d %d %d %d %d\n' % (
                           # change the position of normXYZ from 6,7,8 to 3,4,5
                           pts[i, 3]*dataset.room_coord_max[room_id][0], pts[i, 4]*dataset.room_coord_max[room_id][1], pts[i, 5]*dataset.room_coord_max[room_id][2],
                           pts[i, 6], pts[i, 7], pts[i, 8], pred_[i], l[i]))  # xyzRGB pred gt
                

        else:
            data, seg = data.to(device), seg.to(device)
            data = data.permute(0, 2, 1).float()
            batch_size = data.size()[0]

            seg_pred = model(data)
            seg_pred = seg_pred.permute(0, 2, 1).contiguous()
            pred = seg_pred.max(dim=2)[1]
            seg_np = seg.cpu().numpy()
            pred_np = pred.detach().cpu().numpy()
            test_true_cls.append(seg_np.reshape(-1))
            test_pred_cls.append(pred_np.reshape(-1))
            test_true_seg.append(seg_np)
            test_pred_seg.append(pred_np)

            # write prediction results

            for batch_id in range(batch_size):
                pts = data[batch_id, :, :]
                pts = pts.permute(1, 0).float()
                l = seg[batch_id, :]
                pts[:, 6:9] *= 255.0 # unnormalized RGB, previously in 3:6
                pred_ = pred[batch_id, :]
                logits = seg_pred[batch_id, :, :]
                # compute room_id
                room_id = room_idx[num_batch + batch_id]
                for i in range(pts.shape[0]):
                    fout_data_label[room_id].write('%f %f %f %d %d %d %d %d\n' % (
                       # change the position of normXYZ from 6,7,8 to 3,4,5
                       pts[i, 3]*dataset.room_coord_max[room_id][0], pts[i, 4]*dataset.room_coord_max[room_id][1], pts[i, 5]*dataset.room_coord_max[room_id][2],
                       pts[i, 6], pts[i, 7], pts[i, 8], pred_[i], l[i]))  # xyzRGB pred gt
            
        num_batch += batch_size
        torch.cuda.empty_cache()

    for room_id in np.unique(room_idx):
        fout_data_label[room_id].close()

    # test_ious = calculate_sem_IoU(test_pred_cls, test_true_cls, args_num_classes)
    test_true_cls = np.concatenate(test_true_cls)
    test_pred_cls = np.concatenate(test_pred_cls)
    # test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls)
    # avg_per_class_acc = metrics.balanced_accuracy_score(test_true_cls, test_pred_cls)
    # test_pred_seg = np.concatenate(test_pred_seg, axis=0)
    # outstr = 'Test :: test area: %s, test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % (test_area,
    #                                                                                        test_acc,
    #                                                                                        avg_per_class_acc,
    #                                                                                        np.mean(test_ious))
    # io.cprint(outstr)

    # calculate confusion matrix
    conf_mat = metrics.confusion_matrix(test_true_cls, test_pred_cls)
    print('Confusion matrix:')
    print(conf_mat)
    np.savetxt('predict_3DCNN/con_mat.txt', conf_mat)
    
    # calculate overall accuracy
    OA = metrics.accuracy_score(test_true_cls, test_pred_cls)
    print('Overall Accuracy')
    print(OA)
    # np.savetxt('predict/OA.txt', OA)
    # io.cprint(str(conf_mat))

    # all_true_cls.append(test_true_cls)
    # all_pred_cls.append(test_pred_cls)
    # all_true_seg.append(test_true_seg)
    # all_pred_seg.append(test_pred_seg)'''

In [7]:
args_data_dir = 'data/lithonet_sem_seg_data_Experiment_12' # CHANGE
args_num_points = 4096
args_test_area = '2'
test_area = 2
args_block_size = 50 # CHANGE # CUDA out of memory for 100 m block size
args_num_classes = 10 # CHANGE
args_test_batch_size = 1
args_cuda = True
args_model = 'dgcnn'
args_k = 20
args_emb_dims = 1024
args_dropout = 0.5
args_num_class = 10
args_predict_name = 'Experiment_12'

data_dimension = 175 # RGB (3) SWIR (144) geo (28)
dim_rgb = 3
dim_swir = 144
dim_geo = 28

path_dump_dir = 'predict_3DCNN'
path_model = 'model_3DCNN/Experiment_12_best_acc.t7'

test()

Discard block!! Number of points =  19


  block_data_return, block_label_return = np.array(block_data_list), np.array(block_label_list)


block_data_return_size:
(198,)
Totally 198 samples in test set.
model restored
Start testing ...


 28%|██▊       | 55/198 [01:53<04:50,  2.03s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1


 28%|██▊       | 56/198 [01:57<06:09,  2.60s/it]

Working on subblock =  2


 31%|███       | 61/198 [02:15<07:41,  3.37s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 38%|███▊      | 75/198 [02:54<03:57,  1.93s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 39%|███▉      | 77/198 [03:02<05:51,  2.90s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 41%|████▏     | 82/198 [03:13<04:17,  2.22s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 44%|████▍     | 88/198 [03:30<04:18,  2.35s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 51%|█████     | 100/198 [03:58<02:35,  1.58s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 56%|█████▌    | 110/198 [04:24<02:44,  1.87s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 60%|██████    | 119/198 [04:43<02:10,  1.65s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 62%|██████▏   | 122/198 [04:51<02:55,  2.31s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1


 62%|██████▏   | 123/198 [04:55<03:24,  2.72s/it]

Working on subblock =  2


 64%|██████▎   | 126/198 [05:00<02:12,  1.84s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 66%|██████▌   | 131/198 [05:18<03:29,  3.12s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 67%|██████▋   | 132/198 [05:23<04:01,  3.66s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 69%|██████▊   | 136/198 [05:34<02:38,  2.56s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 69%|██████▉   | 137/198 [05:40<03:32,  3.48s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 70%|██████▉   | 138/198 [05:45<03:48,  3.81s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 71%|███████   | 141/198 [05:54<03:01,  3.18s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1


 72%|███████▏  | 142/198 [05:58<03:10,  3.40s/it]

Working on subblock =  2
Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 72%|███████▏  | 143/198 [06:04<03:42,  4.05s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 76%|███████▌  | 150/198 [06:18<01:04,  1.35s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 76%|███████▋  | 151/198 [06:24<02:06,  2.69s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 77%|███████▋  | 152/198 [06:28<02:28,  3.24s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 79%|███████▉  | 157/198 [06:47<02:24,  3.52s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 80%|███████▉  | 158/198 [06:52<02:37,  3.94s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 80%|████████  | 159/198 [06:57<02:52,  4.43s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 81%|████████▏ | 161/198 [07:05<02:37,  4.25s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1


 82%|████████▏ | 162/198 [07:09<02:30,  4.17s/it]

Working on subblock =  2


 84%|████████▍ | 166/198 [07:19<01:35,  2.97s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 84%|████████▍ | 167/198 [07:24<01:48,  3.50s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 85%|████████▌ | 169/198 [07:32<01:47,  3.72s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 88%|████████▊ | 174/198 [07:44<01:01,  2.57s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 88%|████████▊ | 175/198 [07:48<01:10,  3.06s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1
Working on subblock =  2


 93%|█████████▎| 185/198 [08:17<00:35,  2.75s/it]

Too many points in the block. Split the block!!
N subblocks 2
Working on subblock =  1


 94%|█████████▍| 186/198 [08:21<00:36,  3.08s/it]

Working on subblock =  2


100%|██████████| 198/198 [08:34<00:00,  2.60s/it]


Confusion matrix:
[[194579   3092     92   9067   2331  56913    311     37    512   1062]
 [  2877  17053      8    773     80     86      0      3    925    629]
 [  8966    356  90092   7571    780   6949     26   1103   1459   4735]
 [ 35751  25415  65596 583715   3183  41610   4375   3253   2276   3489]
 [ 10377    106   1163   7543   8216    875     24     43     45    221]
 [ 11990    230   1364   9155      7 599965  39457   3413    491   1780]
 [   240    342    214   3858      4  87023 196130   1345    142    566]
 [  1933   2699    690  15603    565   8170   1388 100844   9380   9026]
 [  2900   1962    482   9007    582    706     36   6562  45174  16128]
 [   640   3944   3481   9294    333    344     64  12073   9568 103896]]
Overall Accuracy
0.7621670197211479
