In [16]:
import os
from config import get_opts
from models.ngp_wrapper import NGP_Prop_Art_Wrapper, NGP_Prop_Wrapper, NGP_Prop_Art_Seg_Wrapper
import torch
from torch.utils.data import DataLoader
from dataset.sapien import SapienParisDataset
from tqdm import tqdm
import open3d as o3d
import sys
import traceback
from pose_estimation import PoseEstimator_multipart
from test_ngp import NGPevaluator
from dataset.pose_utils import quaternion_to_axis_angle, get_rotation_axis_angle
from dataset.io_utils import load_gt_from_json, load_multipart_gt
from models.utils import axis_metrics, geodesic_distance, translational_error
import numpy as np
import torch.nn.functional as F
import shutil
from pathlib import Path as P
import torchvision.transforms.functional as tvF
import matplotlib.pyplot as plt

# stapler_path = P("./results_stable/stapler_end_to_start_f16/1713970094")
# fridge_path = P("./results_stable/fridge_end_to_start_f16/1713977542")
# laptop_path = P("./results_stable/laptop_start_to_end_f16/1713997073")
# oven_path = P("./results_stable/oven_start_to_end_f16/1713998823")
# scissor_path = P("./results_stable/scissor_end_to_start_f16/1714000996")

# # prismatic
# storage_path = P("./results_stable/storage_end_to_start_f16/1714002072")
# blade_path = P("./results_stable/blade_start_to_end_f16/1714009298")



stapler_path = P("./results_ablation/stapler_end_to_start_f100/")
fridge_path = P("./results_ablation/fridge_end_to_start_f100/")
laptop_path = P("./results_ablation/laptop_start_to_end_f100/")
oven_path = P("./results_ablation/oven_start_to_end_f100/")
scissor_path = P("./results_ablation/scissor_end_to_start_f100/")

# prismatic
storage_path = P("./results_ablation/storage_end_to_start_f100/")
blade_path = P("./results_ablation/blade_start_to_end_f100/")

# path_list = [stapler_path, fridge_path, laptop_path, oven_path, scissor_path, storage_path, blade_path]
# path_list = [oven_path, scissor_path, storage_path]
path_list = [scissor_path]
# box
box_path = P("./results_stable/box_start_to_end_f16/1714063616")
glasses_path = P("./results_stable/glasses_end_to_start_f16/1714068462")
oven_mp_path = P("./results_ablation/oven_mp_start_to_end_f100/1715607722")

1713527878

1713527878

In [15]:

