# Manipulating Tensors in Pytorch

### Reshaping Tensors with `.view()` and `.reshape()` functions

In [2]:
import torch

In [4]:
a = torch.tensor([[2, 3, 5, 6], [5, 1, 2, 3]])
b = torch.tensor([[4, 2], [1, 2], [6, 7]])

c = torch.mm(a.view(4, 2), b.reshape(2, 3))
print(c)

tensor([[14, 22, 23],
        [32, 46, 47],
        [22, 16, 12],
        [14, 22, 23]])


In [7]:
a = torch.rand(3,3)
b = torch.arange(1,13)

c = torch.mm(a, b.view(3, 4).to(torch.float32))

print(c)

tensor([[ 7.7157,  9.8867, 12.0576, 14.2286],
        [ 6.2026,  7.5897,  8.9768, 10.3639],
        [10.0020, 11.7572, 13.5125, 15.2677]])


## Slicing and Indexing

**Indexing tensors**

In [9]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [9, 1, 2]])

a = x[0,0]
b = x[1,1]
c = x[2,2]

print(a, b, c)
print(a + b + c)

tensor(1) tensor(5) tensor(2)
tensor(8)


**Advanced Indexing**

In [43]:
x = torch.tensor([[5, 1, 2, 4],
                  [6, 7, 12, 1],
                  [8, 1, 9, 6],
                  [1, 9, 5, 7]])

indices = torch.tensor([[0, 1, 2, 3],
                        [0, 0, 2, 2]])

print(x[indices[0], indices[1]])

tensor([5, 6, 9, 5])


In [44]:
x = torch.tensor([[5, 6, 7],
                  [7, 1, 2],
                  [6, 1, 3]])

indices = torch.tensor([0, 2])

print(x[:, indices])

tensor([[5, 7],
        [7, 2],
        [6, 3]])


## Slicing Tensors

In [46]:
x = torch.tensor([[4, 5, 6],
                  [7, 1, 2],
                  [5, 1, 5]])

# select first row
print(x[0, :])

# select middle column
print(x[:, 1])

# select last column
print(x[:, 2])


tensor([4, 5, 6])
tensor([5, 1, 1])
tensor([6, 2, 5])


In [61]:
x = torch.randint(4, 12, (3, 4), dtype=torch.float32)

print(x[1:, [1, 2]])

print(x[:2, [0, 3]])

tensor([[ 8.,  6.],
        [11.,  5.]])
tensor([[ 9., 11.],
        [ 6.,  5.]])


In [62]:
x

tensor([[ 9., 11.,  5., 11.],
        [ 6.,  8.,  6.,  5.],
        [10., 11.,  5., 11.]])

## Boolean Indexing

In [73]:
x = torch.randint(5, 15, (3, 3))

print(x[x < 10])
print(x[(x >= 5) & (x < 9)])
print(x[(x > 10) | (x < 8)])

tensor([7, 6, 9, 5, 6, 6])
tensor([7, 6, 5, 6, 6])
tensor([ 7,  6, 14,  5,  6,  6])


In [74]:
x

tensor([[ 7, 10,  6],
        [14, 10,  9],
        [ 5,  6,  6]])

## Matrix Transpose

In [76]:
a = torch.tensor([[1, 2, 3], [4, 5, 6], [6, 7, 2]])
b = torch.tensor([[2, 3, 4], [5, 6, 7]])

b = b.t()

c = torch.mm(a, b)

print(c)

tensor([[20, 38],
        [47, 92],
        [41, 86]])


In [77]:
b

tensor([[2, 5],
        [3, 6],
        [4, 7]])

In [81]:
a = torch.randint(5, 10, (4, 3, 2))
b = torch.randint(8, 12, (3, 2, 4))

b = torch.transpose(b, 0, 2) # (4, 2, 3)

c = torch.matmul(a, b)
c

tensor([[[103,  96, 132],
         [110, 104, 143],
         [117, 112, 154]],

        [[171, 144, 189],
         [132, 112, 148],
         [151, 128, 169]],

        [[153, 149, 167],
         [152, 152, 168],
         [143, 141, 157]],

        [[126, 149, 144],
         [ 99, 116, 111],
         [135, 158, 151]]])

In [84]:
a = torch.randint(5, 10, (4, 3, 2))
b = torch.randint(8, 12, (3, 2, 4))

# resulting matrix (4, 3, 3)
c = torch.bmm(a, b.transpose(0, 2))
print(c)

tensor([[[162, 153, 171],
         [ 98,  93, 103],
         [122, 117, 127]],

        [[151, 142, 158],
         [162, 154, 171],
         [122, 114, 127]],

        [[135, 127, 140],
         [131, 139, 140],
         [144, 138, 150]],

        [[138, 135, 129],
         [106,  99,  93],
         [128, 117, 109]]])


In [85]:
c.shape

torch.Size([4, 3, 3])

In [86]:
a = torch.randint(5, 10, (4, 3))
b = torch.randint(8, 12, (3, 4))

torch.bmm(a, b)

RuntimeError: batch1 must be a 3D tensor

## Concatenation and Splitting Tensors

**Concatenation**

In [91]:
a = torch.randint(3, 8, (2, 2))
b = torch.randint(4, 9, (2, 2))
c = torch.randint(6, 10, (3, 4))

c1 = torch.cat((a, b), dim=0)
# (4, 2)

print(torch.matmul(c, c1))

tensor([[191, 179],
        [170, 160],
        [156, 147]])


In [94]:
c2 = torch.reshape(c, (6, 2))
c3 = torch.cat((a, b), dim=1)
# (2, 4)

print(torch.matmul(c2, c3))

tensor([[ 64, 103, 119,  71],
        [ 68, 110, 127,  76],
        [ 56,  91, 105,  63],
        [ 60,  99, 114,  69],
        [ 52,  85,  98,  59],
        [ 56,  90, 104,  62]])


**Splitting tensors**

In [108]:
x = torch.randint(3, 10, (4, 3))

a1, a2 = torch.split(x, 2, dim=0)
b1, b2 = torch.split(x, 2, dim=1)

print(a1.shape)
print(b1.shape)

print(torch.matmul(a1.t(), b1.t()))

torch.Size([2, 3])
torch.Size([4, 2])
tensor([[ 45,  36,  51,  43],
        [ 81,  72, 117,  87],
        [ 48,  40,  60,  48]])


**Stacking tensors**

In [109]:
x = torch.tensor([5, 6, 4])
y = torch.tensor([1, 4, 5])
z = torch.tensor([3, 2, 1])

torch.stack((x, y, z))

tensor([[5, 6, 4],
        [1, 4, 5],
        [3, 2, 1]])

In [112]:
a = torch.tensor([1, 2])
b = torch.tensor([1, 2, 3])

torch.stack((a, b))

RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [3] at entry 1

**Chunking tensors**

In [113]:
x = torch.randint(3, 10, (4, 3))

z1, z2, z3 = torch.chunk(x, 3, dim=1)

print(z1)
print(z2)
print(z3)

tensor([[8],
        [9],
        [8],
        [3]])
tensor([[9],
        [9],
        [8],
        [4]])
tensor([[6],
        [6],
        [5],
        [8]])


In [114]:
x

tensor([[8, 9, 6],
        [9, 9, 6],
        [8, 8, 5],
        [3, 4, 8]])

In [115]:
z1, z2, z3, z4 = torch.chunk(x, 4, dim=0)

print(z1)
print(z2)
print(z3)
print(z4)

tensor([[8, 9, 6]])
tensor([[9, 9, 6]])
tensor([[8, 8, 5]])
tensor([[3, 4, 8]])
