In [2]:
import torch

In [3]:
batch_size = 4
num_mark = 2
seq_len = 5

In [4]:
last_event_time = torch.zeros(
            (batch_size, num_mark), dtype=torch.float32
        )
last_event_time

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])

In [5]:
event_seq = torch.randint(
            0, 2, (batch_size, seq_len), dtype=torch.float32
        )
event_seq

tensor([[1., 1., 0., 0., 0.],
        [0., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1.],
        [0., 1., 1., 1., 1.]])

In [6]:
time_seq = torch.randn(
            (batch_size, seq_len), dtype=torch.float32
        ).abs()
time_seq = time_seq.cumsum(dim=1)
time_seq

tensor([[1.4714, 3.1033, 3.2334, 5.0800, 5.5261],
        [0.0943, 1.5425, 2.3193, 2.4992, 2.7180],
        [0.7317, 0.8511, 1.4375, 1.8325, 3.5591],
        [0.1244, 0.5622, 1.8200, 2.2945, 3.5384]])

In [None]:
for mark in range(num_mark):
            mark_mask = (event_seq == mark)  # [batch_size, seq_len]
            if mark_mask.any():
                # Vectorized operation with efficient masking
                masked_times = time_seq.masked_fill(~mark_mask, float("-inf"))
                max_times, _ = masked_times.max(dim=1)
                valid_mask = (max_times != float("-inf"))
                last_event_time[valid_mask, mark] = max_times[valid_mask]
last_event_time

tensor([[5.5261, 3.1033],
        [2.4992, 2.7180],
        [1.8325, 3.5591],
        [0.1244, 3.5384]])

# Advanced indexing techniques

This section shows practical examples of advanced indexing methods in PyTorch for efficient tensor manipulation.

In [39]:
valid_mask = (last_event_time[:, 1]  > 3)
print("valid_mask: ", valid_mask)
type_pred = torch.randint(0, 2, (batch_size,))
print("type_pred:", type_pred, "shape: ", type_pred.shape)
last_event_time[valid_mask, type_pred]

valid_mask:  tensor([ True, False,  True,  True])
type_pred: tensor([0, 0, 1]) shape:  torch.Size([3])


tensor([5.5261, 1.8325, 3.5384])

## Examples of advanced indexing with PyTorch

This notebook demonstrates several advanced indexing techniques to manipulate PyTorch tensors effectively.

### Quick utilities and printing examples

In [None]:
# Create a 3D tensor for examples
data = torch.randn(3, 4, 5)
print("Original tensor shape:", data.shape)
print("Data:\n", data)

# Basic indexing
print("\n1. First element (batch 0):", data[0].shape)
print("2. Last column:", data[:, :, -1].shape)
print("3. Middle slice:", data[:, 1:3, :].shape)

Tenseur original shape: torch.Size([3, 4, 5])
Donn√©es:
 tensor([[[ 2.1499,  0.6631,  0.1087, -0.4316,  1.3128],
         [ 0.6359,  1.3980,  0.8533,  0.3231, -0.8725],
         [ 1.3129,  0.8505, -0.2898,  0.9754,  0.8462],
         [-1.1346,  1.1589, -1.8925, -0.6487, -1.6561]],

        [[ 0.9511, -1.1479,  0.5666, -1.6383, -1.5595],
         [ 0.5956, -0.0333,  0.2996, -0.4183,  1.6169],
         [ 0.3581, -1.0088, -1.7175,  0.5518,  1.6188],
         [ 1.4274,  2.2002,  0.2962,  0.0888, -0.5887]],

        [[ 0.6040, -0.0827, -0.2848, -0.1057, -0.6897],
         [ 0.0200, -0.1890,  0.3926,  0.3901, -0.4661],
         [ 0.6775, -0.2156,  0.9627,  0.0132,  0.1663],
         [-1.5046,  1.9544,  1.2686, -1.6510, -1.8305]]])

1. Premier √©l√©ment (batch 0): torch.Size([4, 5])
2. Derni√®re colonne: torch.Size([3, 4])
3. Slice du milieu: torch.Size([3, 2, 5])


### Demonstration: Tensor shapes and examples

In [None]:
# torch.gather selects elements according to indices
source = torch.tensor([[1, 2, 3, 4], 
                       [5, 6, 7, 8], 
                       [9, 10, 11, 12]])

# Indices to select for each row
indices = torch.tensor([[0, 2], [1, 3], [0, 3]])

print("Source tensor:\n", source)
print("Indices:\n", indices)

# Gather along dim=1 (columns)
gathered = torch.gather(source, dim=1, index=indices)
print("Gather result:\n", gathered)

# Example with event simulation: select the last event of each type
batch_size, seq_len = 3, 5
events = torch.randint(0, 2, (batch_size, seq_len))
times = torch.randn(batch_size, seq_len).abs().cumsum(dim=1)

print("\nEvent example:")
print("Events:\n", events)
print("Times:\n", times)