ckpt_path = stapler_path
for obj_path in path_list:
    mious = []
    ckpt_paths = obj_path.glob('*')
    
    for ckpt_path in ckpt_paths:
        
        print(f'evaluating at path: {str(ckpt_path)}')
        config_file = ckpt_path / 'eval' / 'config.json'
        if not config_file.exists():
            continue
        config_argv = ['--config', str(config_file)]
        opts = get_opts(config_argv)
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'

        setattr(opts, 'device', device)

        color_map = plt.get_cmap('Set3', 5)
        cmap = color_map.colors[:, :3]
        cmap[0, :] = 0

        opts.pre_trained_weights = str(ckpt_path / 'ckpt' / 'best_ckpt.pth')
        if opts.state == 'start':
            opts.state = 'end'
        else:
            opts.state = 'start'
        model = NGP_Prop_Art_Seg_Wrapper(config=opts, training=False, 
                                        ignore_empty=False, use_timestamp=True, use_se3=opts.use_se3, mkdir=False)
        test_dir = model.output_path / 'test'
        test_dir.mkdir(exist_ok=True)
        test_dataset = SapienParisDataset(
                root_dir = opts.root_dir,
                near = opts.near_plane,
                far = opts.far_plane,
                img_wh = opts.img_wh, 
                batch_size=opts.batch_size,
                split='test',
                render_bkgd='white',
                state=opts.state
            )

        ious = []
        for p in model.pose_module_list:
            p.init_param()
        for idx in tqdm(range(test_dataset.poses.shape[0])):
            batch_data = test_dataset.__getitem__(idx)
            render_batch = model.test(batch_data)
            seg_pred = render_batch['seg_label']
            seg_bg = torch.ones_like(seg_pred[:, :, 0:1])
            seg_bg = seg_bg - seg_pred.sum(dim=-1, keepdim=True)
            seg_pred_bg = torch.cat([seg_bg, seg_pred], dim=-1)
            seg_classes = torch.argmax(seg_pred_bg, dim=-1, keepdim=True)
            seg_classes_np = seg_classes.cpu().numpy()
            test_img_name = str(test_dir / (str(idx).zfill(4) + '.png'))
            
            # np.save(test_seg_name, seg_classes_np)
            gt_img = str(test_dir / (str(idx).zfill(4) + '_gt.png'))
            cur_gt = test_dataset.seg[idx] * 255
            
            if 'stapler' in str(ckpt_path):
                cur_gt = cur_gt - 1
                cur_gt[cur_gt == 3] = 2
            if 'oven' in str(ckpt_path):
                cur_gt[(cur_gt < 9) & (cur_gt > 1)] = 1
                cur_gt[cur_gt == 9] = 2
            if P(test_dataset.root_dir).parent.name == 'storage':
                cur_gt = cur_gt - 1
                cur_gt[cur_gt == 2] = 1
                cur_gt[cur_gt == 3] = 1
                cur_gt[cur_gt == 4] = 2
                cur_gt[cur_gt == 5] = 1
            cur_gt[cur_gt < 0] = 0
            seg_classes[cur_gt == 0] = 0
            valid_pred = seg_classes[cur_gt != 0].view(-1)
            valid_gt = cur_gt[cur_gt != 0].to(valid_pred).view(-1)
            tp = (valid_pred == valid_gt).sum()
            valid_sum = valid_gt.shape[0]
            iou = tp/valid_sum
            ious += [iou]
            pred_seg_img = seg_classes.view(800, 800).cpu().numpy()
            gt_seg_img = cur_gt.view(800, 800).cpu().numpy()
            plt.imsave(test_img_name, cmap[pred_seg_img.astype(np.int16)])
            print(test_img_name)
            plt.imsave(gt_img, cmap[gt_seg_img.astype(np.int16)])
        break
    #     ious = torch.stack(ious)
    #     # print(f'mIoU for ckpt: {ckpt_path}')
    #     std, mean = torch.std_mean(ious, dim=0)
    #     # print(f'mean: {mean.item()}, std: {std.item()}')
    #     mious += [mean]
    # mious = torch.stack(mious)
    # stds, means = torch.std_mean(mious)
    # print(f'evaluating at obj: {str(obj_path)}')
    # print(f'mean: {means.item()}, std: {stds.item()}')

evaluating at path: results_ablation/scissor_end_to_start_f100/1714839015


  2%|▏         | 1/50 [00:00<00:38,  1.27it/s]

results_ablation/scissor_end_to_start_f100/1714839015/test/0000.png


  4%|▍         | 2/50 [00:01<00:36,  1.33it/s]

results_ablation/scissor_end_to_start_f100/1714839015/test/0001.png


  6%|▌         | 3/50 [00:02<00:34,  1.35it/s]

results_ablation/scissor_end_to_start_f100/1714839015/test/0002.png


  8%|▊         | 4/50 [00:02<00:33,  1.36it/s]

results_ablation/scissor_end_to_start_f100/1714839015/test/0003.png


 10%|█         | 5/50 [00:03<00:32,  1.37it/s]

results_ablation/scissor_end_to_start_f100/1714839015/test/0004.png


 12%|█▏        | 6/50 [00:04<00:32,  1.37it/s]

results_ablation/scissor_end_to_start_f100/1714839015/test/0005.png


 14%|█▍        | 7/50 [00:05<00:32,  1.33it/s]

