In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

pd.options.display.max_columns = 0

## Stacking and Concatenation

In [2]:
x_2d = [
    t.squeeze() for t in torch.split(torch.arange(0, 12).reshape(3, 2, 2), 1)
]
x_2d

[tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[ 8,  9],
         [10, 11]])]

In [3]:
x_2d[0].shape

torch.Size([2, 2])

In [8]:
cat0 = torch.cat(x_2d, dim=0)
print(cat0.shape)
cat0

torch.Size([6, 2])


tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11]])

In [9]:
cat1 = torch.cat(x_2d, dim=1)
print(cat1.shape)
cat1

torch.Size([2, 6])


tensor([[ 0,  1,  4,  5,  8,  9],
        [ 2,  3,  6,  7, 10, 11]])

In [10]:
# CANNOT cat along a dimension that doesn't exist
cat2 = torch.cat(x_2d, dim=2)
print(cat2.shape)
cat2

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

In [4]:
dim0 = torch.stack(x_2d, dim=0)
print(dim0.shape)
dim0

torch.Size([3, 2, 2])


tensor([[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]]])

In [5]:
dim1 = torch.stack(x_2d, dim=1)
print(dim1.shape)
dim1

torch.Size([2, 3, 2])


tensor([[[ 0,  1],
         [ 4,  5],
         [ 8,  9]],

        [[ 2,  3],
         [ 6,  7],
         [10, 11]]])

In [6]:
dim2 = torch.stack(x_2d, dim=2)
print(dim2.shape)
dim2

torch.Size([2, 2, 3])


tensor([[[ 0,  4,  8],
         [ 1,  5,  9]],

        [[ 2,  6, 10],
         [ 3,  7, 11]]])

## Reshaping

In [15]:
tensor_2d = torch.stack(x_2d)
print(tensor_2d.shape)
tensor_2d

torch.Size([3, 2, 2])


tensor([[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]]])

In [16]:
# Note the subtle difference vs. the original stack dim=1 example
tensor_2d.view((2,3,2))

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]]])

In [17]:
tensor_2d.view((2,2,3))

tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]]])

In [19]:
tensor_2dm1 = tensor_2d.view(-1)
print(tensor_2dm1.shape)
tensor_2dm1

torch.Size([12])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

In [23]:
tensor_2d.view((12))

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

In [22]:
tensor_2d.view((3,-1))

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [24]:
tensor_2d.view((3,4))

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])