In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from algorithms.feature_extraction_loading import FeatureDataset
from algorithms.utils import feature_collate_fn
from learning_based.weighted_features_tracker import WeightedFeaturesTracker, WeightedHeatmapsTracker
from math import ceil

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## Data

dataset = FeatureDataset("features/davis/")
dataloader = DataLoader(dataset, 1, shuffle=True, collate_fn=feature_collate_fn)

In [3]:
def get_query_batch(query_points, target_points, occluded, trackgroup, batch_size=8, shuffle=True, drop_last=False):
    """
    Yields a tuple containing one batch of query_points, target_points, occluded, trackgroup
    """
    num_points = query_points.shape[0]

    if shuffle:
        permutation = np.random.permutation(num_points)
        query_points = query_points[permutation]
        target_points = target_points[permutation]
        occluded = occluded[permutation]
        trackgroup = trackgroup[permutation]

    if drop_last:
        num_batches = num_points // batch_size
    else:
        num_batches = ceil(num_points / batch_size)
    
    for i in range(num_batches):
        start = i*batch_size
        end = min((i+1)*batch_size, num_points)

        yield (
            query_points[start:end],
            target_points[start:end],
            occluded[start:end],
            trackgroup[start:end],
        )

In [4]:
## Train input
device = "cpu"

for i, data in enumerate(dataloader):
    data = data[0]

    ### Overfitting to single data point
    data["query_points"] = data["query_points"][:,:1,:]
    data["target_points"] = data["target_points"][:,:1,:]
    data["occluded"] = data["occluded"][:,:1]
    data["trackgroup"] = data["trackgroup"][:,:1]

    feature_dict = data['features']
    for block_name, block_feat_list in feature_dict.items():
        for i in range(len(block_feat_list)):
            feature_dict[block_name][i] = feature_dict[block_name][i].to(dtype=torch.float64).to(device)

    for query_batch in get_query_batch(data["query_points"][0], data["target_points"][0], data["occluded"][0], data["trackgroup"][0]):

        query_points, target_points, occluded, trackgroup = query_batch

        query_points = torch.tensor(query_points, dtype=torch.float32, device=device)
        target_points = torch.tensor(target_points[..., [1, 0]], dtype=torch.float32, device=device)
        occluded = torch.tensor(occluded, dtype=torch.float32, device=device)
        trackgroup = torch.tensor(trackgroup, dtype=torch.float32, device=device)


In [6]:
for name, blocklist in feature_dict.items():
    print(name)
    for map in blocklist:
        print(map.shape)

up_block
torch.Size([50, 10, 8, 8])
torch.Size([50, 10, 16, 16])
torch.Size([50, 10, 32, 32])
torch.Size([50, 10, 32, 32])
down_block
torch.Size([50, 10, 16, 16])
torch.Size([50, 10, 8, 8])
torch.Size([50, 10, 4, 4])
torch.Size([50, 10, 4, 4])
mid_block
torch.Size([50, 10, 4, 4])
decoder_block
torch.Size([50, 10, 64, 64])
torch.Size([50, 10, 128, 128])
torch.Size([50, 10, 256, 256])
torch.Size([50, 10, 256, 256])


In [43]:
## Function to check gradients of
from torchvision.transforms.functional import resize 
from algorithms.heatmap_generator import HeatmapGenerator
from algorithms.zero_shot_tracker import ZeroShotTracker
from algorithms.feature_extraction_loading import concatenate_video_features

def fwd_func(w1, w2, w3, w4):

    # list_up = feature_dict["up_block"]
    # list_down = feature_dict["down_block"]
    # list_mid = feature_dict["mid_block"]
    # list_decoder = feature_dict["decoder_block"]
      
    # list_scaled_up = [weight * feature_maps for weight, feature_maps in zip(w1, list_up)]
    # list_scaled_down = [weight * feature_maps for weight, feature_maps in zip(w2, list_down)]
    # list_scaled_mid = [weight * feature_maps for weight, feature_maps in zip(w3, list_mid)]
    # list_scaled_decoder = [weight * feature_maps for weight, feature_maps in zip(w4, list_decoder)]

    # output = list_scaled_up[0]

    # feature_dict_output = {"up_block": list_scaled_up}

    
    feature_dict_output = feature_dict.copy()

    w = (w1, w2, w3, w4)
    for i, (block_name, block_feature_list) in enumerate(feature_dict_output.items()):
       feature_dict_output[block_name] = [weight * feature_maps for weight, feature_maps in zip(w[i], block_feature_list)]

    concat_features = concatenate_video_features(feature_dict)
    
    heatmap_generator = HeatmapGenerator()
    query_points_copy = query_points.clone()
    hmps = heatmap_generator.generate(concat_features, query_points_copy)

    tracker = ZeroShotTracker()
    tracks = tracker.track(hmps)

    target_points_copy = target_points.clone()

    loss = torch.nn.MSELoss()
    output = loss(tracks, target_points_copy)

    print(output)

    return output

In [44]:
## Check gradient
from torch.autograd import gradcheck

w1 = torch.tensor([1, 1, 1, 1], dtype=torch.float64, device=device, requires_grad=True)
w2 = torch.tensor([1, 1, 1, 1], dtype=torch.float64, device=device, requires_grad=True)
w3 = torch.tensor([1,], dtype=torch.float64, device=device, requires_grad=True)
w4 = torch.tensor([1, 1, 1, 1], dtype=torch.float64, device=device, requires_grad=True)


test = gradcheck(fwd_func, (w1, w2, w3, w4))
test


tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449.1927, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(449

True