results_ablation/scissor_end_to_start_f100/1714839015/test/0006.png





KeyboardInterrupt: 

In [47]:
std, mean = torch.std_mean(ious, dim=0)
print(f'mean: {mean}, std: {std}')

mean: 0.9906512498855591, std: 0.007663434371352196


In [10]:

ckpt_path = box_path

config_file = ckpt_path / 'eval' / 'config.json'
config_argv = ['--config', str(config_file)]
opts = get_opts(config_argv)
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

setattr(opts, 'device', device)

color_map = plt.get_cmap('Set3', 5)
cmap = color_map.colors[:, :3]
cmap[0, :] = 0

opts.pre_trained_weights = str(ckpt_path / 'ckpt' / 'best_ckpt.pth')
if opts.state == 'start':
    opts.state = 'end'
else:
    opts.state = 'start'
model = NGP_Prop_Art_Seg_Wrapper(config=opts, training=False, 
                                ignore_empty=False, use_timestamp=True, use_se3=opts.use_se3, mkdir=False)
test_dir = model.output_path / 'test'
test_dir.mkdir(exist_ok=True)
test_dataset = SapienParisDataset(
        root_dir = opts.root_dir,
        near = opts.near_plane,
        far = opts.far_plane,
        img_wh = opts.img_wh, 
        batch_size=opts.batch_size,
        split='test',
        render_bkgd='white',
        state=opts.state
    )

ious = []
for p in model.pose_module_list:
    p.init_param()
for idx in tqdm(range(test_dataset.poses.shape[0])):
    batch_data = test_dataset.__getitem__(idx)
    render_batch = model.test(batch_data)
    seg_pred = render_batch['seg_label']
    seg_bg = torch.ones_like(seg_pred[:, :, 0:1])
    seg_bg = seg_bg - seg_pred.sum(dim=-1, keepdim=True)
    seg_pred_bg = torch.cat([seg_bg, seg_pred], dim=-1)
    seg_classes = torch.argmax(seg_pred_bg, dim=-1, keepdim=True)
    seg_classes_np = seg_classes.cpu().numpy()
    test_img_name = str(test_dir / (str(idx).zfill(4) + '.png'))
    
    # np.save(test_seg_name, seg_classes_np)
    gt_img = str(test_dir / (str(idx).zfill(4) + '_gt.png'))
    cur_gt = test_dataset.seg[idx] * 255
    
    if 'stapler' in str(ckpt_path):
        cur_gt = cur_gt - 1
        cur_gt[cur_gt == 3] = 2
    if 'oven' in str(ckpt_path):
        cur_gt[(cur_gt < 9) & (cur_gt > 1)] = 1
        cur_gt[cur_gt == 9] = 2
    if P(test_dataset.root_dir).parent.name == 'storage':
        cur_gt = cur_gt - 1
        cur_gt[cur_gt == 2] = 1
        cur_gt[cur_gt == 3] = 1
        cur_gt[cur_gt == 4] = 2
        cur_gt[cur_gt == 5] = 1
    if P(test_dataset.root_dir).parent.name == 'box':
        # cur_gt = cur_gt - 1
        cur_gt[cur_gt==2] = 1
        # cur_gt[cur_gt == 3] = 2
        cur_gt[cur_gt == 4] = 2
    cur_gt[cur_gt < 0] = 0
    seg_classes[cur_gt == 0] = 0
    valid_pred = seg_classes[cur_gt != 0].view(-1)
    valid_gt = cur_gt[cur_gt != 0].to(valid_pred).view(-1)
    tp = (valid_pred == valid_gt).sum()
    valid_sum = valid_gt.shape[0]
    iou = tp/valid_sum
    ious += [iou]
    pred_seg_img = seg_classes.view(800, 800).cpu().numpy()
    gt_seg_img = cur_gt.view(800, 800).cpu().numpy()
    plt.imsave(test_img_name, cmap[pred_seg_img.astype(np.int16)])
    plt.imsave(gt_img, cmap[gt_seg_img.astype(np.int16)])
