In [11]:
import sys
sys.path.insert(1, '/home/yannik/vssil')

import torch
from torch.nn.functional import interpolate

from src.utils.kpt_utils import get_image_patches, kpts_2_img_coordinates
from src.losses.kpt_tracking_metric import kpt_tracking_metric
from src.losses.kpt_visual_metric import kpt_visual_metric
from src.losses.spatial_consistency_loss import spatial_consistency_loss
from contrastive_loss_test_data.test_keypoints import get_perfect_keypoints, get_bad_keypoints, get_random_keypoints
from contrastive_loss_test_data.test_data import load_sample_images 


sample_size = 4
batch_size = 16
patch_size = (7, 7)  # (32, 32)
n_bins = 50  # 200
p = float('inf')  # 0, 1, 2, float('inf')

In [12]:
torch.manual_seed(123)

# Load example image frames
img_tensor = load_sample_images(sample_size=sample_size,
                                path="/home/yannik/vssil/contrastive_loss_test_data/990000.mp4").unsqueeze(0)
img_tensor = img_tensor.repeat((batch_size, 1, 1, 1, 1))
img_tensor = interpolate(img_tensor, size=(3, 128, 128))
N, T, C, H, W = img_tensor.shape

In [13]:
# Load example key-points
perfect_kpt_coordinates = get_perfect_keypoints(T=sample_size).unsqueeze(0)
perfect_kpt_coordinates = perfect_kpt_coordinates.repeat((batch_size, 1, 1, 1))

bad_kpt_coordinates = get_bad_keypoints(T=sample_size).unsqueeze(0)
bad_kpt_coordinates = bad_kpt_coordinates.repeat((batch_size, 1, 1, 1))

random_kpt_coordinates = get_random_keypoints(T=sample_size).unsqueeze(0)
random_kpt_coordinates = random_kpt_coordinates.repeat((batch_size, 1, 1, 1))

In [14]:
# Evaluating tracking metric
# The lower the result, the less visual difference of image patches around key-points over time
M_track_perfect, c, g = kpt_tracking_metric(perfect_kpt_coordinates, img_tensor, patch_size, n_bins, p)
M_track_bad, c2, g2 = kpt_tracking_metric(bad_kpt_coordinates, img_tensor, patch_size, n_bins, p)
M_track_random, c3, g3 = kpt_tracking_metric(random_kpt_coordinates, img_tensor, patch_size, n_bins, p)
print(M_track_perfect)
print(M_track_bad)
print(M_track_random)
print()
print(c)
print(c2)
print(c3)
print()
print(g)
print(g2)
print(g3)

tensor(0.0384)
tensor(0.0351)
tensor(0.0094)

tensor(0.1504)
tensor(0.1735)
tensor(0.7189)

tensor(0.1583)
tensor(0.1745)
tensor(0.1485)


In [15]:
# Evaluating visual metric
# The higher the result, the higher the visual differences of image patches across key-points
M_vis_perfect, c1, g1 = kpt_visual_metric(perfect_kpt_coordinates, img_tensor, patch_size, n_bins, p)
M_vis_bad, c2, g2 = kpt_visual_metric(bad_kpt_coordinates, img_tensor, patch_size, n_bins, p)
M_vis_random, c3, g3 = kpt_visual_metric(random_kpt_coordinates, img_tensor, patch_size, n_bins, p)
print(M_vis_perfect)
print(M_vis_bad)
print(M_vis_random)
print()
print(c)
print(c2)
print(c3)
print()
print(g)
print(g2)
print(g3)

tensor(0.0545)
tensor(0.0532)
tensor(0.0096)

tensor(0.1504)
tensor(0.2504)
tensor(0.6659)

tensor(0.1583)
tensor(0.2850)
tensor(0.1716)
