In [35]:
import torch
import torch.nn as nn
import copy
torch.manual_seed(0)

<torch._C.Generator at 0x7fd400058eb0>

In [43]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv0 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv1 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        
        
    def forward(self, x):
        
        aux = x.detach().clone()
        aux0 = self.conv0(aux)
        x = self.conv0(x)
        
        aux = x.detach().clone()
        aux1 = self.conv1(aux)
        x = self.conv1(x)
        
        aux = x.detach().clone()
        aux2 = self.conv2(aux)
        x = self.conv2(x)
        
        
        aux_losses = [torch.mean(a) for a in [aux0, aux1, aux2]]
        primary_loss = torch.mean(x)
        
        return primary_loss, aux_losses
    
    def print_norms(self):
        norms = ""
        for i, conv in enumerate([self.conv0, self.conv1, self.conv2]):
            norms += "  Conv{}: ".format(i) + str(torch.norm(conv.weight, p='fro').item())
        print(norms)
    
    def print_grad_norms(self):
        norms = ""
        for i, conv in enumerate([self.conv0, self.conv1, self.conv2]):
            norm = torch.norm(conv.weight.grad, p='fro').item() if conv.weight.grad is not None else "None"
            norms += "  Conv{}: ".format(i) + str(norm)
        print(norms)
            
        

In [44]:
input = torch.rand((128, 3, 32, 32))

In [45]:
net1 = Net()
net2 = copy.deepcopy(net1)

optimizer1 = torch.optim.SGD(net1.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
optimizer2 = torch.optim.SGD(net2.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)

In [46]:

for i in range(3):
    primary_loss1, aux_losses1 = net1(input)
    primary_loss2, aux_losses2 = net2(input)
    
    
    trained_loss1 = torch.mean(torch.stack(aux_losses1))
    optimizer1.zero_grad()
    trained_loss1.backward()
    optimizer1.step()
    
    
    num_losses = len(aux_losses2)
    for j, loss in enumerate(aux_losses2):
        loss = loss / num_losses
        
        print("Stepping loss {}".format(j))
        
        optimizer2.zero_grad()
        loss.backward()
        optimizer2.step()
        
    print("NET 1:")
    net1.print_norms()
    net1.print_grad_norms()
    
    print("NET 2:")
    net2.print_norms()
    net2.print_grad_norms()
    
    
    

Stepping loss {}
Stepping loss {}
Stepping loss {}
NET 1:
  Conv0: 2.3477413654327393  Conv1: 3.243257761001587  Conv2: 4.63417387008667
  Conv0: 0.20765216648578644  Conv1: 0.18426339328289032  Conv2: 0.09687289595603943
NET 2:
  Conv0: 2.3463075160980225  Conv1: 3.2427210807800293  Conv2: 4.63417387008667
  Conv0: 0.0  Conv1: 0.0  Conv2: 0.09687289595603943
Stepping loss {}
Stepping loss {}


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [32, 16, 3, 3]] is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).