# Use this notebook to time the execution of two functions

In [1]:
import timeit

import torch
import torch.nn.functional as F

## Argmin on CPU or GPU

In [13]:
from typing import List


def get_argmin_mapping_list_cpu(timestamps_in_scales: List[torch.Tensor]) -> List[torch.Tensor]:
    """Calculate the mapping between the base scale and other scales.

    A segment from a longer scale is repeatedly mapped to a segment from a shorter scale or the base scale.

    Args:
        timestamps_in_scales (list):
            List containing timestamp tensors for each scale.
            Each tensor has dimensions of (Number of base segments) x 2.

    Returns:
        session_scale_mapping_list (list):
            List containing argmin arrays indexed by scale index.
    """
    scale_list = list(range(len(timestamps_in_scales)))
    segment_anchor_list = [torch.mean(timestamps_in_scales[scale_idx], dim=1) for scale_idx in scale_list]

    base_scale_anchor = segment_anchor_list[max(scale_list)].view(-1, 1)

    session_scale_mapping_list = []
    for scale_idx in scale_list:
        current_scale_anchor = segment_anchor_list[scale_idx].view(1, -1)
        distance = torch.abs(current_scale_anchor - base_scale_anchor)
        argmin_mat = torch.argmin(distance, dim=1)
        session_scale_mapping_list.append(argmin_mat)

    return session_scale_mapping_list

def get_argmin_mapping_list_gpu(timestamps_in_scales: List[torch.Tensor]) -> List[torch.Tensor]:
    """Calculate the mapping between the base scale and other scales.

    A segment from a longer scale is repeatedly mapped to a segment from a shorter scale or the base scale.

    Args:
        timestamps_in_scales (list):
            List containing timestamp tensors for each scale.
            Each tensor has dimensions of (Number of base segments) x 2.

    Returns:
        session_scale_mapping_list (list):
            List containing argmin arrays indexed by scale index.
    """
    timestamps_in_scales = [x.to('cuda') for x in timestamps_in_scales]
    
    scale_list = list(range(len(timestamps_in_scales)))
    segment_anchor_list = [torch.mean(timestamps_in_scales[scale_idx], dim=1) for scale_idx in scale_list]

    base_scale_anchor = segment_anchor_list[max(scale_list)].view(-1, 1)

    session_scale_mapping_list = []
    for scale_idx in scale_list:
        current_scale_anchor = segment_anchor_list[scale_idx].view(1, -1)
        distance = torch.abs(current_scale_anchor - base_scale_anchor)
        argmin_mat = torch.argmin(distance, dim=1)
        session_scale_mapping_list.append(argmin_mat)

    return session_scale_mapping_list

# Generate a random list of timestamps for each scale (each scale has a different number of timestamps)
scales = [52, 89, 123, 196, 254]
timestamps_in_scales = [torch.rand((s, 2)) for s in scales]

num_passes = 10
time_taken_cpu = timeit.timeit(
    lambda: get_argmin_mapping_list_cpu(timestamps_in_scales), number=num_passes,
)
time_taken_gpu = timeit.timeit(
    lambda: get_argmin_mapping_list_gpu(timestamps_in_scales), number=num_passes,
)

print("Time taken on CPU ({} passes): {:.6f} seconds".format(num_passes, time_taken_cpu))
print("Time taken on GPU ({} passes): {:.6f} seconds".format(num_passes, time_taken_gpu))

Time taken on CPU (10 passes): 0.014017 seconds
Time taken on GPU (10 passes): 2.575309 seconds