# Find last index of event type 1 for each batch
last_indices = []
for b in range(batch_size):
    mask = events[b] == 1
    if mask.any():
        last_idx = torch.where(mask)[0][-1]
    else:
        last_idx = torch.tensor(0)  # fallback
    last_indices.append(last_idx)

last_indices = torch.stack(last_indices).unsqueeze(1)
print("Last indices for type 1:\n", last_indices)

# Use gather to retrieve corresponding times
last_times = torch.gather(times, dim=1, index=last_indices)
print("Last times for type 1:\n", last_times)

Source tensor:
 tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
Indices:
 tensor([[0, 2],
        [1, 3],
        [0, 3]])
R√©sultat gather:
 tensor([[ 1,  3],
        [ 6,  8],
        [ 9, 12]])

Exemple avec √©v√©nements:
Events:
 tensor([[0, 1, 0, 1, 0],
        [1, 0, 0, 1, 1],
        [1, 1, 1, 0, 1]])
Times:
 tensor([[0.0917, 1.1058, 2.4856, 3.1408, 4.7153],
        [0.3074, 0.8413, 2.3905, 2.6149, 3.0617],
        [1.5503, 1.8348, 3.1659, 4.4744, 6.1252]])
Derniers indices pour type 1:
 tensor([[3],
        [4],
        [4]])
Derniers temps pour type 1:
 tensor([[3.1408],
        [3.0617],
        [6.1252]])


### Selecting the last event index for a given type

This example finds the index of the last occurrence of a specific event type within each batch.

In [None]:
# Boolean masks for filtering and conditional selection
values = torch.randn(4, 6)
print("Values:\n", values)

# Mask for positive values
positive_mask = values > 0
print("\nPositive mask:\n", positive_mask)

# Select only positive values (returns a 1D tensor)
positive_values = values[positive_mask]
print("Positive values:", positive_values[:10], "...")  # Partial display

# Use masked_fill to replace values
values_filled = values.masked_fill(values < 0, 0.0)
print("\nValues with negatives replaced by 0:\n", values_filled)

# masked_select for more complex filtering
mask_complex = (values > 0) & (values < 1)
selected = values.masked_select(mask_complex)
print("\nValues between 0 and 1:", selected)

# Practical example: mask paddings in a sequence
seq_len = torch.tensor([3, 5, 2, 4])  # Actual lengths for each sequence
max_len = 5
batch_size = seq_len.size(0)

# Create a padding mask
padding_mask = torch.arange(max_len).unsqueeze(0) >= seq_len.unsqueeze(1)
print(f"\nPadding mask (True = padding):\n{padding_mask}")

# Apply mask to data
sequence_data = torch.randn(batch_size, max_len)
masked_data = sequence_data.masked_fill(padding_mask, float('-inf'))
print("Data with padding masked:\n", masked_data)

Valeurs:
 tensor([[ 1.4011,  1.5079, -1.0909,  0.4197,  0.3895,  0.5238],
        [-2.1130,  0.4523, -0.9354, -1.7169, -0.6213, -0.8477],
        [-0.8427,  0.7809,  0.7720, -1.2123,  1.2467,  0.7433],
        [-0.5526,  2.0318, -0.9890, -2.5797,  0.3550, -0.7316]])

Masque positif:
 tensor([[ True,  True, False,  True,  True,  True],
        [False,  True, False, False, False, False],
        [False,  True,  True, False,  True,  True],
        [False,  True, False, False,  True, False]])
Valeurs positives: tensor([1.4011, 1.5079, 0.4197, 0.3895, 0.5238, 0.4523, 0.7809, 0.7720, 1.2467,
        0.7433]) ...

Valeurs avec n√©gatifs remplac√©s par 0:
 tensor([[1.4011, 1.5079, 0.0000, 0.4197, 0.3895, 0.5238],
        [0.0000, 0.4523, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7809, 0.7720, 0.0000, 1.2467, 0.7433],
        [0.0000, 2.0318, 0.0000, 0.0000, 0.3550, 0.0000]])

Valeurs entre 0 et 1: tensor([0.4197, 0.3895, 0.5238, 0.4523, 0.7809, 0.7720, 0.7433, 0.3550])

Masque de pad

## 4. Indexation avec torch.index_select et torch.take

In [None]:
# torch.index_select: select slices along a dimension
matrix = torch.randn(4, 6)
print("Original matrix:\n", matrix)

# Select specific rows
row_indices = torch.tensor([0, 2, 3])
selected_rows = torch.index_select(matrix, dim=0, index=row_indices)
print("\nSelected rows [0, 2, 3]:\n", selected_rows)

# Select specific columns
col_indices = torch.tensor([1, 3, 5])
selected_cols = torch.index_select(matrix, dim=1, index=col_indices)
print("\nSelected columns [1, 3, 5]:\n", selected_cols)

# torch.take treats the tensor as a flat array
flat_tensor = torch.arange(12).reshape(3, 4)
print("\nFlat tensor:\n", flat_tensor)

