# View reshape

In [1]:
import torch

In [4]:
# 均匀随机初始化
a = torch.rand(4, 1, 28, 28)

In [3]:
a.shape

torch.Size([4, 1, 28, 28])

In [4]:
a.view(4, 28 * 28)

tensor([[0.6433, 0.8808, 0.0511,  ..., 0.3921, 0.1033, 0.3169],
        [0.6885, 0.0379, 0.8065,  ..., 0.1191, 0.0319, 0.7153],
        [0.2190, 0.6243, 0.3316,  ..., 0.3841, 0.5056, 0.0302],
        [0.5766, 0.4673, 0.6388,  ..., 0.2591, 0.1887, 0.2810]])

In [6]:
a.view(4, 28 * 28).shape

torch.Size([4, 784])

In [7]:
a.view(4*28, 28).shape

torch.Size([112, 28])

In [8]:
a.view(4*1, 28, 28).shape

torch.Size([4, 28, 28])

In [9]:
b = a.view(4, 784)

In [10]:
#维度顺序改变使得数据被破坏
b.view(4, 28, 28, 1)

tensor([[[[6.4335e-01],
          [8.8077e-01],
          [5.1108e-02],
          ...,
          [4.5895e-01],
          [5.6652e-01],
          [3.7206e-02]],

         [[1.5342e-01],
          [4.0835e-01],
          [5.9318e-01],
          ...,
          [2.0960e-01],
          [4.1806e-01],
          [1.2811e-01]],

         [[7.1674e-01],
          [8.2527e-01],
          [1.6097e-01],
          ...,
          [7.6300e-01],
          [9.4519e-01],
          [8.6614e-01]],

         ...,

         [[1.6746e-01],
          [3.9573e-02],
          [9.3442e-01],
          ...,
          [5.5010e-01],
          [4.1531e-02],
          [1.7911e-01]],

         [[2.4814e-01],
          [4.3082e-01],
          [8.8916e-01],
          ...,
          [4.5898e-01],
          [1.3169e-02],
          [8.8321e-01]],

         [[2.2563e-01],
          [7.0149e-01],
          [8.6961e-01],
          ...,
          [3.9212e-01],
          [1.0327e-01],
          [3.1692e-01]]],


        [[[6.8852

# Flexible but prone to corrupt

In [5]:
# 维度不一致，出错
a.view(4, 783)

RuntimeError: shape '[4, 783]' is invalid for input of size 3136

# Squeeze vs unsqueeze

## unsqueeze
维度增加

In [6]:
a.shape

torch.Size([4, 1, 28, 28])

In [7]:
a.unsqueeze(0).shape

torch.Size([1, 4, 1, 28, 28])

In [8]:
a.unsqueeze(-1).shape

torch.Size([4, 1, 28, 28, 1])

In [9]:
a.unsqueeze(4).shape

torch.Size([4, 1, 28, 28, 1])

In [10]:
a.unsqueeze(-4).shape

torch.Size([4, 1, 1, 28, 28])

In [11]:
a.unsqueeze(-5).shape

torch.Size([1, 4, 1, 28, 28])

In [12]:
a.unsqueeze(5).shape

IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)

In [13]:
a = torch.tensor([1.2, 2.3])

In [14]:
a.unsqueeze(-1)

tensor([[1.2000],
        [2.3000]])

In [15]:
a.unsqueeze(0)

tensor([[1.2000, 2.3000]])

## For example

In [16]:
b = torch.rand(32)

In [17]:
f = torch.rand(4, 32, 14, 14)

In [18]:
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)

In [19]:
b.shape

torch.Size([1, 32, 1, 1])

## squeeze
维度删减

In [20]:
b.shape

torch.Size([1, 32, 1, 1])

In [21]:
# 删除所有维度为1的维度
b.squeeze().shape

torch.Size([32])

In [22]:
b.squeeze(0).shape

torch.Size([32, 1, 1])

In [23]:
b.squeeze(-1).shape

torch.Size([1, 32, 1])

In [24]:
# 不能挤压的维度则返回原数据
b.squeeze(1).shape

torch.Size([1, 32, 1, 1])

In [25]:
b.squeeze(-4).shape

torch.Size([32, 1, 1])

# Expand / repeat
维度扩展

## expend
用时才扩展

In [26]:
a = torch.rand(4, 32, 14, 14)

In [27]:
b.shape

torch.Size([1, 32, 1, 1])

In [28]:
b.expand(4, 32, 14, 14).shape

torch.Size([4, 32, 14, 14])

In [29]:
# 指定维度上不想变化则为-1
b.expand(-1, 32, -1, -1).shape

torch.Size([1, 32, 1, 1])

In [30]:
b.expand(-1, 32, -1, -4).shape

torch.Size([1, 32, 1, -4])

## repeat
主动复制，改变了原有数据

In [31]:
b.shape

torch.Size([1, 32, 1, 1])

In [32]:
# 每个维度上需要拷贝的次数
b.repeat(4, 32, 1, 1).shape

torch.Size([4, 1024, 1, 1])

In [33]:
b.repeat(4, 1, 1, 1).shape

torch.Size([4, 32, 1, 1])

In [34]:
b.repeat(4, 1, 32, 32).shape

torch.Size([4, 32, 32, 32])

# .t
转置

In [35]:
b.t()

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D

In [37]:
a = torch.randn(3, 4)

In [38]:
a.t()

tensor([[ 0.3701,  0.9084,  3.3079],
        [ 0.5992,  2.5736, -0.7211],
        [ 0.5155,  0.8203, -0.4993],
        [ 0.9709,  0.6069, -1.1415]])

# Transpose
维度交换

In [39]:
a.shape

torch.Size([3, 4])

In [42]:
a.transpose(0, 1)

tensor([[ 0.3701,  0.9084,  3.3079],
        [ 0.5992,  2.5736, -0.7211],
        [ 0.5155,  0.8203, -0.4993],
        [ 0.9709,  0.6069, -1.1415]])

In [None]:
# contiguous()使得数据连续

In [None]:
torch.all(torch.eq(a, a1))

# permute
指定维度进行交换

In [44]:
a = torch.rand(4, 3, 28, 28)

In [45]:
a.transpose(1, 3).shape

torch.Size([4, 28, 28, 3])

In [46]:
b = torch.rand(4, 3, 28, 32)

In [47]:
b.transpose(1, 3).shape

torch.Size([4, 32, 28, 3])

In [48]:
b.transpose(1, 3).transpose(1, 2).shape

torch.Size([4, 28, 32, 3])

In [49]:
b.permute(0, 2, 3, 1).shape

torch.Size([4, 28, 32, 3])