In [1]:
import torch
import torch.nn as nn

# Create a simple linear layer
linear_layer = nn.Linear(10, 1)  # Example dimensions
x_test = torch.randn(5, 10, requires_grad=True)  # Example input

# Forward pass
output_test = linear_layer(x_test)

# Check requires_grad status
print("Output requires_grad:", output_test.requires_grad)  # Should be True

# Try a dummy backward pass
output_test.sum().backward()
print("Gradient for x_test:", x_test.grad)  # Should not be None


Output requires_grad: True
Gradient for x_test: tensor([[-0.2578,  0.2976, -0.0466,  0.1073,  0.0302,  0.2783, -0.0696, -0.1360,
         -0.1836,  0.0608],
        [-0.2578,  0.2976, -0.0466,  0.1073,  0.0302,  0.2783, -0.0696, -0.1360,
         -0.1836,  0.0608],
        [-0.2578,  0.2976, -0.0466,  0.1073,  0.0302,  0.2783, -0.0696, -0.1360,
         -0.1836,  0.0608],
        [-0.2578,  0.2976, -0.0466,  0.1073,  0.0302,  0.2783, -0.0696, -0.1360,
         -0.1836,  0.0608],
        [-0.2578,  0.2976, -0.0466,  0.1073,  0.0302,  0.2783, -0.0696, -0.1360,
         -0.1836,  0.0608]])


In [2]:
class Preprocessor(torch.nn.Module):
    def __init__(self, mean=0., std=1., dist_std=1.):
        super().__init__()
        self.register_buffer('mean', torch.tensor(mean, dtype=torch.float32), persistent=True)
        self.register_buffer('std', torch.tensor(std, dtype=torch.float32), persistent=True)
        self.register_buffer('dist_std', torch.tensor(dist_std, dtype=torch.float32), persistent=True)

    def normalize(self, x):
        return (x - self.mean) / self.std
    
    def normalize_dist(self, d):
        return d / self.dist_std
    
    def unnormalize(self, x):
        return x * self.std + self.mean
    
    def unnormalize_dist(self, d):
        return d * self.dist_std
    
    def get_params(self):
        return dict(
            mean=self.mean,
            std=self.std,
            dist_std=self.dist_std
        )

In [3]:
proc = Preprocessor()

In [5]:
x_test = torch.randn(5, 10, requires_grad=True)  # Example input

# Forward pass
output_test = proc.normalize(x_test)

# Check requires_grad status
print("Output requires_grad:", output_test.requires_grad)  # Should be True


Output requires_grad: True