# Absolute position indices
indices = torch.tensor([0, 5, 7, 11])
taken = torch.take(flat_tensor, indices)
print("Elements at positions [0, 5, 7, 11]:", taken)

# Practical example: select embeddings
vocab_size, embedding_dim = 1000, 128
embeddings = torch.randn(vocab_size, embedding_dim)

# Word index sequences
word_indices = torch.tensor([[1, 5, 23, 7], [45, 2, 8, 12]])
batch_size, seq_length = word_indices.shape

# Select corresponding embeddings
selected_embeddings = torch.index_select(embeddings, dim=0, 
                                        index=word_indices.flatten())
selected_embeddings = selected_embeddings.view(batch_size, seq_length, embedding_dim)

print(f"\nSelected embeddings shape: {selected_embeddings.shape}")
print("First embedding of the first batch:", selected_embeddings[0, 0, :5])

Matrice originale:
 tensor([[ 1.4924, -1.5938, -0.9453,  1.2735,  0.4558,  0.8859],
        [-0.3927,  0.4843,  0.2566, -0.1102, -1.2561,  0.9141],
        [ 0.6275,  0.5636,  1.0326, -0.5624,  0.0315,  1.9188],
        [-0.3732, -0.0578,  0.3002,  2.2216, -0.0849, -0.6441]])

Lignes s√©lectionn√©es [0, 2, 3]:
 tensor([[ 1.4924, -1.5938, -0.9453,  1.2735,  0.4558,  0.8859],
        [ 0.6275,  0.5636,  1.0326, -0.5624,  0.0315,  1.9188],
        [-0.3732, -0.0578,  0.3002,  2.2216, -0.0849, -0.6441]])

Colonnes s√©lectionn√©es [1, 3, 5]:
 tensor([[-1.5938,  1.2735,  0.8859],
        [ 0.4843, -0.1102,  0.9141],
        [ 0.5636, -0.5624,  1.9188],
        [-0.0578,  2.2216, -0.6441]])

Tenseur plat:
 tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
√âl√©ments pris aux positions [0, 5, 7, 11]: tensor([ 0,  5,  7, 11])

