In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(42)

<torch._C.Generator at 0x7d54ca8108d0>

In [2]:
def make_model():
    return nn.Linear(4,2, bias=True)

In [3]:
models = [make_model(), make_model()]
opt = optim.Adam(models[0].parameters(), lr=0.1)
loss_fn = nn.MSELoss()

# dummy input batches for each replica
x1, y1 = torch.randn(5,4), torch.randn(5,2)
x2, y2 = torch.randn(5,4), torch.randn(5,2) 

In [4]:
for m in models:
    print(f"=={m}===")
    for name, param in m.named_parameters():
        print(f"  {name}: shape {param.shape}, requires_grad={param.requires_grad}")
        print(param.data)
        print(param.grad)


    # print(m.weight)
    # print(m.bias)
    # print(m.weight.grad)
    # print(m.bias.grad)


==Linear(in_features=4, out_features=2, bias=True)===
  weight: shape torch.Size([2, 4]), requires_grad=True
tensor([[ 0.3823,  0.4150, -0.1171,  0.4593],
        [-0.1096,  0.1009, -0.2434,  0.2936]])
None
  bias: shape torch.Size([2]), requires_grad=True
tensor([ 0.4408, -0.3668])
None
==Linear(in_features=4, out_features=2, bias=True)===
  weight: shape torch.Size([2, 4]), requires_grad=True
tensor([[ 0.4346,  0.0936,  0.3694,  0.0677],
        [ 0.2411, -0.0706,  0.3854,  0.0739]])
None
  bias: shape torch.Size([2]), requires_grad=True
tensor([-0.2334,  0.1274])
None


In [5]:
for m, (x, y) in zip(models, [(x1, y1), (x2, y2)]):
    out = m(x)
    loss = loss_fn(out, y)
    loss.backward()

In [6]:
for m in models:
    print(f"=={m}===")
    for name, param in m.named_parameters():
        print(f"  {name}: shape {param.shape}, requires_grad={param.requires_grad}")
        print(param.data)
        print(param.grad)

==Linear(in_features=4, out_features=2, bias=True)===
  weight: shape torch.Size([2, 4]), requires_grad=True
tensor([[ 0.3823,  0.4150, -0.1171,  0.4593],
        [-0.1096,  0.1009, -0.2434,  0.2936]])
tensor([[-0.5501,  0.8377, -0.0246,  1.3850],
        [-0.7417, -0.0662, -0.2327,  0.9729]])
  bias: shape torch.Size([2]), requires_grad=True
tensor([ 0.4408, -0.3668])
tensor([ 0.9726, -0.4391])
==Linear(in_features=4, out_features=2, bias=True)===
  weight: shape torch.Size([2, 4]), requires_grad=True
tensor([[ 0.4346,  0.0936,  0.3694,  0.0677],
        [ 0.2411, -0.0706,  0.3854,  0.0739]])
tensor([[ 0.0956,  0.0095, -0.3197,  0.0189],
        [-0.1313, -0.1836,  0.0369, -0.1920]])
  bias: shape torch.Size([2]), requires_grad=True
tensor([-0.2334,  0.1274])
tensor([-0.2163, -0.3139])


In [7]:
for i, params in enumerate(models[0].parameters()):
    print(f"i: {i} params: {params}")

i: 0 params: Parameter containing:
tensor([[ 0.3823,  0.4150, -0.1171,  0.4593],
        [-0.1096,  0.1009, -0.2434,  0.2936]], requires_grad=True)
i: 1 params: Parameter containing:
tensor([ 0.4408, -0.3668], requires_grad=True)


In [8]:
num_params = len(list(models[0].parameters()))
print(num_params)

num_models = len(models)

2


In [9]:
models_with_parameters = []

for m in models:
    params = list(m.parameters())
    models_with_parameters.append(params)

for i in range(num_params):

    new_grad = models_with_parameters[0][i].grad
    for j in range(1, num_models):
        new_grad += models_with_parameters[j][i].grad

    new_grad /= num_models

    for j in range(num_models):
        models_with_parameters[j][i].grad = new_grad




In [10]:
for m in models:
    print(f"=={m}===")
    for name, param in m.named_parameters():
        print(f"  {name}: shape {param.shape}, requires_grad={param.requires_grad}")
        print(param.data)
        print(param.grad)

