In [1]:
import torch


In [2]:
# Create two tensors
t1 = torch.tensor([[1, 2], 
                   [3, 4]], dtype=torch.float32)

t2 = torch.tensor([[5, 6], 
                   [7, 8]], dtype=torch.float32)

print("Tensor t1:")
print(t1)
print(f"Shape: {t1.shape}")
print()

print("Tensor t2:")
print(t2)
print(f"Shape: {t2.shape}")
print()

# Concatenate along dimension 0 (stack vertically)
t_cat0 = torch.cat([t1, t2], dim=0)
print("torch.cat([t1, t2], dim=0) - concatenate along rows:")
print(t_cat0)
print(f"Shape: {t_cat0.shape}")
print()

# Concatenate along dimension 1 (stack horizontally)
t_cat1 = torch.cat([t1, t2], dim=1)
print("torch.cat([t1, t2], dim=1) - concatenate along columns:")
print(t_cat1)
print(f"Shape: {t_cat1.shape}")
print()

# Concatenate multiple tensors
t3 = torch.tensor([[9, 10], 
                   [11, 12]], dtype=torch.float32)
t_cat_multiple = torch.cat([t1, t2, t3], dim=0)
print("Concatenate 3 tensors:")
print(t_cat_multiple)
print(f"Shape: {t_cat_multiple.shape}")

Tensor t1:
tensor([[1., 2.],
        [3., 4.]])
Shape: torch.Size([2, 2])

Tensor t2:
tensor([[5., 6.],
        [7., 8.]])
Shape: torch.Size([2, 2])

torch.cat([t1, t2], dim=0) - concatenate along rows:
tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]])
Shape: torch.Size([4, 2])

torch.cat([t1, t2], dim=1) - concatenate along columns:
tensor([[1., 2., 5., 6.],
        [3., 4., 7., 8.]])
Shape: torch.Size([2, 4])

Concatenate 3 tensors:
tensor([[ 1.,  2.],
        [ 3.,  4.],
        [ 5.,  6.],
        [ 7.,  8.],
        [ 9., 10.],
        [11., 12.]])
Shape: torch.Size([6, 2])


In [3]:
# Create two tensors
t1 = torch.tensor([[1, 2], 
                   [3, 4]], dtype=torch.float32)

t2 = torch.tensor([[5, 6], 
                   [7, 8]], dtype=torch.float32)

print("Tensor t1:")
print(t1)
print(f"Shape: {t1.shape}")
print()

print("Tensor t2:")
print(t2)
print(f"Shape: {t2.shape}")
print()

# Stack along dimension 0 (new first dimension)
t_stack0 = torch.stack([t1, t2], dim=0)
print("torch.stack([t1, t2], dim=0):")
print(t_stack0)
print(f"Shape: {t_stack0.shape}")
print("(New dimension 0 created!)")
print()

# Stack along dimension 1
t_stack1 = torch.stack([t1, t2], dim=1)
print("torch.stack([t1, t2], dim=1):")
print(t_stack1)
print(f"Shape: {t_stack1.shape}")
print()

# Stack along dimension 2
t_stack2 = torch.stack([t1, t2], dim=2)
print("torch.stack([t1, t2], dim=2):")
print(t_stack2)
print(f"Shape: {t_stack2.shape}")
print()
# Compare: cat vs stack
print("=" * 60)
print("COMPARISON: cat vs stack")
print("=" * 60)
print("cat: Concatenates along existing dimension")
print(f"  cat shape: {torch.cat([t1, t2], dim=0).shape}")
print()
print("stack: Creates new dimension and stacks")
print(f"  stack shape: {torch.stack([t1, t2], dim=0).shape}")
print()
print("Difference: stack adds a new dimension!")

Tensor t1:
tensor([[1., 2.],
        [3., 4.]])
Shape: torch.Size([2, 2])

Tensor t2:
tensor([[5., 6.],
        [7., 8.]])
Shape: torch.Size([2, 2])

