# Regularization Tutorial
NeuroVisKit aims to reduce the burden of using explicit regularizations by modularizing and automating the manual steps involved.
### Minimal Usage
First We will show a minimal example of using regularization in a model. A popular type of regularization is energy or L2 regularization, so we will implement that.

In [1]:
import NeuroVisKit.utils.regularization as reg
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10) #arbitrary model
        self.energy_reg = reg.l2(1e-3, target=self.fc1.weight) #this will add 1e-3 weighted reg term to the loss
model = Model()

#calculate penalty every training step
penalty = model.energy_reg()
print(f"L2 penalty is {penalty.item():.3f}")

L2 penalty is 0.003


Traditionally, the regularization penalty is manually added to the loss. However, NeuroVisKit allows for automation using our PytorchWrapper. PytorchWrapper automatically extracts all regularization modules from a model.

Regular regularization penalties are then automatically added to the loss, the optimizer should step, and then proximal regularization should take place. For more information in proximal regularization, I recommend checking out the [wikipedia](https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning) page or looking it up.

In [2]:
from NeuroVisKit.utils import PytorchWrapper
wrapped_model = PytorchWrapper(model)

initialized modules: [l2()] proximal modules: []


### Adding new modules
Taking a step back, lets take a look at which regularization modules come with NeuroVisKit

In [8]:
list(reg.get_regs_dict().keys()) #get all regularization modules

['activityL1',
 'activityL1Sum',
 'activityL2',
 'proximalGroupSparsity',
 'proximalSparsityDekel',
 'proximalL1',
 'proximalP05',
 'proximalL2',
 'l1',
 'l2',
 'l4',
 'max',
 'local',
 'glocal',
 'fourierLocal',
 'edge',
 'center',
 'fourierCenter',
 'localConv',
 'laplacian']

Adding a new reg module is super easy.

In [10]:
class NewRegModule(reg.RegularizationModule):
    def function(self): #implement your own regularization function
        return self.target.norm(1)

And adding a new proximal reg module is also super easy

In [9]:
class NewRegModule(reg.ProximalRegularizationModule):
    def proximal(self): #implement your own proximal step
        penalty = (self.target.data.abs() - 1).clamp(min=0).mean() #proximal penalty for logging
        self.target.data = self.target.data.sign() * self.target.data.abs().clamp(max=1) #proximal step
        return penalty #return a penalty for logging
        