==Linear(in_features=4, out_features=2, bias=True)===
  weight: shape torch.Size([2, 4]), requires_grad=True
tensor([[ 0.3823,  0.4150, -0.1171,  0.4593],
        [-0.1096,  0.1009, -0.2434,  0.2936]])
tensor([[-0.2273,  0.4236, -0.1722,  0.7019],
        [-0.4365, -0.1249, -0.0979,  0.3905]])
  bias: shape torch.Size([2]), requires_grad=True
tensor([ 0.4408, -0.3668])
tensor([ 0.3782, -0.3765])
==Linear(in_features=4, out_features=2, bias=True)===
  weight: shape torch.Size([2, 4]), requires_grad=True
tensor([[ 0.4346,  0.0936,  0.3694,  0.0677],
        [ 0.2411, -0.0706,  0.3854,  0.0739]])
tensor([[-0.2273,  0.4236, -0.1722,  0.7019],
        [-0.4365, -0.1249, -0.0979,  0.3905]])
  bias: shape torch.Size([2]), requires_grad=True
tensor([-0.2334,  0.1274])
tensor([ 0.3782, -0.3765])


In [11]:
x = torch.randn(2,4)
y = x.clone()


print(x)
x += 1

print(x)
print(y)

tensor([[-1.6270, -1.3951, -0.2387, -0.5050],
        [-2.4752, -0.9316, -0.1335,  0.3415]])
tensor([[-0.6270, -0.3951,  0.7613,  0.4950],
        [-1.4752,  0.0684,  0.8665,  1.3415]])
tensor([[-1.6270, -1.3951, -0.2387, -0.5050],
        [-2.4752, -0.9316, -0.1335,  0.3415]])


## Stack Approach - Basic Example

In [12]:
# First, let's understand torch.stack with a simple example
# Stack takes a list of tensors and stacks them into a new dimension

tensor1 = torch.tensor([1.0, 2.0, 3.0])
tensor2 = torch.tensor([4.0, 5.0, 6.0])
tensor3 = torch.tensor([7.0, 8.0, 9.0])

print("Original tensors:")
print(f"tensor1: {tensor1}")
print(f"tensor2: {tensor2}")
print(f"tensor3: {tensor3}")

# Stack them along dimension 0 (creates a new dimension)
stacked = torch.stack([tensor1, tensor2, tensor3])
print(f"\nStacked shape: {stacked.shape}")  # Should be [3, 3]
print(f"Stacked tensor:\n{stacked}")

# Now compute mean along the first dimension (across tensors)
averaged = stacked.mean(dim=0)
print(f"\nAveraged (mean along dim 0): {averaged}")
print(f"Manual check: [(1+4+7)/3, (2+5+8)/3, (3+6+9)/3] = {torch.tensor([4.0, 5.0, 6.0])}")

Original tensors:
tensor1: tensor([1., 2., 3.])
tensor2: tensor([4., 5., 6.])
tensor3: tensor([7., 8., 9.])

Stacked shape: torch.Size([3, 3])
Stacked tensor:
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

Averaged (mean along dim 0): tensor([4., 5., 6.])
Manual check: [(1+4+7)/3, (2+5+8)/3, (3+6+9)/3] = tensor([4., 5., 6.])


In [13]:
torch.manual_seed(42)

# Now apply this to gradient averaging
# Create fresh models with gradients
models_stack = [make_model(), make_model()]

for m, (x, y) in zip(models_stack, [(x1, y1), (x2, y2)]):
    out = m(x)
    loss = loss_fn(out, y)
    loss.backward()

print("Before averaging:")
print(f"Model 0 weight.grad shape: {models_stack[0].weight.grad.shape}")
print(f"Model 0 weight.grad[0, :3]: {models_stack[0].weight.grad[0, :3]}")
print(f"Model 1 weight.grad[0, :3]: {models_stack[1].weight.grad[0, :3]}")
print(f"\nModel 0 bias.grad: {models_stack[0].bias.grad}")
print(f"Model 1 bias.grad: {models_stack[1].bias.grad}")

