* https://blog.ezyang.com/2019/05/pytorch-internals/
* https://web.mit.edu/~ezyang/Public/pytorch-internals.pdf


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt

print(torch.__version__)
# !nvidia-smi

2.9.1


In [3]:
def info(t):
	print("\n--------------------------------")
	print(f'{t} \nshape={t.shape} dtype={t.dtype} device={t.device} ndim={t.ndim}\n')

	for i in range(t.ndim):
		print(f't.sum(dim={i}): {t.sum(dim=i)}') # Sum of the values
		print(f'mean dim={i}: {t.mean(dim=i)}') # Find the mean value
		print(f"t.max(dim={i}): {t.max(dim=i)}") # Find the maximum value
		print(f"t.min(dim={i}): {t.min(dim=i)}") # Find the minimum value
		print(f"t.argmax(dim={i}): {t.argmax(dim=i)}") # Find the index of the maximum value
		print(f"t.argmin(dim={i}): {t.argmin(dim=i)}") # Find the index of the minimum value
		print(f"t.std(dim={i}): {t.std(dim=i)}") # Standard deviation
		print(f"t.var(dim=-1): {t.var(dim=-1)}") # Variance
		print("")

In [4]:

info(torch.tensor(
	[
		[1.0,2.0,3.0], 
		[4.0,5.0,6.0]
	]
))


--------------------------------
tensor([[1., 2., 3.],
        [4., 5., 6.]]) 
shape=torch.Size([2, 3]) dtype=torch.float32 device=cpu ndim=2

t.sum(dim=0): tensor([5., 7., 9.])
mean dim=0: tensor([2.5000, 3.5000, 4.5000])
t.max(dim=0): torch.return_types.max(
values=tensor([4., 5., 6.]),
indices=tensor([1, 1, 1]))
t.min(dim=0): torch.return_types.min(
values=tensor([1., 2., 3.]),
indices=tensor([0, 0, 0]))
t.argmax(dim=0): tensor([1, 1, 1])
t.argmin(dim=0): tensor([0, 0, 0])
t.std(dim=0): tensor([2.1213, 2.1213, 2.1213])
t.var(dim=-1): tensor([1., 1.])

t.sum(dim=1): tensor([ 6., 15.])
mean dim=1: tensor([2., 5.])
t.max(dim=1): torch.return_types.max(
values=tensor([3., 6.]),
indices=tensor([2, 2]))
t.min(dim=1): torch.return_types.min(
values=tensor([1., 4.]),
indices=tensor([0, 0]))
t.argmax(dim=1): tensor([2, 2])
t.argmin(dim=1): tensor([0, 0])
t.std(dim=1): tensor([1., 1.])
t.var(dim=-1): tensor([1., 1.])



In [5]:
info(torch.tensor(
	[
		[
			[1.0,2.0,3.0,4.0], 
			[5.0,6.0,7.0,8.0],
			[9.0,10.0,11.0,12.0], 
		],
		[
			[13.0,14.0,15.0,16.0],
			[17.0,18.0,19.0,20.0],
			[21.0,22.0,23.0,24.0],
		]
	]
))


--------------------------------
tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]],

        [[13., 14., 15., 16.],
         [17., 18., 19., 20.],
         [21., 22., 23., 24.]]]) 
shape=torch.Size([2, 3, 4]) dtype=torch.float32 device=cpu ndim=3

t.sum(dim=0): tensor([[14., 16., 18., 20.],
        [22., 24., 26., 28.],
        [30., 32., 34., 36.]])
mean dim=0: tensor([[ 7.,  8.,  9., 10.],
        [11., 12., 13., 14.],
        [15., 16., 17., 18.]])
t.max(dim=0): torch.return_types.max(
values=tensor([[13., 14., 15., 16.],
        [17., 18., 19., 20.],
        [21., 22., 23., 24.]]),
indices=tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]]))
t.min(dim=0): torch.return_types.min(
values=tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]]),
indices=tensor([[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]]))
t.argmax(dim=0): tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1

In [6]:
one_to_ten = torch.arange(1, 11)
print(one_to_ten)

t = torch.tensor([1,2,3])
print(t * t)
print(t.matmul(t))
print(t @ t)


tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
tensor([1, 4, 9])
tensor(14)
tensor(14)


In [7]:
# Stack : concatenate tensors along a new dimension	
# Squeeze : remove all dimensions of size 1
# Unsqueeze : add a new dimension of size 1
# Reshape : change the shape of the tensor
# Permute : permute the dimensions of the tensor
# View : reshape tensor to different shape
# Transpose : swap two dimensions of the tensor

t = torch.tensor([
	[1,2,3],
	[4,5,6]
])
print(f"t.unsqueeze(0): {t.unsqueeze(0)}")
print(f"t.unsqueeze(1): {t.unsqueeze(1)}")
print(f"t.unsqueeze(-1): {t.unsqueeze(-1)}")
print("--------------------------------")
print(f"t.view(6): {t.view(6)}")
print("--------------------------------")
print(f"t.permute(1, 0): {t.permute(1, 0)}")
print(f"t.transpose(0, 1): {t.transpose(0, 1)}")
print("--------------------------------")
print(f"torch.stack((t, t)): {torch.stack((t, t))}")
print("--------------------------------")
print(f"t.squeeze(): {t.squeeze()}")
print(f"t.unsqueeze(0).squeeze(): {t.unsqueeze(0).squeeze()}")
print("--------------------------------")
print(f"t.reshape(2, 3): {t.reshape(2, 3)}")
print(f"t.reshape(6): {t.reshape(6)}")
print("--------------------------------")
print(f"t.permute(1, 0): {t.permute(1, 0)}")
print(f"t.transpose(0, 1): {t.transpose(0, 1)}")
print("--------------------------------")
print(f"t.view(2, 3): {t.view(2, 3)}")
print(f"t.view(6): {t.view(6)}")