ious = torch.stack(ious)

std, mean = torch.std_mean(ious, dim=0)
print(f'mean: {mean.item()}, std: {std.item()}')

100%|██████████| 50/50 [00:42<00:00,  1.17it/s]

mean: 0.9889721870422363, std: 0.005943655967712402





In [26]:
# save segmentation results for stapler

    # break

### glasses

In [17]:
ckpt_path = glasses_path

config_file = ckpt_path / 'eval' / 'config.json'
config_argv = ['--config', str(config_file)]
opts = get_opts(config_argv)
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

setattr(opts, 'device', device)

color_map = plt.get_cmap('Set3', 5)
cmap = color_map.colors[:, :3]
cmap[0, :] = 0

opts.pre_trained_weights = str(ckpt_path / 'ckpt' / 'best_ckpt.pth')
if opts.state == 'start':
    opts.state = 'end'
else:
    opts.state = 'start'
model = NGP_Prop_Art_Seg_Wrapper(config=opts, training=False, 
                                ignore_empty=False, use_timestamp=True, use_se3=opts.use_se3, mkdir=False)
test_dir = model.output_path / 'test'
test_dir.mkdir(exist_ok=True)
test_dataset = SapienParisDataset(
        root_dir = opts.root_dir,
        near = opts.near_plane,
        far = opts.far_plane,
        img_wh = opts.img_wh, 
        batch_size=opts.batch_size,
        split='test',
        render_bkgd='white',
        state=opts.state
    )

ious = []
for p in model.pose_module_list:
    p.init_param()
for idx in tqdm(range(test_dataset.poses.shape[0])):
    batch_data = test_dataset.__getitem__(idx)
    render_batch = model.test(batch_data)
    seg_pred = render_batch['seg_label']
    seg_bg = torch.ones_like(seg_pred[:, :, 0:1])
    seg_bg = seg_bg - seg_pred.sum(dim=-1, keepdim=True)
    seg_pred_bg = torch.cat([seg_bg, seg_pred], dim=-1)
    seg_classes = torch.argmax(seg_pred_bg, dim=-1, keepdim=True)
    seg_classes_np = seg_classes.cpu().numpy()
    test_img_name = str(test_dir / (str(idx).zfill(4) + '.png'))
    
    # np.save(test_seg_name, seg_classes_np)
    gt_img = str(test_dir / (str(idx).zfill(4) + '_gt.png'))
    cur_gt = test_dataset.seg[idx] * 255
    
    if 'stapler' in str(ckpt_path):
        cur_gt = cur_gt - 1
        cur_gt[cur_gt == 3] = 2
    if 'oven' in str(ckpt_path):
        cur_gt[(cur_gt < 9) & (cur_gt > 1)] = 1
        cur_gt[cur_gt == 9] = 2
    if P(test_dataset.root_dir).parent.name == 'storage':
        cur_gt = cur_gt - 1
        cur_gt[cur_gt == 2] = 1
        cur_gt[cur_gt == 3] = 1
        cur_gt[cur_gt == 4] = 2
        cur_gt[cur_gt == 5] = 1
    if P(test_dataset.root_dir).parent.name == 'box':
        # cur_gt = cur_gt - 1
        cur_gt[cur_gt==2] = 1
        # cur_gt[cur_gt == 3] = 2
        cur_gt[cur_gt == 4] = 2
    if P(test_dataset.root_dir).parent.name == 'glasses':
        cur_gt[cur_gt==2] = 1
        cur_gt[cur_gt==3] = 2
        cur_gt[cur_gt==4] = 3
        
    cur_gt[cur_gt < 0] = 0
    seg_classes[cur_gt == 0] = 0
    valid_pred = seg_classes[cur_gt != 0].view(-1)
    valid_gt = cur_gt[cur_gt != 0].to(valid_pred).view(-1)
    tp = (valid_pred == valid_gt).sum()
    valid_sum = valid_gt.shape[0]
    iou = tp/valid_sum
    ious += [iou]
    pred_seg_img = seg_classes.view(800, 800).cpu().numpy()
    gt_seg_img = cur_gt.view(800, 800).cpu().numpy()
    plt.imsave(test_img_name, cmap[pred_seg_img.astype(np.int16)])
    plt.imsave(gt_img, cmap[gt_seg_img.astype(np.int16)])
