# Tensor Shape Mastery: 30 Practice Problems

**Goal**: Master tensor shape manipulation for ML interviews

**Rules**:
1. Always predict the output shape BEFORE running code
2. Write your prediction as a comment
3. If wrong, debug why
4. Time yourself: aim for <30 seconds per question

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Set seed for reproducibility
torch.manual_seed(42)

def check_shape(tensor, expected_shape, question_num):
    """Helper to check if tensor has expected shape"""
    actual = tuple(tensor.shape)
    if actual == expected_shape:
        print(f"✅ Q{question_num}: Correct! Shape is {actual}")
    else:
        print(f"❌ Q{question_num}: Wrong! Expected {expected_shape}, got {actual}")
    return tensor


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.3.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/amanr/miniconda/envs/ai2/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/amanr/miniconda/envs/ai2/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/amanr/miniconda/envs/ai2/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.

## Level 1: Basic Operations (Q1-10)

In [3]:
# Q1: Basic reshape
x = torch.randn(4, 6)
# Reshape to (3, 8). What happens?
# Your prediction: 
# (3,8)
try:
    result = x.reshape(3, 8)
    check_shape(result, (3, 8), 1)
except Exception as e:
    print(f"Q1: Error - {e}")

✅ Q1: Correct! Shape is (3, 8)


In [4]:
# Q2: Transpose operations
x = torch.randn(2, 3, 4, 5)
# What's the shape after x.transpose(1, 3)?
# Your prediction: 
# (2, 5, 4, 3)
result = x.transpose(1, 3)
check_shape(result, (2, 5, 4, 3), 2)

✅ Q2: Correct! Shape is (2, 5, 4, 3)


