In [2]:
import torch
import numpy as np

def extract_fixed_random_windows(x, w, k):
    """
    Extracts k random windows of size w using the same k starting points for all batch samples.

    Args:
        x (torch.Tensor): Input time-series tensor of shape (batch_size, n_timestamps, n_features).
        w (int): Window size.
        k (int): Number of windows per sample.

    Returns:
        torch.Tensor: Extracted windows of shape (batch_size, k, w, n_features).
    """
    batch_size, n_timestamps, n_features = x.shape

    # Generate k random start indices (same for all samples)
    crop_start = np.random.choice(range(0, n_timestamps - w), size=k, replace=False)   # Shape: (k,)
    print("crop start")
    print(crop_start)

    # Create index tensors
    idx = torch.arange(w).repeat(k, 1)  # Shape: (k, w)
    print(idx)
    crop_start_tensor = torch.tensor(crop_start).unsqueeze(-1)  # Shape: (k, 1)
    print("crop start")
    print(crop_start_tensor)

    # Compute the actual indices
    indices = crop_start_tensor + idx  # Shape: (k, w)
    print("indices")
    print(indices)

    # Expand indices to match batch size
    indices = indices.unsqueeze(0).expand(batch_size, -1, -1)  # Shape: (batch_size, k, w)
    print("indices")
    print(indices)

    # Use advanced indexing to extract the windows
    batch_idx = torch.arange(batch_size).view(batch_size, 1, 1).expand(-1, k, w)
    print(batch_idx)

    extracted_windows = x[batch_idx, indices]  # Shape: (batch_size, k, w, n_features)

    return extracted_windows


import torch
batch_size = 3
n_timestamps = 20
n_features = 5
w = 4  # Window size
k = 2  # Number of windows per sample

# Create a random tensor
x = torch.randn(batch_size, n_timestamps, n_features)

# print("x")
# print(x)

# Extract k windows per sample (same across batch)
windows = extract_fixed_random_windows(x, w, k)
print("windows")
print(windows.shape)  # Expected output: (batch_size, k, w, n_features)




crop start
[10  6]
tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])
crop start
tensor([[10],
        [ 6]])
indices
tensor([[10, 11, 12, 13],
        [ 6,  7,  8,  9]])
indices
tensor([[[10, 11, 12, 13],
         [ 6,  7,  8,  9]],

        [[10, 11, 12, 13],
         [ 6,  7,  8,  9]],

        [[10, 11, 12, 13],
         [ 6,  7,  8,  9]]])
tensor([[[0, 0, 0, 0],
         [0, 0, 0, 0]],

        [[1, 1, 1, 1],
         [1, 1, 1, 1]],

        [[2, 2, 2, 2],
         [2, 2, 2, 2]]])
windows
torch.Size([3, 2, 4, 5])