Before averaging:
Model 0 weight.grad shape: torch.Size([2, 4])
Model 0 weight.grad[0, :3]: tensor([-0.5501,  0.8377, -0.0246])
Model 1 weight.grad[0, :3]: tensor([ 0.0956,  0.0095, -0.3197])

Model 0 bias.grad: tensor([ 0.9726, -0.4391])
Model 1 bias.grad: tensor([-0.2163, -0.3139])


In [14]:
# Stack approach for weight parameter
# Collect gradients from both models
grad0 = models_stack[0].weight.grad  # Shape: [2, 4]
grad1 = models_stack[1].weight.grad  # Shape: [2, 4]

print("Individual gradients:")
print(f"grad0 shape: {grad0.shape}")
print(f"grad1 shape: {grad1.shape}")

Individual gradients:
grad0 shape: torch.Size([2, 4])
grad1 shape: torch.Size([2, 4])


In [15]:
# Stack them - creates a new dimension at position 0
stacked_grads = torch.stack([grad0, grad1])
print(f"\nStacked shape: {stacked_grads.shape}")  # Should be [2, 2, 4]
print(f"Stacked[0] is grad0: {torch.equal(stacked_grads[0], grad0)}")
print(f"Stacked[1] is grad1: {torch.equal(stacked_grads[1], grad1)}")




Stacked shape: torch.Size([2, 2, 4])
Stacked[0] is grad0: True
Stacked[1] is grad1: True


In [28]:
stacked_grads

tensor([[[-0.5501,  0.8377, -0.0246,  1.3850],
         [-0.7417, -0.0662, -0.2327,  0.9729]],

        [[ 0.0956,  0.0095, -0.3197,  0.0189],
         [-0.1313, -0.1836,  0.0369, -0.1920]]])

In [16]:
# Average along dimension 0 (across models)
avg_grad = stacked_grads.mean(dim=0)
print(f"\nAveraged gradient shape: {avg_grad.shape}")  # Back to [2, 4]
print(f"Averaged gradient[0, :3]: {avg_grad[0, :3]}")

# Verify it's actually the mean
manual_avg = (grad0 + grad1) / 2
print(f"\nManual average[0, :3]: {manual_avg[0, :3]}")
print(f"Results match: {torch.allclose(avg_grad, manual_avg)}")


Averaged gradient shape: torch.Size([2, 4])
Averaged gradient[0, :3]: tensor([-0.2273,  0.4236, -0.1722])

Manual average[0, :3]: tensor([-0.2273,  0.4236, -0.1722])
Results match: True


In [17]:
# Complete implementation using stack approach
def average_gradients_stack(models):
    """Average gradients across models using torch.stack"""
    num_params = len(list(models[0].parameters()))
    
    for param_idx in range(num_params):
        # Collect all gradients for this parameter across all models
        grads = [list(m.parameters())[param_idx].grad for m in models]
        
        # Stack them into a single tensor and compute mean
        avg_grad = torch.stack(grads).mean(dim=0)
        
        # Assign averaged gradient back to all models
        for m in models:
            list(m.parameters())[param_idx].grad = avg_grad

# Apply to our models
average_gradients_stack(models_stack)

print("After averaging (stack approach):")
print(f"Model 0 weight.grad[0, :3]: {models_stack[0].weight.grad[0, :3]}")
print(f"Model 1 weight.grad[0, :3]: {models_stack[1].weight.grad[0, :3]}")
print(f"\nModel 0 bias.grad: {models_stack[0].bias.grad}")
print(f"Model 1 bias.grad: {models_stack[1].bias.grad}")
print("\nBoth models now have identical gradients!")

After averaging (stack approach):
Model 0 weight.grad[0, :3]: tensor([-0.2273,  0.4236, -0.1722])
Model 1 weight.grad[0, :3]: tensor([-0.2273,  0.4236, -0.1722])

Model 0 bias.grad: tensor([ 0.3782, -0.3765])
Model 1 bias.grad: tensor([ 0.3782, -0.3765])

Both models now have identical gradients!


In [25]:
torch.manual_seed(42)

# Now apply this to gradient averaging
# Create fresh models with gradients
models_stack = [make_model(), make_model()]

for m, (x, y) in zip(models_stack, [(x1, y1), (x2, y2)]):
    out = m(x)
    loss = loss_fn(out, y)
    loss.backward()

