In [None]:
#| default_exp misc.fc_decomposer

In [None]:
#| include: false
from nbdev.showdoc import *

%config InlineBackend.figure_format = 'retina'

We can factorize our big fully-connected layers and replace them by an approximation of two smaller layers. The idea is to make an SVD decomposition of the weight matrix, which will express the original matrix in a product of 3 matrices: $U \Sigma V^T$
With $\Sigma$ being a diagonal matrix with non-negative values along its diagonal (the singular values). We then define a value $k$ of singular values to keep and modify matrices $U$ and $V^T$ accordingly. The resulting will be an approximation of the initial matrix.

![](imgs/svd.png "SVD Decomposition")

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export
class FC_Decomposer:
    "Decompose fully-connected layers using SVD to reduce parameters"

    def __init__(self):
        pass
        
    def decompose(self, 
                  model: nn.Module,            # The model to decompose
                  percent_removed: float = 0.5 # Fraction of singular values to remove [0, 1)
    ) -> nn.Module:
        "Recursively decompose all Linear layers in the model using SVD"
        if not (0 <= percent_removed < 1):
            raise ValueError(f"percent_removed must be in range [0, 1), got {percent_removed}")

        new_model = copy.deepcopy(model)
        module_names = list(new_model._modules)

        for k, name in enumerate(module_names):
            if len(list(new_model._modules[name]._modules)) > 0:
                new_model._modules[name] = self.decompose(new_model._modules[name], percent_removed)
            else:
                if isinstance(new_model._modules[name], nn.Linear):
                    layer = self.SVD(new_model._modules[name], percent_removed)
                    new_model._modules[name] = layer
        return new_model


    def SVD(self, 
            layer: nn.Linear,       # The Linear layer to decompose
            percent_removed: float  # Fraction of singular values to remove
    ) -> nn.Sequential:
        "Perform SVD decomposition on a single Linear layer"
        W = layer.weight.data
        U, S, V = torch.svd(W)
        L = int((1.-percent_removed)*U.shape[0])
        W1 = U[:,:L]
        W2 = torch.diag(S[:L]) @ V[:,:L].t()
        layer_1 = nn.Linear(in_features=layer.in_features, 
                    out_features=L, bias=False)
        layer_1.weight.data = W2

        layer_2 = nn.Linear(in_features=L, 
                    out_features=layer.out_features, bias=True)
        layer_2.weight.data = W1

        if layer.bias.data is None: 
            layer_2.bias.data = torch.zeros(layer.out_features)
        else:
            layer_2.bias.data = layer.bias.data

        return nn.Sequential(layer_1, layer_2)

In [None]:
show_doc(FC_Decomposer.decompose)

---

### FC_Decomposer.decompose

>      FC_Decomposer.decompose (model, percent_removed=0.5)

A tutorial about how to use the `FC_Decomposer` functionalities can be found [here](https://nathanhubens.github.io/fasterai/tutorial.fc_decomposer.html)