ious = torch.stack(ious)

std, mean = torch.std_mean(ious, dim=0)
print(f'mean: {mean.item()}, std: {std.item()}')

100%|██████████| 50/50 [00:42<00:00,  1.17it/s]

mean: 0.9844302535057068, std: 0.002639641985297203





### oven_mp

In [17]:
ckpt_path = oven_mp_path

config_file = ckpt_path / 'eval' / 'config.json'
config_argv = ['--config', str(config_file)]
opts = get_opts(config_argv)
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

setattr(opts, 'device', device)

color_map = plt.get_cmap('Set3', 5)
cmap = color_map.colors[:, :3]
cmap[0, :] = 0

opts.pre_trained_weights = str(ckpt_path / 'ckpt' / 'best_ckpt.pth')
if opts.state == 'start':
    opts.state = 'end'
else:
    opts.state = 'start'
model = NGP_Prop_Art_Seg_Wrapper(config=opts, training=False, 
                                ignore_empty=False, use_timestamp=True, use_se3=opts.use_se3, mkdir=False)
test_dir = model.output_path / 'test'
test_dir.mkdir(exist_ok=True)
test_dataset = SapienParisDataset(
        root_dir = opts.root_dir,
        near = opts.near_plane,
        far = opts.far_plane,
        img_wh = opts.img_wh, 
        batch_size=opts.batch_size,
        split='test',
        render_bkgd='white',
        state=opts.state
    )

ious = []
for p in model.pose_module_list:
    p.init_param()
for idx in tqdm(range(test_dataset.poses.shape[0])):
    batch_data = test_dataset.__getitem__(idx)
    render_batch = model.test(batch_data)
    seg_pred = render_batch['seg_label']
    seg_bg = torch.ones_like(seg_pred[:, :, 0:1])
    seg_bg = seg_bg - seg_pred.sum(dim=-1, keepdim=True)
    seg_pred_bg = torch.cat([seg_bg, seg_pred], dim=-1)
    seg_classes = torch.argmax(seg_pred_bg, dim=-1, keepdim=True)
    seg_classes_np = seg_classes.cpu().numpy()
    test_img_name = str(test_dir / (str(idx).zfill(4) + '.png'))
    
    # np.save(test_seg_name, seg_classes_np)
    gt_img = str(test_dir / (str(idx).zfill(4) + '_gt.png'))
    cur_gt = test_dataset.seg[idx] * 255
    
    if 'stapler' in str(ckpt_path):
        cur_gt = cur_gt - 1
        cur_gt[cur_gt == 3] = 2
    # if 'oven' in str(ckpt_path):
    if P(test_dataset.root_dir).parent.name == 'oven':
        cur_gt[(cur_gt < 9) & (cur_gt > 1)] = 1
        cur_gt[cur_gt == 9] = 2
    if P(test_dataset.root_dir).parent.name == 'storage':
        cur_gt = cur_gt - 1
        cur_gt[cur_gt == 2] = 1
        cur_gt[cur_gt == 3] = 1
        cur_gt[cur_gt == 4] = 2
        cur_gt[cur_gt == 5] = 1
    if P(test_dataset.root_dir).parent.name == 'box':
        # cur_gt = cur_gt - 1
        cur_gt[cur_gt==2] = 1
        # cur_gt[cur_gt == 3] = 2
        cur_gt[cur_gt == 4] = 2
    if P(test_dataset.root_dir).parent.name == 'glasses':
        cur_gt[cur_gt==2] = 1
        cur_gt[cur_gt==3] = 2
        cur_gt[cur_gt==4] = 3
    if P(test_dataset.root_dir).parent.name == 'oven_mp':
        cur_gt[cur_gt==2] = 1
        # cur_gt[cur_gt==3] = 2
        cur_gt[cur_gt==4] = 2
        cur_gt[cur_gt>3] = 1
    cur_gt[cur_gt < 0] = 0
    seg_classes[cur_gt == 0] = 0
    valid_pred = seg_classes[cur_gt != 0].view(-1)
    valid_gt = cur_gt[cur_gt != 0].to(valid_pred).view(-1)
    tp = (valid_pred == valid_gt).sum()
    valid_sum = valid_gt.shape[0]
    iou = tp/valid_sum
    ious += [iou]
    pred_seg_img = seg_classes.view(800, 800).cpu().numpy()
    gt_seg_img = cur_gt.view(800, 800).cpu().numpy()
    plt.imsave(test_img_name, cmap[pred_seg_img.astype(np.int16)])
    plt.imsave(gt_img, cmap[gt_seg_img.astype(np.int16)])