print("Before averaging:")
print(f"Model 0 weight.grad shape: {models_stack[0].weight.grad.shape}")
print(f"Model 0 weight.grad[0, :3]: {models_stack[0].weight.grad[0, :3]}")
print(f"Model 1 weight.grad[0, :3]: {models_stack[1].weight.grad[0, :3]}")
print(f"\nModel 0 bias.grad: {models_stack[0].bias.grad}")
print(f"Model 1 bias.grad: {models_stack[1].bias.grad}")

Before averaging:
Model 0 weight.grad shape: torch.Size([2, 4])
Model 0 weight.grad[0, :3]: tensor([-0.5501,  0.8377, -0.0246])
Model 1 weight.grad[0, :3]: tensor([ 0.0956,  0.0095, -0.3197])

Model 0 bias.grad: tensor([ 0.9726, -0.4391])
Model 1 bias.grad: tensor([-0.2163, -0.3139])


In [27]:
# Complete implementation using stack approach
def average_gradients_stack_2(models):
    """Average gradients across models using torch.stack"""
    num_params = len(list(models[0].parameters()))
    
    for param_idx in range(num_params):

        # pid_shape = list(models[0].parameters())[param_idx].grad.shape
        # print(pid_shape)

        init = torch.zeros_like(list(models[0].parameters())[param_idx].grad)

        for m in models:
            init += list(m.parameters())[param_idx].grad

        init /= len(models)
       
        # Assign averaged gradient back to all models
        for m in models:
            list(m.parameters())[param_idx].grad = init

# Apply to our models
average_gradients_stack_2(models_stack)

print("After averaging (stack approach):")
print(f"Model 0 weight.grad[0, :3]: {models_stack[0].weight.grad[0, :3]}")
print(f"Model 1 weight.grad[0, :3]: {models_stack[1].weight.grad[0, :3]}")
print(f"\nModel 0 bias.grad: {models_stack[0].bias.grad}")
print(f"Model 1 bias.grad: {models_stack[1].bias.grad}")

is_equal = torch.equal(models_stack[0].bias.grad, models_stack[1].bias.grad)
print(f"models have identical gradients?: {is_equal}")

After averaging (stack approach):
Model 0 weight.grad[0, :3]: tensor([-0.2273,  0.4236, -0.1722])
Model 1 weight.grad[0, :3]: tensor([-0.2273,  0.4236, -0.1722])

Model 0 bias.grad: tensor([ 0.3782, -0.3765])
Model 1 bias.grad: tensor([ 0.3782, -0.3765])
models have identical gradients?: True


## torch.equal() Example

In [20]:
# torch.equal() checks if two tensors have the same shape and elements
# Returns True/False (single boolean value)

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
c = torch.tensor([1, 2, 4])  # Different last element
d = torch.tensor([1, 2])     # Different shape

print("Tensors:")
print(f"a = {a}")
print(f"b = {b}")
print(f"c = {c}")
print(f"d = {d}")

print("\nComparisons:")
print(f"torch.equal(a, b) = {torch.equal(a, b)}")  # True - same shape and values
print(f"torch.equal(a, c) = {torch.equal(a, c)}")  # False - different values
print(f"torch.equal(a, d) = {torch.equal(a, d)}")  # False - different shapes

# Works with any tensor shape
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
z = x.clone()  # Creates a copy

print(f"\ntorch.equal(x, y) = {torch.equal(x, y)}")  # True
print(f"torch.equal(x, z) = {torch.equal(x, z)}")  # True - clone creates exact copy

Tensors:
a = tensor([1, 2, 3])
b = tensor([1, 2, 3])
c = tensor([1, 2, 4])
d = tensor([1, 2])

Comparisons:
torch.equal(a, b) = True
torch.equal(a, c) = False
torch.equal(a, d) = False

torch.equal(x, y) = True
torch.equal(x, z) = True


In [21]:
# Contrast with == operator (element-wise comparison)
# == returns a tensor of booleans, torch.equal returns single boolean

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 4])

print("Using == (element-wise):")
print(f"a == b = {a == b}")  # Returns tensor([True, True, False])

print("\nUsing torch.equal (whole tensor):")
print(f"torch.equal(a, b) = {torch.equal(a, b)}")  # Returns False

