In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
import torch.nn.functional as F
from einops import rearrange

from models.multi_task_model import MultiTaskModel
from data.taskonomy_replica_gso_dataset import TaskonomyReplicaGsoDataset

In [None]:
# replica baseline
# pretrained_weights_path = '/scratch/ainaz/omnidata2/experiments/semseg/checkpoints/omnidata/3mevrwvo/epoch=51.ckpt'

# replica semseg + normal (GT)
# pretrained_weights_path = '/scratch/ainaz/omnidata2/experiments/semseg/checkpoints/omnidata/2vts59o9/epoch=52.ckpt'

# replica semseg + normal + edge3d (GT)
pretrained_weights_path = '/scratch/ainaz/omnidata2/experiments/semseg/checkpoints/omnidata/1w8z04sn/epoch=52.ckpt'

# replica semseg + normal + edge3d + edge2d + keypoints3d + depth (GT) 
# pretrained_weights_path = '/scratch/ainaz/omnidata2/experiments/semseg/checkpoints/omnidata/4sk98iky/epoch=51.ckpt'

# replica semseg + normal (PADNET)
# pretrained_weights_path = '/scratch/ainaz/omnidata2/experiments/multitask/checkpoints/omnidata/1ypt36cz/last.ckpt'

# replica semseg + normal + edge3d (PADNET)
# pretrained_weights_path = '/scratch/ainaz/omnidata2/experiments/multitask/checkpoints/omnidata/29sqmzk2/last.ckpt'

# replica semseg + normal + edge3d + edge2d + keypoints3d + depth (PADNET)
# pretrained_weights_path = '/scratch/ainaz/omnidata2/experiments/multitask/checkpoints/omnidata/2rg5cep8/last.ckpt'


In [None]:
# settings
tasks = ['rgb', 'segment_semantic', 'normal', 'edge_occlusion', 'mask_valid']
# tasks = ['rgb', 'normal', 'segment_semantic', 'edge_occlusion', 'edge_texture', 'keypoints3d', 'depth_zbuffer', 'mask_valid']
taskonomy_variant = 'tiny'
batch_size = 16
image_size = 256
normalize_rgb = True

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MultiTaskModel(tasks=['segment_semantic'], n_channels=7, 
                       backbone='hrnet_w18', head='hrnet', pretrained=False, dilated=False)

checkpoint = torch.load(pretrained_weights_path, map_location='cuda:0')
if 'state_dict' in checkpoint:
    state_dict = {}
    for k, v in checkpoint['state_dict'].items():
        state_dict[k.replace('model.', '')] = v
else:
      state_dict = checkpoint
model.load_state_dict(state_dict)
model.to(device)

