In [16]:
import sys
sys.path.insert(1, '/home/yannik/vssil')

import torch
from torch.nn.functional import interpolate

from src.utils.kpt_utils import get_image_patches, kpts_2_img_coordinates
from src.losses.kpt_tracking_metric import kpt_tracking_metric
from src.losses.kpt_visual_metric import kpt_visual_metric
from src.losses.spatial_consistency_loss import spatial_consistency_loss
from contrastive_loss_test_data.test_keypoints import get_perfect_keypoints, get_bad_keypoints, get_random_keypoints
from contrastive_loss_test_data.test_data import load_sample_images 


sample_size = 4
batch_size = 16
patch_size = (3, 3)
n_bins = 200
p = float('inf')

In [17]:
torch.manual_seed(123)

# Load example image frames
img_tensor = load_sample_images(sample_size=sample_size,
                                path="/home/yannik/vssil/contrastive_loss_test_data/990000.mp4").unsqueeze(0)
img_tensor = img_tensor.repeat((batch_size, 1, 1, 1, 1))
img_tensor = interpolate(img_tensor, size=(3, 64, 64))
N, T, C, H, W = img_tensor.shape

In [18]:
# Load example key-points
perfect_kpt_coordinates = get_perfect_keypoints(T=sample_size).unsqueeze(0)
perfect_kpt_coordinates = perfect_kpt_coordinates.repeat((batch_size, 1, 1, 1))

bad_kpt_coordinates = get_bad_keypoints(T=sample_size).unsqueeze(0)
bad_kpt_coordinates = bad_kpt_coordinates.repeat((batch_size, 1, 1, 1))

random_kpt_coordinates = get_random_keypoints(T=sample_size).unsqueeze(0)
random_kpt_coordinates = random_kpt_coordinates.repeat((batch_size, 1, 1, 1))

In [19]:
# Evaluating tracking metric
# The lower the result, the less visual difference of image patches around key-points over time
M_track_perfect = kpt_tracking_metric(perfect_kpt_coordinates, img_tensor, patch_size, n_bins, p)
M_track_bad = kpt_tracking_metric(bad_kpt_coordinates, img_tensor, patch_size, n_bins, p)
M_track_random = kpt_tracking_metric(random_kpt_coordinates, img_tensor, patch_size, n_bins, p)
print(M_track_perfect)
print(M_track_bad)
print(M_track_random)

tensor(0.0726)
tensor(0.0549)
tensor(0.1646)


In [20]:
# Evaluating visual metric
# The higher the result, the higher the visual differences of image patches across key-points
M_vis_perfect = kpt_visual_metric(perfect_kpt_coordinates, img_tensor, patch_size, n_bins, p)
M_vis_bad = kpt_visual_metric(bad_kpt_coordinates, img_tensor, patch_size, n_bins, p)
M_vis_random = kpt_visual_metric(random_kpt_coordinates, img_tensor, patch_size, n_bins, p)
print(M_vis_perfect)
print(M_vis_bad)
print(M_vis_random)

tensor(0.0099)
tensor(0.0070)
tensor(0.0268)