t.unsqueeze(0): tensor([[[1, 2, 3],
         [4, 5, 6]]])
t.unsqueeze(1): tensor([[[1, 2, 3]],

        [[4, 5, 6]]])
t.unsqueeze(-1): tensor([[[1],
         [2],
         [3]],

        [[4],
         [5],
         [6]]])
--------------------------------
t.view(6): tensor([1, 2, 3, 4, 5, 6])
--------------------------------
t.permute(1, 0): tensor([[1, 4],
        [2, 5],
        [3, 6]])
t.transpose(0, 1): tensor([[1, 4],
        [2, 5],
        [3, 6]])
--------------------------------
torch.stack((t, t)): tensor([[[1, 2, 3],
         [4, 5, 6]],

        [[1, 2, 3],
         [4, 5, 6]]])
--------------------------------
t.squeeze(): tensor([[1, 2, 3],
        [4, 5, 6]])
t.unsqueeze(0).squeeze(): tensor([[1, 2, 3],
        [4, 5, 6]])
--------------------------------
t.reshape(2, 3): tensor([[1, 2, 3],
        [4, 5, 6]])
t.reshape(6): tensor([1, 2, 3, 4, 5, 6])
--------------------------------
t.permute(1, 0): tensor([[1, 4],
        [2, 5],
        [3, 6]])
t.transpose(0, 1): ten

## Broadcasting in PyTorch/NumPy

https://numpy.org/doc/stable/user/basics.broadcasting.html

Broadcasting lets you do operations on tensors with different shapes without explicit copying.
The Rules
Two tensors are broadcastable if, comparing dimensions right-to-left:
Dimensions are equal, OR
One of them is 1, OR
One of them doesn't exist (missing)

Compare shapes RIGHT to LEFT:
A: [8, 1, 6, 1]
B:    [7, 1, 5]
     ↑  ↑  ↑  ↑
     8  7  6  5  → Result shape: [8, 7, 6, 5]

In [8]:
# Simple: scalar broadcasts to any shape
a = torch.tensor([[1, 2], [3, 4]])  # [2, 2]
b = torch.tensor(10)                 # scalar []
a + b  # [[11, 12], [13, 14]]

# Row broadcasts across rows
a = torch.tensor([
	[1, 2, 3],
	[4, 5, 6]])        # [2, 3]
b = torch.tensor([10, 20, 30])       # [3]

print(a + b)

# Column broadcasts across columns (need shape [2, 1])
b = torch.tensor([[10], [20]])       # [2, 1]
print(a + b)


tensor([[11, 22, 33],
        [14, 25, 36]])
tensor([[11, 12, 13],
        [24, 25, 26]])


In [9]:
a = torch.tensor([[1],    # shape [3, 1]
                  [2],
                  [3]])

b = torch.tensor([10, 20, 30])  # shape [3]

# Step 1: Align shapes right-to-left
# a: [3, 1]
# b:    [3]  → treated as [1, 3]

# Step 2: Expand dims of size 1
# a: [3, 1] → [[1, 1, 1],    (row repeated 3x)
#              [2, 2, 2],
#              [3, 3, 3]]
#
# b: [1, 3] → [[10, 20, 30], (col repeated 3x)
#              [10, 20, 30],
#              [10, 20, 30]]

# Step 3: Element-wise operation
a + b  # [[11, 21, 31],
       #  [12, 22, 32],
       #  [13, 23, 33]]

tensor([[11, 21, 31],
        [12, 22, 32],
        [13, 23, 33]])

In [10]:
# 1. Normalize columns (subtract mean per column)
x = torch.randn(100, 5)          # [100, 5]
mean = x.mean(dim=0)             # [5]
x_centered = x - mean            # broadcasts [5] → [100, 5]

# 2. Scale rows differently
x = torch.randn(100, 5)              # [100, 5]
weights = torch.rand(100, 1)         # [100, 1] - random weights per row
x_scaled = x * weights               # broadcasts [100,1] → [100, 5]

# Or create specific weights:
weights = torch.linspace(0.1, 1.0, 100).unsqueeze(1)  # [100, 1]

# 3. Outer product
a = torch.tensor([1, 2, 3])      # [3]
b = torch.tensor([10, 20])       # [2]
outer = a[:, None] * b[None, :]  # [3,1] * [1,2] → [3, 2]
# [[10, 20],
#  [20, 40],
#  [30, 60]]

# 4. Pairwise distances
x = torch.randn(100, 3)          # 100 points in 3D
# x[:, None, :] shape: [100, 1, 3]
# x[None, :, :] shape: [1, 100, 3]
diff = x[:, None, :] - x[None, :, :]  # [100, 100, 3]
dist = (diff ** 2).sum(-1).sqrt()     # [100, 100]

In [11]:
# Broadcasting doesn't copy data - it's a view trick. 
# The tensor acts as if it's larger but uses the same memory:
a = torch.tensor([1, 2, 3])
expanded = a.expand(1000, 3)  # "looks" like [1000, 3]
expanded.storage().size()     # still only 3 elements in memory!

  expanded.storage().size()     # still only 3 elements in memory!


3

In [24]:
a = torch.arange(18)
print(a)
print(a.view(2, 3, 3))
print(a.storage().size())
print(a.storage().dtype)


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]]])
18
torch.int64
