In [1]:
import torch

In [4]:
# Create a 2D tensor
t = torch.tensor([[1, 2, 3], 
                  [4, 5, 6]], dtype=torch.float32)

print("Original tensor:")
print(t)
print(f"Shape: {t.shape}")
print()

# Flatten to 1D
t_flat = t.flatten()
print("Flattened tensor:")
print(t_flat)
print(f"Shape: {t_flat.shape}")
print()
# Reshape back to
# Flatten from start dimension
t_flat_start = torch.flatten(t, start_dim=0)
print("Flattened from start_dim=0:")
print(t_flat_start)
print(f"Shape: {t_flat_start.shape}")
print()



Original tensor:
tensor([[1., 2., 3.],
        [4., 5., 6.]])
Shape: torch.Size([2, 3])

Flattened tensor:
tensor([1., 2., 3., 4., 5., 6.])
Shape: torch.Size([6])

Flattened from start_dim=0:
tensor([1., 2., 3., 4., 5., 6.])
Shape: torch.Size([6])



In [5]:
# 3D example
t_3d = torch.tensor([[[1, 2], [3, 4]], 
                     [[5, 6], [7, 8]]], dtype=torch.float32)
print("3D tensor:")
print(t_3d)
print(f"Shape: {t_3d.shape}")
print()

t_3d_flat = torch.flatten(t_3d, start_dim=1)
print("Flattened 3D tensor:")
print(t_3d_flat)
print(f"Shape: {t_3d_flat.shape}")

3D tensor:
tensor([[[1., 2.],
         [3., 4.]],

        [[5., 6.],
         [7., 8.]]])
Shape: torch.Size([2, 2, 2])

Flattened 3D tensor:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
Shape: torch.Size([2, 4])


### Reshape the tensor

In [18]:

# Create a tensor
t = torch.tensor([[1, 2, 3, 4], 
                  [5, 6, 7, 8]], dtype=torch.float32)

print("Original tensor:")
print(t)
print(f"Shape: {t.shape}(2*4 = 8 elements) ")
print()
t_reshaped = t.reshape(4, 2)
print(t_reshaped)
print("Reshaped tensor (4x1):",t_reshaped.shape)

Original tensor:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
Shape: torch.Size([2, 4])(2*4 = 8 elements) 

tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]])
Reshaped tensor (4x1): torch.Size([4, 2])


In [21]:
auto_reshaped = t.reshape(-1) #Auto Calculate the size and shape
print(auto_reshaped)
print("Reshaped tensor:",auto_reshaped.shape)

tensor([1., 2., 3., 4., 5., 6., 7., 8.])
Reshaped tensor: torch.Size([8])


In [23]:
# Reshape 3D to 2D
t_3d = torch.arange(24).reshape(2, 3, 4)
print("3D tensor (2×3×4):")
print(t_3d)
print(f"Shape: {t_3d.shape}")
print()

t_2d_from_3d = t_3d.reshape(6, 4)
print("Reshaped to 2D (6×4):")
print(t_2d_from_3d)
print(f"Shape: {t_2d_from_3d.shape}")

3D tensor (2×3×4):
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
Shape: torch.Size([2, 3, 4])

Reshaped to 2D (6×4):
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]])
Shape: torch.Size([6, 4])


In [24]:
# Create a tensor
t = torch.tensor([[1, 2, 3, 4], 
                  [5, 6, 7, 8]], dtype=torch.float32)

print("Original tensor:")
print(t)
print(f"Shape: {t.shape}")
print()

print("=" * 60)
print("FLATTEN:")
print("=" * 60)
t_flat = t.flatten()
print(f"t.flatten(): {t_flat.shape}")
print()

print("=" * 60)
print("RESHAPE:")
print("=" * 60)
t_reshaped = t.reshape(4, 2)
print(f"t.reshape(4, 2): {t_reshaped.shape}")
t_reshaped2 = t.reshape(8)
print(f"t.reshape(8): {t_reshaped2.shape}")
print()
print("=" * 60)
print("VIEW:")
print("=" * 60)
t_view = t.view(4, 2)
print(f"t.view(4, 2): {t_view.shape}")
print("(Shares memory with original)")
print()

print("=" * 60)
print("SQUEEZE:")
print("=" * 60)
t_with_ones = torch.tensor([[[1, 2, 3]]], dtype=torch.float32)
print(f"Original: {t_with_ones.shape}")
t_squeezed = t_with_ones.squeeze()
print(f"t.squeeze(): {t_squeezed.shape}")
print()

print("=" * 60)
print("UNSQUEEZE:")
print("=" * 60)
t_1d = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
print(f"Original: {t_1d.shape}")
t_unsqueezed = t_1d.unsqueeze(0)
print(f"t.unsqueeze(0): {t_unsqueezed.shape}")
print()
print("=" * 60)
print("SUMMARY:")
print("=" * 60)
print("flatten(): Make 1D")
print("reshape(): Change shape (may copy)")
print("view(): Change shape (shares memory, faster)")
print("squeeze(): Remove size-1 dimensions")
print("unsqueeze(): Add size-1 dimension")

Original tensor:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
Shape: torch.Size([2, 4])

FLATTEN:
t.flatten(): torch.Size([8])

RESHAPE:
t.reshape(4, 2): torch.Size([4, 2])
t.reshape(8): torch.Size([8])

VIEW:
t.view(4, 2): torch.Size([4, 2])
(Shares memory with original)

SQUEEZE:
Original: torch.Size([1, 1, 3])
t.squeeze(): torch.Size([3])

UNSQUEEZE:
Original: torch.Size([4])
t.unsqueeze(0): torch.Size([1, 4])

SUMMARY:
flatten(): Make 1D
reshape(): Change shape (may copy)
view(): Change shape (shares memory, faster)
squeeze(): Remove size-1 dimensions
unsqueeze(): Add size-1 dimension