ious = torch.stack(ious)

std, mean = torch.std_mean(ious, dim=0)
print(f'mean: {mean.item()}, std: {std.item()}')

100%|██████████| 50/50 [00:53<00:00,  1.07s/it]

mean: 0.9869259595870972, std: 0.007746274583041668





In [21]:
cur_gt.unique()

tensor([0., 2., 3., 4.])

In [36]:
ckpt_folder = ckpt_path / 'ckpt'

In [38]:
list(ckpt_folder.glob('*'))

[PosixPath('/home/dj/Downloads/project/nerfacc_ngp/results_stable/stapler_end_to_start_f16/1713970094/ckpt/002000.pth'),
 PosixPath('/home/dj/Downloads/project/nerfacc_ngp/results_stable/stapler_end_to_start_f16/1713970094/ckpt/004000.pth'),
 PosixPath('/home/dj/Downloads/project/nerfacc_ngp/results_stable/stapler_end_to_start_f16/1713970094/ckpt/006000.pth'),
 PosixPath('/home/dj/Downloads/project/nerfacc_ngp/results_stable/stapler_end_to_start_f16/1713970094/ckpt/008000.pth'),
 PosixPath('/home/dj/Downloads/project/nerfacc_ngp/results_stable/stapler_end_to_start_f16/1713970094/ckpt/best_ckpt.pth'),
 PosixPath('/home/dj/Downloads/project/nerfacc_ngp/results_stable/stapler_end_to_start_f16/1713970094/ckpt/010000.pth')]

### storage MP

In [30]:
storage_mp_path = P("/home/dj/Downloads/project/nerfacc_ngp/results_stable/storage_mp_start_to_end_f16/1713527878")

ckpt_path = storage_mp_path
config_file = ckpt_path / 'eval' / 'config.json'
config_argv = ['--config', str(config_file)]
opts = get_opts(config_argv)
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

setattr(opts, 'device', device)

color_map = plt.get_cmap('Set3', 5)
cmap = color_map.colors[:, :3]
cmap[0, :] = 0

opts.pre_trained_weights = str(ckpt_path / 'ckpt' / 'best_ckpt.pth')
if opts.state == 'start':
    opts.state = 'end'
else:
    opts.state = 'start'
model = NGP_Prop_Art_Seg_Wrapper(config=opts, training=False, 
                                ignore_empty=False, use_timestamp=True, use_se3=opts.use_se3, mkdir=False)
test_dir = model.output_path / 'test'
test_dir.mkdir(exist_ok=True)
test_dataset = SapienParisDataset(
        root_dir = opts.root_dir,
        near = opts.near_plane,
        far = opts.far_plane,
        img_wh = opts.img_wh, 
        batch_size=opts.batch_size,
        split='test',
        render_bkgd='white',
        state=opts.state
    )