MultiTaskModel(
  (backbone): HighResolutionNet(
    (conv1): Conv2d(7, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU()
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, t

# Dataloaders

In [None]:
opt_test_taskonomy = TaskonomyReplicaGsoDataset.Options(
    split='test',
    taskonomy_variant=taskonomy_variant,
    tasks=tasks,
    datasets=['taskonomy'],
    transform='DEFAULT',
    image_size=image_size,
    normalize_rgb=normalize_rgb,
    randomize_views=False
)

testset_taskonomy = TaskonomyReplicaGsoDataset(options=opt_test_taskonomy)

# opt_test_replica = TaskonomyReplicaGsoDataset.Options(
#     split='test',
#     taskonomy_variant=taskonomy_variant,
#     tasks=tasks,
#     datasets=['replica'],
#     transform='DEFAULT',
#     image_size=image_size,
#     normalize_rgb=normalize_rgb,
#     randomize_views=False
# )

# testset_replica = TaskonomyReplicaGsoDataset(options=opt_test_replica)

# opt_test_hypersim = TaskonomyReplicaGsoDataset.Options(
#     split='test',
#     taskonomy_variant=taskonomy_variant,
#     tasks=tasks,
#     datasets=['hypersim'],
#     transform='DEFAULT',
#     image_size=image_size,
#     normalize_rgb=normalize_rgb,
#     randomize_views=False
# )

# testset_hypersim = TaskonomyReplicaGsoDataset(options=opt_test_hypersim)


!!!!!!!!!!!!!!!!!!!!!!!!!!!!!  ./tmp/taskonomy_rgb-normal-segment_semantic-keypoints2d-keypoints3d-depth_zbuffer-edge_texture-edge_occlusion-mask_valid_tiny-test.pkl
!! here
Loaded taskonomy with 54514 images from tmp.
!!!!!!!!!!!! rgb :  54514
!!!!!!!!!!!! semantic segmentation :  54514


Loaded 53386 images in 1.06 seconds
	 (5 buildings) (8759 points) (53386 images) for domains ['rgb', 'segment_semantic', 'normal', 'edge_occlusion', 'mask_valid']


In [None]:
test_dataloader_taskonomy = DataLoader(
    testset_taskonomy, batch_size=batch_size, shuffle=False, num_workers=64, pin_memory=False
)
# test_dataloader_replica = DataLoader(
#     testset_replica, batch_size=batch_size, shuffle=False, num_workers=64, pin_memory=False
# )
# test_dataloader_hypersim = DataLoader(
#     testset_hypersim, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=False
# )
# test_dataloader_combined = DataLoader(
#     ConcatDataset([testset_taskonomy, testset_replica, testset_hypersim]), 
#     batch_size=batch_size, shuffle=False, num_workers=64, pin_memory=False
# )

In [None]:
print(len(test_dataloader_taskonomy))
# print(len(test_dataloader_replica))
# print(len(test_dataloader_hypersim))
# print(len(test_dataloader_combined))

3337


# Utils

In [None]:
def make_valid_mask(mask_float, max_pool_size=4, return_small_mask=False):
    '''
        Creates a mask indicating the valid parts of the image(s).
        Enlargens masked area using a max pooling operation.

        Args:
            mask_float: A mask as loaded from the Taskonomy loader.
            max_pool_size: Parameter to choose how much to enlarge masked area.
            return_small_mask: Set to true to return mask for aggregated image
    '''
    if len(mask_float.shape) == 3:
        mask_float = mask_float.unsqueeze(axis=0)
    reshape_temp = len(mask_float.shape) == 5
    if reshape_temp:
        mask_float = rearrange(mask_float, 'b p c h w -> (b p) c h w')
    mask_float = 1 - mask_float
    mask_float = F.max_pool2d(mask_float, kernel_size=max_pool_size)
    mask_float = F.interpolate(mask_float, (image_size, image_size), mode='nearest')
    mask_valid = mask_float == 0
    if reshape_temp:
        mask_valid = rearrange(mask_valid, '(b p) c h w -> b p c h w', p=1)

    return mask_valid

In [None]:
model.eval()
criterion = nn.CrossEntropyLoss(ignore_index=-1)

losses = []
accuracies = []

all_pixels = 0
pos_pixels = 0

with torch.no_grad():
    for step, batch in enumerate(test_dataloader_taskonomy):
        print(step)
        rgb = batch['positive']['rgb'].to(device)
        semantic = batch['positive']['segment_semantic'].to(device)
        normal = batch['positive']['normal'].to(device)
#         depth = batch['positive']['depth_zbuffer'].to(device)
        edge_occlusion = batch['positive']['edge_occlusion'].to(device)
#         edge_texture = batch['positive']['edge_texture'].to(device)
#         keypoints3d = batch['positive']['keypoints3d'].to(device)
        mask_valid = make_valid_mask(batch['positive']['mask_valid']).squeeze(1).to(device)

        labels_gt = semantic[:,:,:,0]

        # background and undefined classes are labeled as 0
        labels_gt[(semantic[:,:,:,0]==255) * (semantic[:,:,:,1]==255) * (semantic[:,:,:,2]==255)] = 0 # background in taskonomy
        labels_gt[labels_gt==-1] = 0  # undefined class in hypersim

        # mask out invalid parts of the mesh, background and undefined label
        labels_gt *= mask_valid # invalid parts of the mesh also have label (0)
        labels_gt -= 1  # the model should not predict undefined and background classes
        
#         combo_input = torch.cat([rgb, normal, depth, edge_texture, edge_occlusion, keypoints3d], dim=1)
        combo_input = torch.cat([rgb, normal, edge_occlusion], dim=1)
        labels_preds = model.forward(combo_input)['segment_semantic']

        # loss
        total_loss = criterion(labels_preds, labels_gt)
        
        # accuracy
        mask_gt = labels_gt + 1
        mask_valid = mask_gt != 0
        if mask_valid.sum() == 0: continue
        mask_preds = F.softmax(labels_preds, dim=1)
        mask_preds = torch.argmax(mask_preds, dim=1) + 1 # model does not predict background/undefined
        mask_preds *= mask_valid
        mask_preds = mask_preds[mask_valid != 0]
        mask_gt = mask_gt[mask_valid != 0]
#         acc = (mask_preds == mask_gt).sum() * 1.0 / mask_gt.size(0)
#         print(acc)

        all_pixels += mask_gt.size(0)
        pos_pixels += (mask_preds == mask_gt).sum()
        
        losses.append(total_loss)      
#         accuracies.append(acc)

  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))


  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))


  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  img = torch.from_numpy(np.array(pic, np.int32, copy=False))


In [None]:
len(losses)

3337

In [None]:
# torch.tensor(accuracies).mean()

In [None]:
torch.tensor(losses).mean()

tensor(6.1244)

In [None]:
pos_pixels * 1.0 / all_pixels

tensor(0.1290, device='cuda:0')