In [None]:
import os
import os.path as osp
import argparse
from datetime import date
import json
import random
import time
from pathlib import Path
import numpy as np
import numpy.linalg as LA
from tqdm import tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt
import cv2
import csv
import warnings

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torch

import datasets
import util.misc as utils
from datasets import build_matterport_dataset,build_scannet_dataset,build_su3_dataset
from models import build_model
from models.matchers import build_matcher
from config import cfg

In [None]:
def AA(x, y, threshold):
    index = np.searchsorted(x, threshold)
    x = np.concatenate([x[:index], [threshold]])
    y = np.concatenate([y[:index], [threshold]])
    return ((x[1:] - x[:-1]) * y[:-1]).sum() / threshold

def get_src_permutation_idx(indices):
    # permute predictions following indices
    batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
    src_idx = torch.cat([src for (src, _) in indices])
    return batch_idx, src_idx

def get_tgt_permutation_idx(indices):
    # permute targets following indices
    batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
    tgt_idx = torch.cat([tgt for (_, tgt) in indices])
    return batch_idx, tgt_idx
    
def to_device(data, device):
    if type(data) == dict:
        return {k: v.to(device) for k, v in data.items()}
    return [{k: v.to(device) if isinstance(v, torch.Tensor) else v
             for k, v in t.items()} for t in data]

def compute_error(vps_pd, vps_gt):
    error = np.arccos(np.abs(vps_gt @ vps_pd.transpose()).clip(max=1))
    # import pdb; pdb.set_trace()
    error = error.min(axis=1) / np.pi * 180.0 # num_pd x num_gt, axis=1
    return error.flatten()

def to_device(data, device):
    if type(data) == dict:
        return {k: v.to(device) for k, v in data.items()}
    return [{k: v.to(device) if isinstance(v, torch.Tensor) else v
             for k, v in t.items()} for t in data]

In [None]:
device = torch.device(cfg.DEVICE)

In [None]:
model, _ = build_model(cfg)
model.to(device)
checkpoint = torch.load('/home/kmuvcl/CTRL-C/su3log/checkpoint0078.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'])
# model = model.eval()
matcher = build_matcher(cfg)

In [None]:
dataset_test = build_su3_dataset(image_set='val', cfg=cfg)
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = DataLoader(dataset_test, 1, sampler=sampler_test,
                                drop_last=False, 
                                collate_fn=utils.collate_fn, 
                                num_workers=2)

In [None]:
for i, (samples, extra_samples, targets) in enumerate(tqdm(data_loader_test)):
    # with torch.no_grad():
    pred_vp = {}
    target_vp = {}
    samples = samples.to(device)
    extra_samples = to_device(extra_samples, device)
    outputs, extra_info = model(extra_samples)
    pred_vp1 = outputs['pred_vp1'].to('cpu')[0]
    pred_vp2 = outputs['pred_vp2'].to('cpu')[0]
    pred_vp3 = outputs['pred_vp3'].to('cpu')[0]
    target_vp1 = targets[0]['vp1']
    target_vp2 = targets[0]['vp2']
    target_vp3 = targets[0]['vp3']
    target_vp4 = targets[0]['vp']

    pred_vp['pred_vp1'] = outputs['pred_vp1'].to('cpu')
    pred_vp['pred_vp2'] = outputs['pred_vp2'].to('cpu')
    pred_vp['pred_vp3'] = outputs['pred_vp3'].to('cpu')

    target_vp['vp1'] = target_vp1
    target_vp['vp2'] = target_vp2
    target_vp['vp3'] = target_vp3
    target_vp['vp'] = target_vp4
    target_vp = (target_vp,)

    indices = matcher(pred_vp,target_vp)
    src_idx = get_src_permutation_idx(indices)
    tgt_idx = get_tgt_permutation_idx(indices) 

    pred_vpts = torch.cat([pred_vp1.unsqueeze(1),pred_vp2.unsqueeze(1),pred_vp3.unsqueeze(1)],dim=1).unsqueeze(0)
    cos_sim = F.cosine_similarity(pred_vpts[src_idx], target_vp4[tgt_idx], dim=-1).abs()
    print(cos_sim)
