In [1]:
import os
import sys
import pdb
import rospy
import numpy as np
import torch
from visualization_msgs.msg import *

# Adds parent dir to path
current_dir = os.getcwd()
base_dir = os.path.dirname(os.path.dirname(current_dir))
sys.path.insert(0, base_dir)

from Models.DiscreteBKI import *
from Data.dataset import Rellis3dDataset, ray_trace_batch
from Data.utils import *
from model_utils import *
from sklearn.metrics import jaccard_score

NUM_CLASSES=21
DATASET_DIR = "/home/arthurzhang/Data/Rellis-3D"
sbki_folder = "/home/arthurzhang/CURLY/Baselines/catkin_ws/src/BKISemanticMapping/data/rellis3d/test"

gt_all = np.array([])
pred_all = np.array([])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

rospy.init_node('talker', anonymous=True)
map_pub = rospy.Publisher('SemMap', MarkerArray, queue_size=10)
bki_map = DiscreteBKI(
    torch.tensor([256, 256, 16]).to(device), # Grid size
    torch.tensor([-25.6, -25.6, -2.0]).to(device), # Lower bound
    torch.tensor([25.6, 25.6, 1.2]).to(device), # Upper bound
    device=device
)

rellis_ds = Rellis3dDataset(directory=DATASET_DIR, device=device, 
    num_frames=10, remap=True, use_aug=False, use_gt=True, model_setting="test")

num_correct = 0
num_total = 0
# running_inter = np.zeros((NUM_CLASSES,))
# running_union =  np.zeros((NUM_CLASSES,))
running_iou = np.zeros((NUM_CLASSES,))
running_counter = 0
for idx in range(len(rellis_ds)):
    sem_evaluation_dir = os.path.join(sbki_folder, "evaluations", str(idx).zfill(6)+".txt")
    ssc_evaluation_dir = os.path.join(sbki_folder, "ssc", str(idx).zfill(6)+".txt")
    
    sem_evaluation = np.loadtxt(sem_evaluation_dir, dtype=np.uint8).reshape(-1)
    ssc_evaluation = np.loadtxt(ssc_evaluation_dir, dtype=np.uint8)

    # Load rellis3D dataset into np arrays
    gt_pc, gt_labels, gt_voxels, _, _ = rellis_ds[idx]
    gt_pc       = np.vstack(np.array(gt_pc))
    gt_labels   = np.vstack(np.array(gt_labels))
    gt_voxels_np = np.array(gt_voxels).astype(np.uint8)
    gt_voxels_np = gt_voxels_np.reshape(-1)

    # Process free space labels
    fs_pc           = ray_trace_batch(gt_pc, gt_labels, 0.3, device)
    gt_pc_np        = np.vstack( (gt_pc, fs_pc[:, :3].reshape(-1, 3)))
    gt_labels_np    = np.vstack( (gt_labels, fs_pc[:, 3].reshape(-1, 1))).reshape(-1)


    # Ignore void labels from ground truth
    void_mask = gt_labels_np!=0
    gt_pc_np = gt_pc_np[void_mask]
    gt_labels_np = gt_labels_np[void_mask]

    if gt_pc_np.shape[0] <= 0:
        # Zero pad in case all labels are 0
        gt_pc_np = np.zeros((1, 3))
        gt_labels_np = np.zeros((1,))

    iou_score       = jaccard_score(gt_labels_np, sem_evaluation, 
        labels=np.arange(0, NUM_CLASSES), average=None, zero_division=0)
    running_iou     += iou_score
    running_counter += 1
    num_correct += np.sum(sem_evaluation==gt_labels_np)
    num_total   += gt_labels_np.shape[0]

    print("IoU ", running_iou / running_counter)
    print("mIoU ", np.mean(running_iou / running_counter))
    print("Accuracy ", num_correct / num_total)

    # Visualization for Debugging
    # gt_pc_torch = torch.from_numpy(gt_pc_np)
    # preds = torch.from_numpy(sem_evaluation.astype(np.uint8)).reshape(-1)
    # publish_pc(gt_pc_torch, preds, map_pub, 
    #         bki_map.min_bound.reshape(-1),
    #         bki_map.max_bound.reshape(-1),
    #         bki_map.grid_size.reshape(-1))
    # pdb.set_trace()
    # gt_labels_torch = torch.from_numpy(gt_labels_np.astype(np.uint8)).reshape(-1)
    # publish_pc(gt_pc_torch, gt_labels_torch, map_pub, 
    #         bki_map.min_bound.reshape(-1),
    #         bki_map.max_bound.reshape(-1),
    #         bki_map.grid_size.reshape(-1))
    # pdb.set_trace()

print("Final IoU ", running_iou / running_counter)
print("Final mIoU ", np.mean(running_iou / running_counter))
print("Final Accuracy ", num_correct / num_total)


IoU  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
mIoU  0.0
Accuracy  0.0
IoU  [0.         0.         0.337594   0.3445828  0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.42799462 0.25754599 0.37765801 0.         0.37133353
 0.28409091 0.         0.48940577]
mIoU  0.13762883975404638
Accuracy  0.9792257479243497
IoU  [0.         0.         0.44748453 0.44824639 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.54817945 0.32126368 0.50139922 0.         0.49078837
 0.38568041 0.         0.65227115]
mIoU  0.1807292005852089
Accuracy  0.9785702733907865
IoU  [0.         0.         0.50094427 0.49697117 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.60290379 0.35124021 0.56206753 0.         0.55052237
 0.43372398 0.         0.73358898]
mIoU  0.20152201368572684
Accuracy  0.9781909260151003
IoU  [0.         0.         0.53241009 0.52385167 0.         0.
 0.