In [35]:
import torch
import torch.nn as nn
import copy
torch.manual_seed(0)

<torch._C.Generator at 0x7fd400058eb0>

In [63]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv0 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv1 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        
        
    def forward(self, x):
        
        aux = x.detach().clone()
        aux0 = self.conv0(aux)
        x = self.conv0(x)
        
        aux = x.detach().clone()
        aux1 = self.conv1(aux)
        x = self.conv1(x)
        
        aux = x.detach().clone()
        aux2 = self.conv2(aux)
        x = self.conv2(x)
        
        
        aux_losses = [torch.mean(a) for a in [aux0, aux1, aux2]]
        primary_loss = torch.mean(x)
        
        return primary_loss, aux_losses
    
    def forward_iterate(self, x):
        
        for conv in [self.conv0, self.conv1, self.conv2]:
            aux = x.detach().clone()
            aux_loss = torch.mean(conv(aux))
            x = conv(x)
            yield aux_loss
    
    
    def print_norms(self):
        norms = ""
        for i, conv in enumerate([self.conv0, self.conv1, self.conv2]):
            norms += "  Conv{}: ".format(i) + str(torch.norm(conv.weight, p='fro').item())
        print(norms)
    
    def print_grad_norms(self):
        norms = ""
        for i, conv in enumerate([self.conv0, self.conv1, self.conv2]):
            norm = torch.norm(conv.weight.grad, p='fro').item() if conv.weight.grad is not None else "None"
            norms += "  Conv{}: ".format(i) + str(norm)
        print(norms)
            
        

In [48]:
input = torch.rand((128, 3, 32, 32))

In [64]:
net1 = Net()
net2 = copy.deepcopy(net1)

optimizer1 = torch.optim.SGD(net1.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
optimizer2 = torch.optim.SGD(net2.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)

In [65]:

torch.autograd.set_detect_anomaly(True)

for i in range(3):
    primary_loss1, aux_losses1 = net1(input)

    trained_loss1 = torch.mean(torch.stack(aux_losses1))
    optimizer1.zero_grad()
    trained_loss1.backward()
    optimizer1.step()
    
    
    num_losses = 3
    for j, loss in enumerate(net2.forward_iterate(input)):
        loss = loss / num_losses
        
        print("Stepping loss {}".format(j))
        
        for p in net2.parameters():
            p.grad = None
            
        loss.backward()
        optimizer2.step()
        
    print("NET 1:")
    net1.print_norms()
    net1.print_grad_norms()
    
    print("NET 2:")
    net2.print_norms()
    net2.print_grad_norms()
    
    
    

Stepping loss 0
Stepping loss 1
Stepping loss 2
NET 1:
  Conv0: 2.326932191848755  Conv1: 3.2559821605682373  Conv2: 4.590390205383301
  Conv0: 0.20728228986263275  Conv1: 0.18254031240940094  Conv2: 0.10456957668066025
NET 2:
  Conv0: 2.326932191848755  Conv1: 3.2559821605682373  Conv2: 4.590390205383301
  Conv0: None  Conv1: None  Conv2: 0.10456957668066025
Stepping loss 0
Stepping loss 1
Stepping loss 2
NET 1:
  Conv0: 2.327664613723755  Conv1: 3.255280017852783  Conv2: 4.589842319488525
  Conv0: 0.20728228986263275  Conv1: 0.18561454117298126  Conv2: 0.10101107507944107
NET 2:
  Conv0: 2.327664613723755  Conv1: 3.255280017852783  Conv2: 4.589842319488525
  Conv0: None  Conv1: None  Conv2: 0.10101107507944107
Stepping loss 0
Stepping loss 1
Stepping loss 2
NET 1:
  Conv0: 2.3305435180664062  Conv1: 3.2556495666503906  Conv2: 4.58942985534668
  Conv0: 0.20728228986263275  Conv1: 0.19335055351257324  Conv2: 0.10289294272661209
NET 2:
  Conv0: 2.3305435180664062  Conv1: 3.2556495666503

In [None]:
No aux path: 3443
