<a href="https://colab.research.google.com/github/MeghanaShanthappa/TensorRT_features/blob/main/Parametrization_regularization_tech.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
from torch.nn.utils import parametrize

torch.manual_seed(42)

# -------------------------------
# Helper: simple training step
# -------------------------------
def train_step(layer, x, target, optimizer):
    optimizer.zero_grad()
    output = layer(x)
    loss = ((output - target) ** 2).mean()
    loss.backward()
    optimizer.step()

    # Determine which weight and gradient to return
    if hasattr(layer, 'parametrizations') and 'weight' in layer.parametrizations:
        # For parametrized layers, return the gradient of the 'original' parameter
        # and the effective weight (layer.weight)
        # FIX: Access .original on the ParametrizationList, not the WeightNorm instance
        param_to_track = layer.parametrizations.weight.original
        grad_val = param_to_track.grad.clone() if param_to_track.grad is not None else None
        weight_val = layer.weight.clone() # This is the effective weight (output of parametrization)
    else:
        # For standard layers, return the gradient of layer.weight and layer.weight itself
        grad_val = layer.weight.grad.clone() if layer.weight.grad is not None else None
        weight_val = layer.weight.clone()

    return grad_val, weight_val

# -------------------------------
# Inputs
# -------------------------------
x = torch.randn(64, 5)
target = torch.randn(64, 3)

# -------------------------------
# 1. Standard Linear layer
# -------------------------------
layer_plain = nn.Linear(5, 3)
optimizer_plain = torch.optim.SGD(layer_plain.parameters(), lr=0.1)

grad_plain, weight_plain = train_step(layer_plain, x, target, optimizer_plain)

print("Standard Linear Layer")
print("Gradient:\n", grad_plain)
print("Weight after step:\n", weight_plain)

# -------------------------------
# 2. Weight-normalized Linear layer
# -------------------------------
class WeightNorm(nn.Module):
    def forward(self, W):
        norm = W.norm(dim=1, keepdim=True) + 1e-6
        return W / norm

layer_wn = nn.Linear(5, 3)
parametrize.register_parametrization(layer_wn, "weight", WeightNorm())
optimizer_wn = torch.optim.SGD(layer_wn.parameters(), lr=0.1)

grad_wn, weight_wn = train_step(layer_wn, x, target, optimizer_wn)

print("\nWeight-normalized Linear Layer")
print("Gradient:\n", grad_wn)
print("Weight after step:\n", weight_wn)

# Effective weight actually used in forward pass
effective_weight = layer_wn.parametrizations.weight[0](layer_wn.parametrizations.weight.original)
print("\nEffective weight (after normalization):\n", effective_weight)

Standard Linear Layer
Gradient:
 tensor([[ 0.1911, -0.2900, -0.0938, -0.2961,  0.2978],
        [-0.2728,  0.1792, -0.1200,  0.0270,  0.2592],
        [ 0.0795,  0.3142,  0.1466, -0.1204,  0.2307]])
Weight after step:
 tensor([[ 0.1201, -0.3122, -0.2736, -0.3422,  0.3599],
        [-0.3962,  0.3194, -0.0030, -0.0547,  0.2538],
        [-0.0493,  0.2491,  0.3087, -0.3762,  0.1490]],
       grad_fn=<CloneBackward0>)

Weight-normalized Linear Layer
Gradient:
 tensor([[ 0.0719, -0.0212,  0.2506, -0.0023,  0.0066],
        [-0.1356, -0.0713, -0.2224,  0.0454, -0.0388],
        [-0.0923,  0.1113,  0.0346,  0.2936, -0.0983]])
Weight after step:
 tensor([[ 0.1986, -0.3675, -0.1419,  0.7334,  0.5172],
        [-0.0051, -0.6511,  0.0659, -0.4867,  0.5787],
        [-0.5055, -0.0837, -0.1980,  0.0968,  0.8300]],
       grad_fn=<CloneBackward0>)

Effective weight (after normalization):
 tensor([[ 0.1986, -0.3675, -0.1419,  0.7334,  0.5172],
        [-0.0051, -0.6511,  0.0659, -0.4867,  0.5787],
  

In [16]:
import torch.nn as nn
import torch.nn.utils.parametrizations as parametrize

m = parametrize.weight_norm(nn.Linear(20, 40), name='weight')
print(m)

print("\nAttributes of m:")
print(dir(m))

print("\nm.parametrizations.weight:")
print(m.parametrizations.weight)

print("\nNamed parameters of m:")
for name, param in m.named_parameters():
    print(f"  {name}: {param.size()}")

# Based on the output of named_parameters, we can then correctly access g and v.
weight_g = m.parametrizations.weight.original0
weight_v = m.parametrizations.weight.original1

print("\nm.weight_g.size():", weight_g.size())
print("m.weight_v.size():", weight_v.size())

ParametrizedLinear(
  in_features=20, out_features=40, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _WeightNorm()
    )
  )
)

Attributes of m:
['T_destination', '__annotations__', '__call__', '__class__', '__constants__', '__deepcopy__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_