Embeddings s√©lectionn√©s shape: torch.Size([2, 4, 128])
Premier embedding du premier batch: tensor([ 1.5345, -0.7401,  0.4479,  0.8496,  0.0

### Boolean masks for filtering and conditional selection

Use boolean masks to filter values or select subsets without explicit Python loops.

In [None]:
# Indexing with multiple dimensions at once
batch_size, num_classes, seq_len = 2, 3, 4
tensor_3d = torch.randn(batch_size, num_classes, seq_len)
print("3D tensor shape:", tensor_3d.shape)
print("Data:\n", tensor_3d)

# Fancy indexing with tensors
batch_idx = torch.tensor([0, 1, 0, 1])  # batch indices
class_idx = torch.tensor([1, 2, 0, 1])  # class indices
seq_idx = torch.tensor([0, 1, 2, 3])    # sequence indices

# Select specific elements
selected_elements = tensor_3d[batch_idx, class_idx, seq_idx]
print("\nSelected elements:", selected_elements)

# Use arange for indexing
batch_range = torch.arange(batch_size)
print(f"Batch range: {batch_range}")

# Get max per batch/class across sequence length
max_values, max_indices = tensor_3d.max(dim=2)  # max over seq_len
print(f"\nMax values shape: {max_values.shape}")
print("Max values:\n", max_values)
print("Max indices:\n", max_indices)

# Use indices to retrieve original values
batch_expanded = batch_range.unsqueeze(1).expand(-1, num_classes)
class_expanded = torch.arange(num_classes).unsqueeze(0).expand(batch_size, -1)

recovered_values = tensor_3d[batch_expanded, class_expanded, max_indices]
print("\nRecovered values (verification):\n", recovered_values)
print("Equality:", torch.allclose(max_values, recovered_values))

Tenseur 3D shape: torch.Size([2, 3, 4])
Donn√©es:
 tensor([[[ 1.2225,  0.0716,  0.3245, -0.3487],
         [-1.2330,  0.8190,  0.9548,  1.0756],
         [ 0.3341,  0.4133, -0.0394,  0.3766]],

        [[ 0.6936, -0.5983, -0.5106,  0.8924],
         [ 0.6499,  0.7171, -0.9614,  0.5521],
         [-0.2948,  0.6610, -0.2536, -0.1455]]])

√âl√©ments s√©lectionn√©s: tensor([-1.2330,  0.6610,  0.3245,  0.5521])
Range batch: tensor([0, 1])

Max values shape: torch.Size([2, 3])
Max values:
 tensor([[1.2225, 1.0756, 0.4133],
        [0.8924, 0.7171, 0.6610]])
Max indices:
 tensor([[0, 3, 1],
        [3, 1, 1]])

Valeurs r√©cup√©r√©es (v√©rification):
 tensor([[1.2225, 1.0756, 0.4133],
        [0.8924, 0.7171, 0.6610]])
√âgalit√©: True


### Practical example: masking padding in sequences

In [None]:
# Practical applications for point processes
batch_size, max_seq_len, num_event_types = 3, 8, 4

# Simulate events with variable lengths
seq_lengths = torch.tensor([5, 7, 6])
event_types = torch.randint(0, num_event_types, (batch_size, max_seq_len))
event_times = torch.randn(batch_size, max_seq_len).abs().cumsum(dim=1)

print("Event types:\n", event_types)
print("Event times:\n", event_times)

# 1. Mask padding events
padding_mask = torch.arange(max_seq_len).unsqueeze(0) < seq_lengths.unsqueeze(1)
valid_events = event_types * padding_mask.long()  # 0 for paddings
valid_times = event_times.masked_fill(~padding_mask, 0.0)

print("\nValid events (with mask):\n", valid_events)
print("Valid times:\n", valid_times)

# 2. Find last event of each type for each sequence
last_event_times = torch.zeros(batch_size, num_event_types)

for batch_idx in range(batch_size):
    for event_type in range(num_event_types):
        # Mask for this event type in this sequence
        type_mask = (event_types[batch_idx] == event_type) & padding_mask[batch_idx]
        
        if type_mask.any():
            # Find the index of the last event of this type
            last_idx = torch.where(type_mask)[0][-1]
            last_event_times[batch_idx, event_type] = event_times[batch_idx, last_idx]

print("\nLast times per event type:\n", last_event_times)

# 3. More efficient vectorized version
def get_last_event_times_vectorized(event_types, event_times, padding_mask, num_event_types):
    batch_size, seq_len = event_types.shape
    last_times = torch.zeros(batch_size, num_event_types)
    
    for event_type in range(num_event_types):
        # Mask for this event type
        type_mask = (event_types == event_type) & padding_mask
        
        # Create a tensor with -inf where there are no events of this type
        masked_times = event_times.masked_fill(~type_mask, float('-inf'))
        
        # Take the maximum (last time) per batch
        max_times, _ = masked_times.max(dim=1)
        
        # Replace -inf by 0 where there was no event of that type
        valid_mask = max_times != float('-inf')
        last_times[valid_mask, event_type] = max_times[valid_mask]
    
    return last_times

last_times_vec = get_last_event_times_vectorized(event_types, event_times, padding_mask, num_event_types)
print("\nVectorized version:\n", last_times_vec)
print("Equality with naive version:", torch.allclose(last_event_times, last_times_vec))

Types d'√©v√©nements:
 tensor([[1, 3, 1, 2, 2, 3, 3, 3],
        [3, 1, 3, 1, 2, 0, 3, 0],
        [2, 1, 3, 2, 1, 2, 0, 1]])
Temps d'√©v√©nements:
 tensor([[1.0120, 1.7871, 3.0567, 4.7263, 5.7100, 6.1736, 6.5166, 6.6035],
        [0.2779, 0.9679, 4.2464, 4.7372, 4.9811, 7.8192, 8.9795, 9.7267],
        [0.2971, 1.2268, 2.4551, 2.5715, 4.0258, 4.2429, 5.2187, 5.3796]])

√âv√©nements valides (avec masque):
 tensor([[1, 3, 1, 2, 2, 0, 0, 0],
        [3, 1, 3, 1, 2, 0, 3, 0],
        [2, 1, 3, 2, 1, 2, 0, 0]])
Temps valides:
 tensor([[1.0120, 1.7871, 3.0567, 4.7263, 5.7100, 0.0000, 0.0000, 0.0000],
        [0.2779, 0.9679, 4.2464, 4.7372, 4.9811, 7.8192, 8.9795, 0.0000],
        [0.2971, 1.2268, 2.4551, 2.5715, 4.0258, 4.2429, 0.0000, 0.0000]])

Derniers temps par type d'√©v√©nement:
 tensor([[0.0000, 3.0567, 5.7100, 1.7871],
        [7.8192, 4.7372, 4.9811, 8.9795],
        [0.0000, 4.0258, 4.2429, 2.4551]])

Version vectoris√©e:
 tensor([[0.0000, 3.0567, 5.7100, 1.7871],
        [7.8192

### Applied examples for point processes

Examples applying the indexing techniques to temporal point processes (TPPs).

In [None]:
# Advanced techniques to optimize performance

# 1. Avoid Python loops using broadcasting
def naive_distance_matrix(points):
    """Naive version with loops (SLOW)"""
    n = points.shape[0]
    distances = torch.zeros(n, n)
    for i in range(n):
        for j in range(n):
            distances[i, j] = torch.norm(points[i] - points[j])
    return distances


def vectorized_distance_matrix(points):
    """Vectorized version (FAST)"""
    # points shape: [n, d]
    # Using broadcasting
    diff = points.unsqueeze(1) - points.unsqueeze(0)  # [n, n, d]
    distances = torch.norm(diff, dim=2)  # [n, n]
    return distances

# Test with random points
points = torch.randn(100, 3)

# Timing
import time

start = time.time()
dist_vectorized = vectorized_distance_matrix(points)
time_vectorized = time.time() - start

print(f"Vectorized time: {time_vectorized:.4f}s")
print(f"Result shape: {dist_vectorized.shape}")

# 2. Efficient indexing with scatter operations
batch_size, num_bins = 4, 10
values = torch.randn(batch_size, 20)  # 20 values per batch
bin_indices = torch.randint(0, num_bins, (batch_size, 20))

print("\nValues to bin:", values[0, :5])
print("Bin indices:", bin_indices[0, :5])

# Sum values per bin with scatter_add
bin_sums = torch.zeros(batch_size, num_bins)
bin_sums.scatter_add_(1, bin_indices, values)

print("Sums per bin:\n", bin_sums)

# 3. Memory optimization with view and squeeze/unsqueeze
large_tensor = torch.randn(1000, 1000)
print(f"Original tensor shape: {large_tensor.shape}")

# View to reshape without copying
reshaped = large_tensor.view(100, 10, 100, 10)
print(f"Reshaped shape: {reshaped.shape}")
print(f"Same memory: {reshaped.data_ptr() == large_tensor.data_ptr()}")

# Permute + contiguous to rearrange efficiently
permuted = reshaped.permute(0, 2, 1, 3)  # Change dimension order
flattened = permuted.contiguous().view(100, 100, 100)  # Needs contiguous()
print(f"Final shape: {flattened.shape}")

# 4. Indexing with topk for efficient selection
scores = torch.randn(5, 1000)  # 5 samples, 1000 features each
k = 10

# Select top-k scores for each sample
top_values, top_indices = torch.topk(scores, k, dim=1)
print(f"\nTop-{k} values shape: {top_values.shape}")
print(f"Top-{k} indices shape: {top_indices.shape}")

# Use indices to retrieve other information
features = torch.randn(5, 1000, 64)  # Associated features
top_features = torch.gather(features, 1, 
                           top_indices.unsqueeze(-1).expand(-1, -1, 64))
print(f"Top features shape: {top_features.shape}")

print("\n=== Optimization tips ===")
print("1. Use view() instead of reshape() when possible")
print("2. Avoid Python loops; prefer vectorized operations")
print("3. Use scatter/gather for grouped operations")
print("4. masked_fill is faster than conditional indexing")
print("5. topk is optimized for selecting top elements")

Temps vectoris√©: 0.0014s
Shape r√©sultat: torch.Size([100, 100])

Valeurs √† grouper par bins: tensor([ 0.9677,  0.8307, -0.6004, -0.0969, -0.9670])
Indices de bins: tensor([2, 3, 2, 9, 3])
Sommes par bin:
 tensor([[ 2.1389e+00,  0.0000e+00,  4.1902e+00,  1.9977e-03,  1.0681e+00,
          2.3102e-01, -7.2568e-01,  5.6631e-01,  0.0000e+00, -1.8103e-01],
        [ 0.0000e+00,  2.7782e+00, -1.4263e+00,  0.0000e+00,  6.3193e-01,
          3.1381e+00, -1.1786e-01,  4.7861e-01,  4.2729e-01,  2.0775e+00],
        [-2.1165e+00,  1.6351e+00,  4.4144e-01, -5.4236e-01,  5.4693e-01,
          0.0000e+00,  4.1478e-01, -1.1112e-01,  3.2264e+00, -3.7155e-01],
        [ 1.4515e+00, -1.0545e+00,  8.9650e-01,  2.7579e+00, -1.7292e-01,
         -8.0008e-01,  1.1697e-01, -1.7198e+00, -1.1725e+00, -1.5216e-01]])
Tenseur original shape: torch.Size([1000, 1000])
Reshaped shape: torch.Size([100, 10, 100, 10])
M√™me m√©moire: True
Final shape: torch.Size([100, 100, 100])

Top-10 values shape: torch.Size([5, 

## Vectorized version (fast)

This section contains an optimized, vectorized implementation and a timing comparison with a naive version.

In [16]:
# Comparaison d√©taill√©e entre gather, index_select et take
import torch

# Cr√©ons un tenseur de r√©f√©rence pour tous les exemples
data = torch.tensor([
    [10, 11, 12, 13, 14],
    [20, 21, 22, 23, 24],
    [30, 31, 32, 33, 34],
    [40, 41, 42, 43, 44]
])
print("Tenseur de r√©f√©rence (4x5):")
print(data)
print("Shape:", data.shape)

print("\n" + "="*60)
print("1. TORCH.GATHER - Indexation flexible par dimension")
print("="*60)

# gather: s√©lectionne des √©l√©ments selon des indices, en pr√©servant la structure
indices_gather = torch.tensor([
    [1, 3, 0, 4],  # Pour la ligne 0: colonnes 1,3,0,4
    [4, 2, 1, 0],  # Pour la ligne 1: colonnes 4,2,1,0  
    [2, 2, 2, 2],  # Pour la ligne 2: colonne 2 r√©p√©t√©e
    [0, 1, 2, 3]   # Pour la ligne 3: colonnes 0,1,2,3
])

print("Indices pour gather:")
print(indices_gather)

gathered = torch.gather(data, dim=1, index=indices_gather)
print("R√©sultat gather (dim=1):")
print(gathered)
print("Shape:", gathered.shape)

print("\nCaract√©ristiques de gather:")
print("- Pr√©serve le nombre de dimensions")
print("- Les indices peuvent √™tre diff√©rents pour chaque 'ligne' (batch)")
print("- Shape du r√©sultat = shape des indices")
print("- Tr√®s flexible pour l'indexation par batch")

Tenseur de r√©f√©rence (4x5):
tensor([[10, 11, 12, 13, 14],
        [20, 21, 22, 23, 24],
        [30, 31, 32, 33, 34],
        [40, 41, 42, 43, 44]])
Shape: torch.Size([4, 5])

1. TORCH.GATHER - Indexation flexible par dimension
Indices pour gather:
tensor([[1, 3, 0, 4],
        [4, 2, 1, 0],
        [2, 2, 2, 2],
        [0, 1, 2, 3]])
R√©sultat gather (dim=1):
tensor([[11, 13, 10, 14],
        [24, 22, 21, 20],
        [32, 32, 32, 32],
        [40, 41, 42, 43]])
Shape: torch.Size([4, 4])

Caract√©ristiques de gather:
- Pr√©serve le nombre de dimensions
- Les indices peuvent √™tre diff√©rents pour chaque 'ligne' (batch)
- Shape du r√©sultat = shape des indices
- Tr√®s flexible pour l'indexation par batch


In [17]:
print("\n" + "="*60)
print("2. TORCH.INDEX_SELECT - S√©lection de tranches enti√®res")
print("="*60)

# index_select: s√©lectionne des lignes/colonnes ENTI√àRES selon des indices
indices_rows = torch.tensor([0, 2, 3])  # S√©lectionner lignes 0, 2, 3
selected_rows = torch.index_select(data, dim=0, index=indices_rows)
print("S√©lection de lignes [0, 2, 3]:")
print(selected_rows)
print("Shape:", selected_rows.shape)

indices_cols = torch.tensor([1, 4, 2])  # S√©lectionner colonnes 1, 4, 2  
selected_cols = torch.index_select(data, dim=1, index=indices_cols)
print("\nS√©lection de colonnes [1, 4, 2]:")
print(selected_cols)
print("Shape:", selected_cols.shape)

print("\nCaract√©ristiques d'index_select:")
print("- S√©lectionne des tranches COMPL√àTES selon une dimension")
print("- Les indices s'appliquent √† TOUTES les lignes/colonnes uniform√©ment")
print("- Plus simple que gather mais moins flexible")
print("- √âquivalent au slicing avanc√©: data[[0,2,3], :] ou data[:, [1,4,2]]")


2. TORCH.INDEX_SELECT - S√©lection de tranches enti√®res
S√©lection de lignes [0, 2, 3]:
tensor([[10, 11, 12, 13, 14],
        [30, 31, 32, 33, 34],
        [40, 41, 42, 43, 44]])
Shape: torch.Size([3, 5])

S√©lection de colonnes [1, 4, 2]:
tensor([[11, 14, 12],
        [21, 24, 22],
        [31, 34, 32],
        [41, 44, 42]])
Shape: torch.Size([4, 3])

Caract√©ristiques d'index_select:
- S√©lectionne des tranches COMPL√àTES selon une dimension
- Les indices s'appliquent √† TOUTES les lignes/colonnes uniform√©ment
- Plus simple que gather mais moins flexible
- √âquivalent au slicing avanc√©: data[[0,2,3], :] ou data[:, [1,4,2]]


In [18]:
print("\n" + "="*60)
print("3. TORCH.TAKE - Indexation en tableau plat")
print("="*60)

# take: traite le tenseur comme un vecteur plat (1D)
print("Tenseur aplati conceptuellement:")
flat_view = data.flatten()
print("Position: ", list(range(len(flat_view))))
print("Valeur:   ", flat_view.tolist())

indices_take = torch.tensor([0, 5, 7, 12, 19])  # Positions absolues dans le tableau plat
taken = torch.take(data, indices_take)
print(f"\nIndices take: {indices_take.tolist()}")
print("R√©sultat take:", taken)
print("Shape:", taken.shape)

print("\nV√©rification manuelle:")
for i, idx in enumerate(indices_take):
    row = idx // 5  # Division enti√®re pour trouver la ligne
    col = idx % 5   # Modulo pour trouver la colonne
    print(f"Index {idx} -> position [{row}, {col}] -> valeur {data[row, col]}")

print("\nCaract√©ristiques de take:")
print("- Traite le tenseur comme un tableau 1D (ordre row-major)")
print("- Indices en positions absolutes")
print("- R√©sultat toujours 1D")
print("- Utile pour indexation sparse ou √©chantillonnage al√©atoire")


3. TORCH.TAKE - Indexation en tableau plat
Tenseur aplati conceptuellement:
Position:  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Valeur:    [10, 11, 12, 13, 14, 20, 21, 22, 23, 24, 30, 31, 32, 33, 34, 40, 41, 42, 43, 44]

Indices take: [0, 5, 7, 12, 19]
R√©sultat take: tensor([10, 20, 22, 32, 44])
Shape: torch.Size([5])

V√©rification manuelle:
Index 0 -> position [0, 0] -> valeur 10
Index 5 -> position [1, 0] -> valeur 20
Index 7 -> position [1, 2] -> valeur 22
Index 12 -> position [2, 2] -> valeur 32
Index 19 -> position [3, 4] -> valeur 44

Caract√©ristiques de take:
- Traite le tenseur comme un tableau 1D (ordre row-major)
- Indices en positions absolutes
- R√©sultat toujours 1D
- Utile pour indexation sparse ou √©chantillonnage al√©atoire


In [19]:
print("\n" + "="*60)
print("4. COMPARAISON PRATIQUE - M√™me objectif, approches diff√©rentes")
print("="*60)

print("OBJECTIF: S√©lectionner la 2e et 4e colonne de toutes les lignes\n")

# M√©thode 1: avec index_select (LE PLUS SIMPLE)
cols_wanted = torch.tensor([1, 3])
result_index_select = torch.index_select(data, dim=1, index=cols_wanted)
print("1. Avec index_select:")
print(result_index_select)

# M√©thode 2: avec gather (PLUS VERBEUX mais plus flexible)
# Il faut r√©p√©ter les indices pour chaque ligne
gather_indices = cols_wanted.unsqueeze(0).expand(data.shape[0], -1)
print(f"\nIndices pour gather (r√©p√©t√©s): \n{gather_indices}")
result_gather = torch.gather(data, dim=1, index=gather_indices)
print("2. Avec gather (√©quivalent):")
print(result_gather)

# M√©thode 3: avec take (COMPLIQU√â)
# Il faut calculer les positions absolutes
take_indices = []
for row in range(data.shape[0]):
    for col in [1, 3]:  # colonnes 1 et 3
        take_indices.append(row * data.shape[1] + col)
take_indices = torch.tensor(take_indices)
result_take = torch.take(data, take_indices).reshape(data.shape[0], 2)
print("3. Avec take (plus complexe):")
print(result_take)

print(f"\nV√©rification - tous √©gaux: {torch.equal(result_index_select, result_gather) and torch.equal(result_gather, result_take)}")


4. COMPARAISON PRATIQUE - M√™me objectif, approches diff√©rentes
OBJECTIF: S√©lectionner la 2e et 4e colonne de toutes les lignes

1. Avec index_select:
tensor([[11, 13],
        [21, 23],
        [31, 33],
        [41, 43]])

Indices pour gather (r√©p√©t√©s): 
tensor([[1, 3],
        [1, 3],
        [1, 3],
        [1, 3]])
2. Avec gather (√©quivalent):
tensor([[11, 13],
        [21, 23],
        [31, 33],
        [41, 43]])
3. Avec take (plus complexe):
tensor([[11, 13],
        [21, 23],
        [31, 33],
        [41, 43]])

V√©rification - tous √©gaux: True


In [None]:
print("\n" + "="*60)
print("5. COMMON USAGE CASES")
print("="*60)

print("üìå GATHER - When to use:")
print("‚úì Different indexing per batch (e.g., last events)")
print("‚úì Select elements based on dynamic criteria")
print("‚úì Extract values using computed indices")
print("‚úì Batch-wise processing with variable indices")

print("\nüìå INDEX_SELECT - When to use:")
print("‚úì Select entire rows/columns")
print("‚úì Uniform subsampling")
print("‚úì Reorganizing dimensions")
print("‚úì When all batches share the same indices")

print("\nüìå TAKE - When to use:")
print("‚úì Random sampling of positions")
print("‚úì Sparse indexing on flattened tensors")
print("‚úì Convert 2D indices to 1D")
print("‚úì Select non-structured elements")

# Practical example for point processes
print("\n" + "="*60)
print("6. POINT PROCESS EXAMPLE")
print("="*60)

batch_size, seq_len, num_types = 3, 5, 2
events = torch.randint(0, num_types, (batch_size, seq_len))
times = torch.randn(batch_size, seq_len).abs().cumsum(dim=1)

print("Events:")
print(events)
print("Times:")
print(times)

# Case 1: GATHER - Get the time of the 3rd event of each batch
indices_3rd = torch.tensor([[2], [2], [2]])  # 3rd position for each batch
times_3rd_gather = torch.gather(times, dim=1, index=indices_3rd)
print(f"\n1. GATHER - 3rd time of each batch: {times_3rd_gather.flatten()}")

# Case 2: INDEX_SELECT - Take the 2nd and 4th times for all batches
positions = torch.tensor([1, 3])  # positions 2 and 4
times_selected = torch.index_select(times, dim=1, index=positions)
print("2. INDEX_SELECT - positions 2 and 4 for all:")
print(times_selected)

# Case 3: TAKE - Retrieve specific sparse elements
sparse_indices = torch.tensor([1, 7, 12])  # absolute positions in the flat array
times_sparse = torch.take(times, sparse_indices)
print(f"3. TAKE - elements at absolute positions {sparse_indices.tolist()}: {times_sparse}")


5. CAS D'USAGE TYPIQUES
üìå GATHER - Quand utiliser:
‚úì Indexation diff√©rente par batch (ex: derniers √©v√©nements)
‚úì S√©lection de √©l√©ments selon des crit√®res dynamiques
‚úì Extraction de valeurs selon des indices calcul√©s
‚úì Traitement par batch avec indices variables

üìå INDEX_SELECT - Quand utiliser:
‚úì S√©lection de lignes/colonnes compl√®tes
‚úì Sous-√©chantillonnage uniforme
‚úì R√©organisation de dimensions
‚úì Quand tous les batches ont les m√™mes indices

üìå TAKE - Quand utiliser:
‚úì √âchantillonnage al√©atoire de positions
‚úì Indexation sparse sur tenseurs aplatis
‚úì Conversion d'indices 2D vers 1D
‚úì S√©lection d'√©l√©ments non-structur√©e

6. EXEMPLE PROCESSUS PONCTUELS
√âv√©nements:
tensor([[0, 1, 0, 0, 1],
        [1, 0, 1, 1, 0],
        [1, 0, 0, 1, 0]])
Temps:
tensor([[1.7946, 2.5892, 3.1408, 4.6366, 4.9139],
        [1.1127, 1.3986, 3.1231, 4.1274, 5.9714],
        [0.2518, 0.3949, 1.9774, 2.6786, 3.4462]])

1. GATHER - 3e temps de chaque batch: t

In [None]:
print("\n" + "="*60)
print("7. SUMMARY TABLE")
print("="*60)

print("""
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ Characteristic  ‚îÇ    GATHER    ‚îÇ INDEX_SELECT ‚îÇ     TAKE     ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ Flexibility     ‚îÇ    High      ‚îÇ   Medium     ‚îÇ    Low       ‚îÇ
‚îÇ Complexity      ‚îÇ   Medium     ‚îÇ   Low        ‚îÇ    High      ‚îÇ
‚îÇ Result shape    ‚îÇ = shape idx  ‚îÇ Preserved    ‚îÇ  Always 1D   ‚îÇ
‚îÇ Indices/batch   ‚îÇ Varying      ‚îÇ  Identical   ‚îÇ   Absolute   ‚îÇ
‚îÇ Dimensions      ‚îÇ Preserved    ‚îÇ Dim-1 free   ‚îÇ  Flattened   ‚îÇ
‚îÇ Performance     ‚îÇ   Fast       ‚îÇ   Fast       ‚îÇ   Fast       ‚îÇ
‚îÇ Use cases       ‚îÇ Batch varied ‚îÇ Advanced     ‚îÇ Sparse/1D    ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
""")

print("üéØ GENERAL RULE:")
print("‚Ä¢ Use INDEX_SELECT for simple uniform selections")
print("‚Ä¢ Use GATHER for batch-variable selections") 
print("‚Ä¢ Use TAKE for 1D or very specific indexing")

print("\n‚ú® PERFORMANCE TIP:")
print("‚Ä¢ For simple selections: normal slicing [indices] > index_select > gather")
print("‚Ä¢ For complex batch-wise selections: gather is optimal")
print("‚Ä¢ For sparse access: take can be useful but often there's a better approach")


7. TABLEAU R√âCAPITULATIF

‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ Caract√©ristique ‚îÇ    GATHER    ‚îÇ INDEX_SELECT ‚îÇ     TAKE     ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ Flexibilit√©     ‚îÇ    Haute     ‚îÇ   Moyenne    ‚îÇ    Faible    ‚îÇ
‚îÇ Complexit√©      ‚îÇ   Moyenne    ‚îÇ   Faible     ‚îÇ    Haute     ‚îÇ
‚îÇ Shape r√©sultat  ‚îÇ = shape idx  ‚îÇ Pr√©serv√©e    ‚îÇ  Toujours 1D ‚îÇ
‚îÇ Indices/batch   ‚îÇ Diff√©rents   ‚îÇ  Identiques  ‚îÇ   Absolus    ‚îÇ
‚îÇ Dimensions      ‚îÇ Pr√©serv√©es   ‚îÇ Dim-1 libre  ‚îÇ  Aplaties    ‚îÇ
‚îÇ Performance     ‚îÇ   Rapide     ‚îÇ   Rapide     ‚îÇ   Rapide     ‚îÇ
‚îÇ Cas d'usage     ‚îÇ Batch vari√©  ‚îÇ Slicing++ 