In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from algorithms.feature_extraction_loading import FeatureDataset
from algorithms.utils import feature_collate_fn
#from learning_based.weighted_features_tracker import WeightedFeaturesTracker, WeightedHeatmapsTracker
from math import ceil
from evaluation.visualization import place_marker_in_frames

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
## Data

dataset = FeatureDataset("features/davis/")
dataloader = DataLoader(dataset, 1, shuffle=True, collate_fn=feature_collate_fn)

In [5]:
def get_query_batch(query_points, target_points, occluded, trackgroup, batch_size=8, shuffle=True, drop_last=False):
    """
    Yields a tuple containing one batch of query_points, target_points, occluded, trackgroup
    """
    num_points = query_points.shape[0]

    if shuffle:
        permutation = np.random.permutation(num_points)
        query_points = query_points[permutation]
        target_points = target_points[permutation]
        occluded = occluded[permutation]
        trackgroup = trackgroup[permutation]

    if drop_last:
        num_batches = num_points // batch_size
    else:
        num_batches = ceil(num_points / batch_size)
    
    for i in range(num_batches):
        start = i*batch_size
        end = min((i+1)*batch_size, num_points)

        yield (
            query_points[start:end],
            target_points[start:end],
            occluded[start:end],
            trackgroup[start:end],
        )

In [6]:
from learning_based.learn_upsample_tracker import LearnUpsampleTracker

model = LearnUpsampleTracker(next(iter(dataloader))[0]["features"])

model.load_state_dict(torch.load('trained_upsample_1.pth', map_location=torch.device('cpu')))

model.eval()

LearnUpsampleTracker(
  (heatmap_processor): HeatmapProcessor(
    (softmax): Softmax(dim=-1)
    (relu): ReLU()
    (heatmap_processing_layers): ModuleDict(
      (hid1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (hid2): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (hid3): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (hid4): Linear(in_features=1, out_features=16, bias=True)
      (occ_out): Linear(in_features=16, out_features=1, bias=True)
      (regression_hid): Linear(in_features=32, out_features=128, bias=True)
      (regression_out): Linear(in_features=128, out_features=2, bias=True)
    )
  )
  (softmax): Softmax(dim=None)
  (relu): ReLU()
  (conv1d_up8): Conv2d(10, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv1d_up16): Conv2d(10, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv1d_up32_1): Conv2d(10, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv1d_up32_2): Conv2d(10, 32, kernel_size=(1, 1), stride=(1

In [7]:
## Train input
device = "cpu"

for i, data in enumerate(dataloader):
    data = data[0]

    ### Overfitting to single data point
    data["query_points"] = data["query_points"][:,:1,:]
    data["target_points"] = data["target_points"][:,:1,:]
    data["occluded"] = data["occluded"][:,:1]
    data["trackgroup"] = data["trackgroup"][:,:1]

    feature_dict = data['features']
    for block_name, block_feat_list in feature_dict.items():
        for i in range(len(block_feat_list)):
            feature_dict[block_name][i] = feature_dict[block_name][i].to(dtype=torch.float32).to(device)

    tracks = []
    occ = []
    counter = 0
    for query_batch in get_query_batch(data["query_points"][0], data["target_points"][0], data["occluded"][0], data["trackgroup"][0]):
            print(counter)
            counter += 1

            query_points, target_points, occluded, trackgroup = query_batch

            query_points = torch.tensor(query_points, dtype=torch.float32, device=device)
            target_points = torch.tensor(target_points[..., [1, 0]], dtype=torch.float32, device=device)
            occluded = torch.tensor(occluded, dtype=torch.float32, device=device)
            trackgroup = torch.tensor(trackgroup, dtype=torch.float32, device=device)

            pred_points, _ = model(feature_dict, query_points)
            
            occ.append(occluded)
            tracks.append(pred_points)


0


In [None]:
all_pred_points = torch.cat(tracks)
all_occluded = torch.cat(occ)

print()


place_marker_in_frames(data["video"], all_pred_points, all_occluded)

In [6]:
for name, blocklist in feature_dict.items():
    print(name)
    for map in blocklist:
        print(map.shape)

up_block
torch.Size([50, 10, 8, 8])
torch.Size([50, 10, 16, 16])
torch.Size([50, 10, 32, 32])
torch.Size([50, 10, 32, 32])
down_block
torch.Size([50, 10, 16, 16])
torch.Size([50, 10, 8, 8])
torch.Size([50, 10, 4, 4])
torch.Size([50, 10, 4, 4])
mid_block
torch.Size([50, 10, 4, 4])
decoder_block
torch.Size([50, 10, 64, 64])
torch.Size([50, 10, 128, 128])
torch.Size([50, 10, 256, 256])
torch.Size([50, 10, 256, 256])