tensor([[[[ 1.0868e-02,  6.4076e-01, -1.2685e-02],
          [ 5.2462e-01,  7.5276e-01,  3.3989e-01],
          [-1.0495e+00, -1.3109e-03, -1.4364e+00],
          [ 4.8354e-01,  4.7706e-01,  5.7600e-01]],

         [[-3.3874e-01,  5.8325e-01,  2.4084e-01],
          [ 1.1412e+00,  4.0476e-01,  7.1997e-01],
          [ 6.0390e-01, -3.0360e-01, -1.1299e+00],
          [-2.5095e+00,  7.2618e-01,  1.1415e+00]],

         [[-1.3407e+00,  1.0669e+00,  1.3254e-01],
          [ 5.1644e-02,  1.7847e-01,  4.1141e-01],
          [-1.7223e+00, -1.4570e+00, -1.3603e-01],
          [ 4.8800e-01,  9.1152e-02,  1.8565e-02]],

         [[-5.8537e-01, -4.5015e-01,  7.6424e-01],
          [ 7.4395e-01,  2.6491e-01,  1.9312e+00],
          [-8.2777e-01, -1.0234e-01,  1.6354e+00],
          [ 7.8459e-01, -3.8907e-01, -1.8058e+00]],

         [[ 5.3619e-01, -1.8527e-01,  1.0950e+00],
          [-4.8158e-01,  1.2732e+00,  1.0119e+00],
          [ 1.3347e+00, -5.9915e-01,  6.5474e-01],
          [ 2.8647e-02,

In [5]:
# Q3: Squeeze and unsqueeze
x = torch.randn(1, 3, 1, 4, 1)
# What's x.squeeze().unsqueeze(0).shape?
# Your prediction: 
# 1, 3, 4
result = x.squeeze().unsqueeze(0)
check_shape(result, (1, 3, 4), 3)

✅ Q3: Correct! Shape is (1, 3, 4)


tensor([[[-1.2842, -0.6917, -0.5359,  0.3355],
         [ 0.2469,  0.0324,  0.4057,  1.6181],
         [ 0.3932, -0.2148,  1.2651, -0.3178]]])

In [None]:
# Q4: Broadcasting addition
x = torch.randn(2, 1, 4)
y = torch.randn(3, 4)
# What's (x + y).shape?
# Your prediction: 

result = x + y
check_shape(result, (2, 3, 4), 4)

In [None]:
# Q5: Matrix multiplication
a = torch.randn(3, 4, 5)
b = torch.randn(3, 5, 6)
# What's torch.bmm(a, b).shape?
# Your prediction: 

result = torch.bmm(a, b)
check_shape(result, (3, 4, 6), 5)

In [6]:
# Q6: View with -1
x = torch.randn(2, 3, 4, 5)
# What's x.view(2, -1).shape?
# Your prediction: 

result = x.view(2, -1)
check_shape(result, (2, 60), 6)

✅ Q6: Correct! Shape is (2, 60)


tensor([[-0.9291,  0.2762, -0.5389,  0.4626, -0.8719, -0.0271, -0.3532,  1.4639,
          1.2554, -0.7150,  0.8539,  0.5130,  0.5397,  0.5655,  0.5058,  0.2225,
         -0.6855,  0.5636, -1.5072, -1.6107, -1.4790,  0.4323, -0.1250,  0.7821,
         -1.5988, -0.1091,  0.7152,  0.0391,  1.3059,  0.2466, -1.9776,  0.0179,
         -1.3793,  0.6258, -2.5850, -0.0240, -0.1222, -0.7470,  1.7093,  0.0579,
          1.1930,  1.9373,  0.7287,  0.9809,  0.4146,  1.1566,  0.2691, -0.0366,
          0.9733, -1.0151, -0.5419, -0.4410, -0.3136, -0.1293, -0.7150, -0.0476,
          2.0207,  0.2539,  0.9364,  0.7122],
        [-0.0318,  0.1016,  1.3433,  0.7133,  0.4038, -0.7140,  0.8337, -0.9585,
          0.4536,  1.2461, -2.3065, -1.2869,  0.1799, -2.1268, -0.1341, -1.0408,
         -0.7647, -0.0553,  1.2049, -0.9825,  0.4334, -0.7172,  1.0554, -1.4534,
          0.4652,  0.3714, -0.0047,  0.0795,  0.3782,  0.7051, -1.7237, -0.8435,
          0.4351,  0.2659, -0.5871,  0.0827,  0.8854,  0.1824, 

In [None]:
# Q7: Permute dimensions
x = torch.randn(2, 3, 4, 5)
# What's x.permute(3, 1, 0, 2).shape?
# Your prediction: 

result = x.permute(3, 1, 0, 2)
check_shape(result, (5, 3, 2, 4), 7)

In [None]:
# Q8: Cat along different dimensions
x = torch.randn(2, 3, 4)
y = torch.randn(2, 5, 4)
# What's torch.cat([x, y], dim=1).shape?
# Your prediction: 

result = torch.cat([x, y], dim=1)
check_shape(result, (2, 8, 4), 8)

In [None]:
# Q9: Stack vs Cat
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)
# What's torch.stack([x, y, z], dim=1).shape?
# Your prediction: 

result = torch.stack([x, y, z], dim=1)
check_shape(result, (2, 3, 3), 9)

In [None]:
# Q10: Advanced indexing
x = torch.randn(4, 5, 6)
# What's x[:, [1, 3], :].shape?
# Your prediction: 

result = x[:, [1, 3], :]
check_shape(result, (4, 2, 6), 10)

## Level 2: Neural Network Operations (Q11-20)

In [None]:
# Q11: Linear layer
batch_size, seq_len, d_model = 4, 10, 512
x = torch.randn(batch_size, seq_len, d_model)
linear = nn.Linear(d_model, 256)
# What's linear(x).shape?
# Your prediction: 

result = linear(x)
check_shape(result, (4, 10, 256), 11)

In [None]:
# Q12: Embedding layer
vocab_size, embed_dim = 1000, 128
seq_len, batch_size = 20, 8
token_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) # 8, 20
embedding = nn.Embedding(vocab_size, embed_dim)
# What's embedding(token_ids).shape?
# Your prediction: 

result = embedding(token_ids)
check_shape(result, (8, 20, 128), 12)

In [None]:
# Q13: Multi-head attention setup
batch_size, seq_len, d_model = 2, 16, 512
num_heads = 8
x = torch.randn(batch_size, seq_len, d_model) # 2, 16, 512
# Split for multi-head: reshape to (batch, seq, heads, head_dim)
# What's the head_dim and final shape after transpose?
# Your prediction: 

head_dim = d_model // num_heads
x_heads = x.view(batch_size, seq_len, num_heads, head_dim) # (2, 16, 8, 64)
x_transposed = x_heads.transpose(1, 2)  # (batch, heads, seq, head_dim)
check_shape(x_transposed, (2, 8, 16, 64), 13)

In [None]:
# Q14: Attention scores
batch_size, num_heads, seq_len, head_dim = 2, 8, 16, 64
q = torch.randn(batch_size, num_heads, seq_len, head_dim)
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
# What's torch.matmul(q, k.transpose(-2, -1)).shape?
# Your prediction: 2, 8, 16, 64 | 2, 8, 64, 16
# 2, 8, 16, 16

scores = torch.matmul(q, k.transpose(-2, -1))
check_shape(scores, (2, 8, 16, 16), 14)

In [None]:
# Q15: Causal mask
seq_len = 10
mask = torch.tril(torch.ones(seq_len, seq_len))
# You have attention scores (4, 6, 10, 10)
# How does mask broadcast when you do scores.masked_fill(mask == 0, -float('inf'))?
# Your prediction: 

scores = torch.randn(4, 6, 10, 10)
result = scores.masked_fill(mask == 0, -float('inf'))
check_shape(result, (4, 6, 10, 10), 15)

In [None]:
# Q16: Conv1D for text
batch_size, seq_len, embed_dim = 8, 100, 256
x = torch.randn(batch_size, embed_dim, seq_len)  # 8, 256, 100
conv = nn.Conv1d(embed_dim, 512, kernel_size=3, padding=1)
# What's conv(x).shape?
# Your prediction: 

result = conv(x)
check_shape(result, (8, 512, 100), 16)

In [None]:
# Q17: Global average pooling
x = torch.randn(4, 512, 28, 28)  # Image features
# What's F.adaptive_avg_pool2d(x, (1, 1)).shape?
# Your prediction: 

result = F.adaptive_avg_pool2d(x, (1, 1))
check_shape(result, (4, 512, 1, 1), 17)

In [None]:
# Q18: Batch normalization
x = torch.randn(32, 256, 14, 14)
bn = nn.BatchNorm2d(256)
# What's bn(x).shape?
# Your prediction: 

result = bn(x)
check_shape(result, (32, 256, 14, 14), 18)

In [None]:
# Q19: Layer normalization
x = torch.randn(4, 10, 512)  # (batch, seq, features)
ln = nn.LayerNorm(512)
# What's ln(x).shape?
# Your prediction: 

result = ln(x)
check_shape(result, (4, 10, 512), 19)

In [None]:
# Q20: Dropout during training
x = torch.randn(8, 16, 256)
dropout = nn.Dropout(0.1)
# What's dropout(x).shape during training?
# Your prediction: 

dropout.train()
result = dropout(x)
check_shape(result, (8, 16, 256), 20)

## Level 3: Advanced Operations (Q21-30)

In [None]:
# Q21: Einstein summation
a = torch.randn(4, 3, 5)
b = torch.randn(4, 5, 7)
# What's torch.einsum('bij,bjk->bik', a, b).shape?
# Your prediction: 

result = torch.einsum('bij,bjk->bik', a, b)
check_shape(result, (4, 3, 7), 21)

In [None]:
# Q22: Cross attention
# Query from decoder, Key/Value from encoder
q = torch.randn(2, 8, 20, 64)  # (batch, heads, tgt_len, head_dim)
k = torch.randn(2, 8, 30, 64)  # (batch, heads, src_len, head_dim)
v = torch.randn(2, 8, 30, 64)
# What's the attention scores shape: q @ k.transpose(-2, -1)?
# Your prediction: 

scores = torch.matmul(q, k.transpose(-2, -1))
check_shape(scores, (2, 8, 20, 30), 22)

In [None]:
# Q23: Grouped convolution
x = torch.randn(8, 64, 32, 32)
conv = nn.Conv2d(64, 128, kernel_size=3, padding=1, groups=8)
# What's conv(x).shape?
# Your prediction: 

result = conv(x)
check_shape(result, (8, 128, 32, 32), 23)

In [None]:
# Q24: Positional encoding addition
batch_size, seq_len, d_model = 4, 50, 512
tokens = torch.randn(batch_size, seq_len, d_model)
pos_encoding = torch.randn(1, seq_len, d_model)  # Learned positional encoding
# What's (tokens + pos_encoding).shape?
# Your prediction: 

result = tokens + pos_encoding
check_shape(result, (4, 50, 512), 24)

In [None]:
# Q25: Gather operation
x = torch.randn(3, 5, 4)
indices = torch.tensor([[0, 2], [1, 4], [0, 3]]) # (3, 2)
# What's torch.gather(x, 1, indices.unsqueeze(-1).expand(-1, -1, 4)).shape?
# Your prediction: (3, 2, 4)

expanded_indices = indices.unsqueeze(-1).expand(-1, -1, 4)
result = torch.gather(x, 1, expanded_indices)
check_shape(result, (3, 2, 4), 25)

In [17]:
# Q26: Masked select
x = torch.randn(4, 6)
mask = torch.randint(0, 2, (4, 6)).bool()
# What's x[mask].shape? (Note: result is 1D!)
# Your prediction: 
result = x[mask]
# Shape will be (N,) where N is number of True values in mask
print(f"Q26: Selected {result.shape[0]} elements from {mask.sum().item()} True mask values")
print(f"✅ Q26: Correct! Shape is {result.shape} (1D with {result.shape[0]} elements)")

Q26: Selected 8 elements from 8 True mask values
✅ Q26: Correct! Shape is torch.Size([8]) (1D with 8 elements)


In [None]:
# Q27: Repeat and tile
x = torch.randn(2, 3)
# What's x.repeat(4, 1, 2).shape?
# Your prediction: 

result = x.repeat(4, 1, 2)
check_shape(result, (4, 2, 6), 27)

In [None]:
# Q28: Advanced broadcasting
a = torch.randn(8, 1, 6, 1)
b = torch.randn(7, 1, 5)
# What's (a * b).shape?
# Your prediction: 

result = a * b
check_shape(result, (8, 7, 6, 5), 28)

In [None]:
# Q29: Unfold operation (sliding window)
x = torch.randn(1, 1, 10)  # (batch, channels, length)
# What's F.unfold(x.unsqueeze(-1), kernel_size=(3, 1), padding=(1, 0)).shape?
# Your prediction: 

x_2d = x.unsqueeze(-1)  # Make it (1, 1, 10, 1) for unfold
result = F.unfold(x_2d, kernel_size=(3, 1), padding=(1, 0))
check_shape(result, (1, 3, 10), 29)

In [None]:
# Q30: Complex reshape puzzle
x = torch.randn(2, 3, 4, 5, 6)
# Flatten last 3 dims, then swap first two dims
# What's x.view(2, 3, -1).transpose(0, 1).shape?
# Your prediction: 

flattened = x.view(2, 3, -1)  # Last 3 dims: 4*5*6 = 120
result = flattened.transpose(0, 1)
check_shape(result, (3, 2, 120), 30)

## Bonus: Common Interview Patterns

In [None]:
# Bonus 1: Implement scaled dot-product attention from scratch
def scaled_dot_product_attention(q, k, v, mask=None, dropout=None):
    """
    q, k, v: (batch, heads, seq_len, head_dim)
    mask: (seq_len, seq_len) or broadcastable
    """
    # Your implementation here - predict all intermediate shapes!
    d_k = q.size(-1)
    
    # Step 1: Compute attention scores
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    print(f"Scores shape: {scores.shape}")  # Should be (batch, heads, seq_len, seq_len)
    
    # Step 2: Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Step 3: Apply softmax
    attn_weights = F.softmax(scores, dim=-1)
    print(f"Attention weights shape: {attn_weights.shape}")  # Same as scores
    
    # Step 4: Apply dropout if provided
    if dropout is not None:
        attn_weights = dropout(attn_weights)
    
    # Step 5: Apply attention to values
    output = torch.matmul(attn_weights, v)
    print(f"Output shape: {output.shape}")  # Should be (batch, heads, seq_len, head_dim)
    
    return output, attn_weights

# Test it
batch, heads, seq_len, head_dim = 2, 8, 16, 64
q = torch.randn(batch, heads, seq_len, head_dim)
k = torch.randn(batch, heads, seq_len, head_dim)
v = torch.randn(batch, heads, seq_len, head_dim)

output, weights = scaled_dot_product_attention(q, k, v)
print(f"\nFinal shapes - Output: {output.shape}, Weights: {weights.shape}")

In [None]:
# Bonus 2: Batch processing with variable lengths
# Common interview scenario: handle sequences of different lengths

def create_padding_mask(lengths, max_len):
    """Create mask for variable length sequences"""
    batch_size = len(lengths)
    mask = torch.arange(max_len).expand(batch_size, max_len) < torch.tensor(lengths).unsqueeze(1)
    return mask

# Example: batch with sequences of lengths [5, 8, 3]
lengths = [5, 8, 3]
max_len = 10
mask = create_padding_mask(lengths, max_len)

print(f"Mask shape: {mask.shape}")  # Should be (3, 10)
print(f"Mask:\n{mask.int()}")

# Use mask to zero out padded positions
x = torch.randn(3, 10, 256)  # (batch, seq, features)
x_masked = x * mask.unsqueeze(-1).float()
print(f"Masked input shape: {x_masked.shape}")  # Should be (3, 10, 256)

## Summary

**Key Takeaways:**
1. Always write down tensor shapes as comments
2. Understand broadcasting rules (right-align dimensions)
3. Matrix multiplication: last dim of A = second-to-last dim of B
4. Practice predicting shapes before running code
5. Use `.shape` liberally when debugging

**Common Patterns:**
- Multi-head attention: `(batch, seq, dim) → (batch, heads, seq, head_dim)`
- Attention scores: `(batch, heads, seq_q, head_dim) @ (batch, heads, head_dim, seq_k) → (batch, heads, seq_q, seq_k)`
- Broadcasting: smaller tensors expand to match larger ones
- Masking: use `.masked_fill()` for attention masks

**Interview Tips:**
- Verbalize your shape reasoning out loud
- Start with simple examples and build up
- Don't be afraid to use `.shape` to verify
- Practice these patterns until they're automatic!