Trying to figure out how backpropagation works when parts of the model are no_grad

In [25]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
torch.autograd.set_detect_anomaly(True)
rng = np.random.default_rng()

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
dtype = torch.double

Using cpu device


In [26]:
class FeatExtract(nn.Module):
    def __init__(self): 
        super().__init__()  
        self.layers = nn.Sequential(nn.Conv1d(6, 10, 3, padding='same'),
                                    nn.ReLU(),
                                    nn.Conv1d(10, 12, 3, padding='same'),
                                    nn.ReLU())
    def forward(self, data):
        features = self.layers(data)
        return features

In [27]:
class Classifier(nn.Module):
    def __init__(self): 
        super().__init__()  
        self.layers = nn.Sequential(nn.LazyLinear(7),
                                   nn.ReLU(),
                                   nn.Linear(7,1),
                                   nn.ReLU())
    def forward(self, features):
        pred = self.layers(features)
        return pred

In [28]:
# training loop
def train(dataloader, feat_extractor, classifier, loss_fn, optimizer_feat, optimizer_cl):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        features = feat_extractor(X)
        pred = classifier(features.detach())
        loss = loss_fn(pred, y)

        # no minibatching for this example but must do for real model
        # train the discriminator
        print("training discriminator")
        optimizer_cl.zero_grad()
        loss = loss_fn(pred,y)
        loss.backward()
        optimizer_cl.step()
        for name, param in classifier.named_parameters():
            if param.requires_grad:
                print(name, param.data)
        # Backpropagation- feat extractor
        print("training feature extractor")
        optimizer_feat.zero_grad()
        pred = classifier(features)
        reverse_loss = loss_fn(pred, 1-y)  # loss with reversed labels (doesnt make sense with some 'real' data mixed in)
        reverse_loss.backward()
        optimizer_feat.step() 
        for name, param in classifier.named_parameters():
            if param.requires_grad:
                print(name, param.data)
        

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [29]:
# mock data
dummy_data_np = rng.random((15,6,1))
dummy_cats_np = rng.integers(0, 1, size=(15,1), endpoint=True)
dummy_data = torch.from_numpy(dummy_data_np).to(device, dtype=dtype)
dummy_cats = torch.from_numpy(dummy_cats_np).to(device, dtype=torch.long)
dummy_dset = TensorDataset(dummy_data, dummy_cats)
dummy_dl = DataLoader(dummy_dset, batch_size=1)

In [30]:
# models
feat = FeatExtract().to(device=device, dtype=torch.double)
clas = Classifier().to(device=device, dtype=torch.double)

In [31]:
# loss function and optimizers
loss_fn = nn.CrossEntropyLoss()
opt_feat = torch.optim.Adam(feat.parameters(),lr=0.01)
opt_clas = torch.optim.Adam(clas.parameters(),lr=0.01)

In [32]:
# train
train(dummy_dl, feat, clas, loss_fn, opt_feat, opt_clas)

training discriminator
layers.0.weight tensor([[-0.9151],
        [-0.2637],
        [ 0.9647],
        [ 0.6351],
        [ 0.4071],
        [-0.8933],
        [ 0.6762]], dtype=torch.float64)
layers.0.bias tensor([ 0.4316, -0.4340, -0.7893,  0.8036,  0.7087, -0.6917, -0.6832],
       dtype=torch.float64)
layers.2.weight tensor([[ 0.3641,  0.3440,  0.2790,  0.0856, -0.2187,  0.1508, -0.2734]],
       dtype=torch.float64)
layers.2.bias tensor([0.3318], dtype=torch.float64)
training feature extractor
layers.0.weight tensor([[-0.9151],
        [-0.2637],
        [ 0.9647],
        [ 0.6351],
        [ 0.4071],
        [-0.8933],
        [ 0.6762]], dtype=torch.float64)
layers.0.bias tensor([ 0.4316, -0.4340, -0.7893,  0.8036,  0.7087, -0.6917, -0.6832],
       dtype=torch.float64)
layers.2.weight tensor([[ 0.3641,  0.3440,  0.2790,  0.0856, -0.2187,  0.1508, -0.2734]],
       dtype=torch.float64)
layers.2.bias tensor([0.3318], dtype=torch.float64)
loss: 2.474764  [    1/   15]
training d