In [None]:
#| default_exp misc.fc_decomposer

In [None]:
#| include: false
from nbdev.showdoc import *

## Overview

The `FC_Decomposer` class reduces model size by factorizing large fully-connected (Linear) layers into two smaller layers using Singular Value Decomposition (SVD). This is particularly effective for models with large FC layers like VGG or older architectures with big classifier heads.

**Key Benefits:**
- Reduces parameter count without changing model architecture externally
- No retraining required (though fine-tuning may improve accuracy)
- Works on any model with Linear layers

### When to Use FC Decomposition

| Scenario | Recommendation |
|----------|----------------|
| Large classifier heads (e.g., VGG's 4096→4096→1000) | **Highly recommended** - significant savings |
| Modern architectures (ResNet, EfficientNet) | Limited benefit - already efficient |
| Transformer attention layers | Use with caution - may hurt performance |
| Pre-deployment optimization | Good complement to pruning/quantization |

### Compression Ratio

For a Linear layer with shape `(out_features, in_features)`:
- **Original parameters**: `out_features × in_features + out_features` (with bias)
- **After decomposition** (keeping `k` singular values): `k × in_features + out_features × k + out_features`
- **Compression ratio**: roughly `1 / (1 - percent_removed)` for square layers

## How It Works

SVD decomposes a weight matrix into three matrices: $W = U \Sigma V^T$

Where:
- $U$ contains left singular vectors (output features)
- $\Sigma$ is diagonal with singular values (importance scores)
- $V^T$ contains right singular vectors (input features)

By keeping only the top $k$ singular values, we approximate $W$ with two smaller matrices, trading accuracy for compression.

![](../imgs/svd.png "SVD Decomposition")

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

In [None]:
#| export
class FC_Decomposer:
    "Decompose fully-connected layers using SVD to reduce parameters"

    def __init__(self):
        pass
        
    def decompose(self, 
                  model: nn.Module,            # The model to decompose
                  percent_removed: float = 0.5 # Fraction of singular values to remove [0, 1)
    ) -> nn.Module:
        "Recursively decompose all Linear layers in the model using SVD"
        if not (0 <= percent_removed < 1):
            raise ValueError(f"percent_removed must be in range [0, 1), got {percent_removed}")

        new_model = copy.deepcopy(model)
        module_names = list(new_model._modules)

        for k, name in enumerate(module_names):
            if len(list(new_model._modules[name]._modules)) > 0:
                new_model._modules[name] = self.decompose(new_model._modules[name], percent_removed)
            else:
                if isinstance(new_model._modules[name], nn.Linear):
                    layer = self.SVD(new_model._modules[name], percent_removed)
                    new_model._modules[name] = layer
        return new_model


    def SVD(self, 
            layer: nn.Linear,       # The Linear layer to decompose
            percent_removed: float  # Fraction of singular values to remove
    ) -> nn.Sequential:
        "Perform SVD decomposition on a single Linear layer"
        W = layer.weight.data
        U, S, V = torch.svd(W)
        L = max(1, int((1.-percent_removed) * S.shape[0]))
        W1 = U[:,:L]
        W2 = torch.diag(S[:L]) @ V[:,:L].t()
        layer_1 = nn.Linear(in_features=layer.in_features, 
                    out_features=L, bias=False)
        layer_1.weight.data = W2

        layer_2 = nn.Linear(in_features=L, 
                    out_features=layer.out_features, bias=True)
        layer_2.weight.data = W1

        if layer.bias is None: 
            layer_2.bias.data = torch.zeros(layer.out_features)
        else:
            layer_2.bias.data = layer.bias.data

        return nn.Sequential(layer_1, layer_2)

In [None]:
show_doc(FC_Decomposer.decompose)

Found permutation search CUDA kernels
[ASP][Info] permutation_search_kernels can be imported.


---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/misc/fc_decomposer.py#L19){target="_blank" style="float:right; font-size:smaller"}

### FC_Decomposer.decompose

```python

def decompose(
    model:Module, # The model to decompose
    percent_removed:float=0.5, # Fraction of singular values to remove [0, 1)
)->Module:


```

*Recursively decompose all Linear layers in the model using SVD*

---

## Usage Example

```python
from fasterai.misc.fc_decomposer import FC_Decomposer
from torchvision.models import vgg16

# Load a model with large FC layers
model = vgg16(pretrained=True)

# Decompose, removing 50% of singular values
decomposer = FC_Decomposer()
compressed_model = decomposer.decompose(model, percent_removed=0.5)

# Check parameter reduction
original_params = sum(p.numel() for p in model.parameters())
compressed_params = sum(p.numel() for p in compressed_model.parameters())
print(f"Compression: {original_params/compressed_params:.2f}x")
```

In [None]:
#| hide
from fastcore.test import *

# SVD decomposition preserves output approximately
model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))
x = torch.randn(4, 32)
out_orig = model(x)

decomposer = FC_Decomposer()
model_dec = decomposer.decompose(model, percent_removed=0.5)
out_dec = model_dec(x)
test_close(out_orig, out_dec, eps=1.0)  # 50% SVD removal has significant reconstruction error

# Decomposed structure: Linear → Sequential(Linear, Linear)
assert isinstance(model_dec[0], nn.Sequential)
assert len(model_dec[0]) == 2

# percent_removed=0 → very close output
m2 = nn.Sequential(nn.Linear(32, 64))
x2 = torch.randn(4, 32)
out2 = m2(x2)
m2_dec = decomposer.decompose(m2, percent_removed=0.0)
test_close(out2, m2_dec(x2), eps=1e-4)

# L >= 1 always (even at extreme removal)
m3 = nn.Sequential(nn.Linear(10, 20))
m3_dec = decomposer.decompose(m3, percent_removed=0.95)
assert m3_dec[0][0].out_features >= 1

# Invalid percent_removed raises ValueError
with ExceptionExpected(ValueError):
    decomposer.decompose(nn.Sequential(nn.Linear(10, 10)), percent_removed=1.0)
with ExceptionExpected(ValueError):
    decomposer.decompose(nn.Sequential(nn.Linear(10, 10)), percent_removed=-0.1)

---

## See Also

- [FC Decomposer Tutorial](../tutorials/misc/fc_decomposer.html) - Step-by-step walkthrough with examples
- [BN Folding](bn_folding.html) - Another optimization technique to reduce inference overhead
- [Pruner](../prune/pruner.html) - Remove entire filters for structured compression