# To check if all elements match using ==, you need .all()
print(f"\n(a == b).all() = {(a == b).all()}")  # False - same as torch.equal

Using == (element-wise):
a == b = tensor([ True,  True, False])

Using torch.equal (whole tensor):
torch.equal(a, b) = False

(a == b).all() = False


## torch.zeros() vs torch.zeros_like()

In [22]:
# torch.zeros() - You manually specify the shape
# torch.zeros_like() - Automatically matches shape of another tensor

# Create a reference tensor
reference = torch.tensor([[1.0, 2.0, 3.0], 
                          [4.0, 5.0, 6.0]])
print(f"Reference tensor shape: {reference.shape}")
print(f"Reference tensor dtype: {reference.dtype}")
print(f"Reference:\n{reference}\n")

# Method 1: torch.zeros() - Manual specification
zeros_manual = torch.zeros(2, 3)  # Must specify shape explicitly
print(f"torch.zeros(2, 3):")
print(f"  Shape: {zeros_manual.shape}")
print(f"  Dtype: {zeros_manual.dtype}")
print(f"  Values:\n{zeros_manual}\n")

# Method 2: torch.zeros_like() - Automatic from reference
zeros_auto = torch.zeros_like(reference)  # Automatically gets shape & dtype
print(f"torch.zeros_like(reference):")
print(f"  Shape: {zeros_auto.shape}")  # Matches reference!
print(f"  Dtype: {zeros_auto.dtype}")  # Matches reference!
print(f"  Values:\n{zeros_auto}")

Reference tensor shape: torch.Size([2, 3])
Reference tensor dtype: torch.float32
Reference:
tensor([[1., 2., 3.],
        [4., 5., 6.]])

torch.zeros(2, 3):
  Shape: torch.Size([2, 3])
  Dtype: torch.float32
  Values:
tensor([[0., 0., 0.],
        [0., 0., 0.]])

torch.zeros_like(reference):
  Shape: torch.Size([2, 3])
  Dtype: torch.float32
  Values:
tensor([[0., 0., 0.],
        [0., 0., 0.]])


In [24]:
torch.zeros(reference.shape)

tensor([[0., 0., 0.],
        [0., 0., 0.]])

In [None]:
# Key advantage: zeros_like() also preserves dtype and device

# Integer tensor
int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int32)
zeros_int = torch.zeros_like(int_tensor)
print(f"int_tensor dtype: {int_tensor.dtype}")
print(f"zeros_like dtype: {zeros_int.dtype}")  # Also int32!
print(f"zeros_int: {zeros_int}\n")

# Float16 tensor (half precision)
half_tensor = torch.tensor([[1.0, 2.0]], dtype=torch.float16)
zeros_half = torch.zeros_like(half_tensor)
print(f"half_tensor dtype: {half_tensor.dtype}")
print(f"zeros_like dtype: {zeros_half.dtype}")  # Also float16!

# If you used zeros(), you'd have to manually specify dtype
zeros_manual_wrong = torch.zeros(1, 2)  # Defaults to float32
print(f"\ntorch.zeros(1, 2) dtype: {zeros_manual_wrong.dtype}")  # float32 (wrong!)

# You'd need to do this instead:
zeros_manual_correct = torch.zeros(1, 2, dtype=torch.float16)
print(f"torch.zeros(1, 2, dtype=torch.float16) dtype: {zeros_manual_correct.dtype}")

In [None]:
# Why use zeros_like() in gradient averaging?
# It ensures the accumulator matches the gradient's properties exactly

grad_example = models_stack[0].weight.grad  # Shape [2, 4], dtype float32
print(f"Gradient shape: {grad_example.shape}")
print(f"Gradient dtype: {grad_example.dtype}")

# Using zeros_like() - automatically correct
accumulator = torch.zeros_like(grad_example)
print(f"\nAccumulator (zeros_like) shape: {accumulator.shape}")
print(f"Accumulator (zeros_like) dtype: {accumulator.dtype}")

# Now we can safely add gradients
accumulator += grad_example
print(f"\nAfter adding gradient:")
print(f"Accumulator[0, :3]: {accumulator[0, :3]}")

# This is safer than manually specifying:
# accumulator = torch.zeros(2, 4)  # What if shape changes? What about dtype?