ious = []
for p in model.pose_module_list:
    p.init_param()
for idx in tqdm(range(test_dataset.poses.shape[0])):
    batch_data = test_dataset.__getitem__(idx)
    render_batch = model.test(batch_data)
    seg_pred = render_batch['seg_label']
    seg_bg = torch.ones_like(seg_pred[:, :, 0:1])
    seg_bg = seg_bg - seg_pred.sum(dim=-1, keepdim=True)
    seg_pred_bg = torch.cat([seg_bg, seg_pred], dim=-1)
    seg_classes = torch.argmax(seg_pred_bg, dim=-1, keepdim=True)
    seg_classes_np = seg_classes.cpu().numpy()
    test_img_name = str(test_dir / (str(idx).zfill(4) + '.png'))
    
    # np.save(test_seg_name, seg_classes_np)
    gt_img = str(test_dir / (str(idx).zfill(4) + '_gt.png'))
    cur_gt = test_dataset.seg[idx] * 255
    
    if P(test_dataset.root_dir).parent.name ==  'stapler':
        cur_gt = cur_gt - 1
        cur_gt[cur_gt == 3] = 2
    # if 'oven' in str(ckpt_path):
    if P(test_dataset.root_dir).parent.name == 'oven':
        cur_gt[(cur_gt < 9) & (cur_gt > 1)] = 1
        cur_gt[cur_gt == 9] = 2
    if P(test_dataset.root_dir).parent.name == 'storage':
        cur_gt = cur_gt - 1
        cur_gt[cur_gt == 2] = 1
        cur_gt[cur_gt == 3] = 1
        cur_gt[cur_gt == 4] = 2
        cur_gt[cur_gt == 5] = 1
    if P(test_dataset.root_dir).parent.name == 'box':
        # cur_gt = cur_gt - 1
        cur_gt[cur_gt==2] = 1
        # cur_gt[cur_gt == 3] = 2
        cur_gt[cur_gt == 4] = 2
    if P(test_dataset.root_dir).parent.name == 'glasses':
        cur_gt[cur_gt==2] = 1
        cur_gt[cur_gt==3] = 2
        cur_gt[cur_gt==4] = 3
    if P(test_dataset.root_dir).parent.name == 'oven_mp':
        cur_gt[cur_gt==2] = 1
        # cur_gt[cur_gt==3] = 2
        cur_gt[cur_gt==4] = 2
        cur_gt[cur_gt>3] = 1
    if P(test_dataset.root_dir).parent.name == 'storage_mp':
        cur_gt[cur_gt==2] = 1
        cur_gt[cur_gt==3] = 2
        cur_gt[cur_gt==4] = 3
        cur_gt[cur_gt>3] = 1
    cur_gt[cur_gt < 0] = 0
    seg_classes[cur_gt == 0] = 0
    valid_pred = seg_classes[cur_gt != 0].view(-1)
    valid_gt = cur_gt[cur_gt != 0].to(valid_pred).view(-1)
    tp = (valid_pred == valid_gt).sum()
    valid_sum = valid_gt.shape[0]
    iou = tp/valid_sum
    ious += [iou]
    pred_seg_img = seg_classes.view(800, 800).cpu().numpy()
    gt_seg_img = cur_gt.view(800, 800).cpu().numpy()
    plt.imsave(test_img_name, cmap[pred_seg_img.astype(np.int16)])
    plt.imsave(gt_img, cmap[gt_seg_img.astype(np.int16)])
ious = torch.stack(ious)

std, mean = torch.std_mean(ious, dim=0)
print(f'mean: {mean.item()}, std: {std.item()}')

100%|██████████| 50/50 [00:42<00:00,  1.17it/s]

mean: 0.9425734877586365, std: 0.03102875128388405





In [28]:
cur_gt.unique()

tensor([0., 2., 3., 4., 5., 6.])