torch.stack([t1, t2], dim=0):
tensor([[[1., 2.],
         [3., 4.]],

        [[5., 6.],
         [7., 8.]]])
Shape: torch.Size([2, 2, 2])
(New dimension 0 created!)

torch.stack([t1, t2], dim=1):
tensor([[[1., 2.],
         [5., 6.]],

        [[3., 4.],
         [7., 8.]]])
Shape: torch.Size([2, 2, 2])

torch.stack([t1, t2], dim=2):
tensor([[[1., 5.],
         [2., 6.]],

        [[3., 7.],
         [4., 8.]]])
Shape: torch.Size([2, 2, 2])

COMPARISON: cat vs stack
cat: Concatenates along existing dimension
  cat shape: torch.Size([4, 2])

stack: Creates new dimension and stacks
  stack shape: torch.Size([2, 2, 2])

Difference: stack adds a new dimension!


In [4]:
# Use case 1: Combine feature vectors
feature1 = torch.tensor([1, 2, 3], dtype=torch.float32)
feature2 = torch.tensor([4, 5, 6], dtype=torch.float32)
feature3 = torch.tensor([7, 8, 9], dtype=torch.float32)

print("Feature vectors:")
print(f"feature1: {feature1}")
print(f"feature2: {feature2}")
print(f"feature3: {feature3}")
print()

# Stack to create batch
batch = torch.stack([feature1, feature2, feature3], dim=0)
print("Stacked into batch:")
print(batch)
print(f"Shape: {batch.shape}  (batch_size=3, features=3)")
print()

# Use case 2: Concatenate layers
layer1 = torch.randn(2, 3)
layer2 = torch.randn(2, 3)

print("Layer outputs:")
print(f"layer1 shape: {layer1.shape}")
print(f"layer2 shape: {layer2.shape}")
print()

# Concatenate features (horizontal)
features_concat = torch.cat([layer1, layer2], dim=1)
print("Concatenated features:")
print(features_concat)
print(f"Shape: {features_concat.shape}  (same rows, doubled columns)")
print()

# Use case 3: Stack for batch processing
batch1 = torch.randn(2, 3)
batch2 = torch.randn(2, 3)
batch3 = torch.randn(2, 3)
print("Multiple batches:")
print(f"batch1 shape: {batch1.shape}")
print(f"batch2 shape: {batch2.shape}")
print(f"batch3 shape: {batch3.shape}")
print()

# Stack to create larger batch
large_batch = torch.cat([batch1, batch2, batch3], dim=0)
print("Concatenated batches:")
print(f"Shape: {large_batch.shape}  (6 samples, 3 features)")
print()

# Or stack to create sequence
sequence = torch.stack([batch1, batch2, batch3], dim=0)
print("Stacked as sequence:")
print(f"Shape: {sequence.shape}  (sequence_length=3, batch_size=2, features=3)")

Feature vectors:
feature1: tensor([1., 2., 3.])
feature2: tensor([4., 5., 6.])
feature3: tensor([7., 8., 9.])

Stacked into batch:
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
Shape: torch.Size([3, 3])  (batch_size=3, features=3)

Layer outputs:
layer1 shape: torch.Size([2, 3])
layer2 shape: torch.Size([2, 3])

Concatenated features:
tensor([[ 1.8862, -0.6076, -0.4660,  1.5749, -0.4939, -0.4293],
        [ 1.6981,  0.0429,  0.6983, -1.6400, -1.8981, -0.4417]])
Shape: torch.Size([2, 6])  (same rows, doubled columns)

Multiple batches:
batch1 shape: torch.Size([2, 3])
batch2 shape: torch.Size([2, 3])
batch3 shape: torch.Size([2, 3])

Concatenated batches:
Shape: torch.Size([6, 3])  (6 samples, 3 features)

Stacked as sequence:
Shape: torch.Size([3, 2, 3])  (sequence_length=3, batch_size=2, features=3)
