Skip to content

Commit

Permalink
refactor: latent_codec more natural initialization
Browse files Browse the repository at this point in the history
Use self-documenting keyword arguments instead of `_setdefault` et al.
  • Loading branch information
YodaEmbedding committed Apr 12, 2023
1 parent 312d62b commit 063a2f2
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 43 deletions.
26 changes: 20 additions & 6 deletions compressai/latent_codecs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,29 @@
class _SetDefaultMixin:
"""Convenience functions for initializing classes with defaults."""

_kwargs: Dict[str, Any]

def _setdefault(self, k, f):
v = self._kwargs.get(k, None) or f()
def _setdefault(self, k, v, f):
"""Initialize attribute ``k`` with value ``v`` or ``f()``."""
v = v or f()
setattr(self, k, v)

# TODO instead of save_direct, override load_state_dict() and state_dict()
def _set_group_defaults(self, group_key, defaults, save_direct=False):
group_dict = self._kwargs.get(group_key, {})
def _set_group_defaults(self, group_key, group_dict, defaults, save_direct=False):
"""Initialize attribute ``group_key`` with items from
``group_dict``, using defaults for missing keys.
Ensures ``nn.Module`` attributes are properly registered.
Args:
- group_key:
Name of attribute.
- group_dict:
Dict of items to initialize ``group_key`` with.
- defaults:
Dict of defaults for items not in ``group_dict``.
- save_direct:
If ``True``, save items directly as attributes of ``self``.
If ``False``, save items in a ``nn.ModuleDict``.
"""
group_dict = group_dict if group_dict is not None else {}
for k, f in defaults.items():
if k in group_dict:
continue
Expand Down
9 changes: 6 additions & 3 deletions compressai/latent_codecs/entropy_bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,13 @@ class EntropyBottleneckLatentCodec(LatentCodec):

entropy_bottleneck: EntropyBottleneck

def __init__(self, channels: Optional[int] = None, **kwargs):
def __init__(
self,
entropy_bottleneck: Optional[EntropyBottleneck] = None,
**kwargs,
):
super().__init__()
self._kwargs = kwargs
self._setdefault("entropy_bottleneck", lambda: EntropyBottleneck(channels))
self.entropy_bottleneck = entropy_bottleneck or EntropyBottleneck(**kwargs)

def forward(self, y: Tensor) -> Dict[str, Any]:
y_hat, y_likelihoods = self.entropy_bottleneck(y)
Expand Down
15 changes: 10 additions & 5 deletions compressai/latent_codecs/gain/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,17 @@ class GainHyperLatentCodec(LatentCodec):
h_a: nn.Module
h_s: nn.Module

def __init__(self, z_channels: Optional[int] = None, **kwargs):
def __init__(
self,
entropy_bottleneck: Optional[EntropyBottleneck] = None,
h_a: Optional[nn.Module] = None,
h_s: Optional[nn.Module] = None,
**kwargs,
):
super().__init__()
self._kwargs = kwargs
self._setdefault("entropy_bottleneck", lambda: EntropyBottleneck(z_channels))
self._setdefault("h_a", nn.Identity)
self._setdefault("h_s", nn.Identity)
self.entropy_bottleneck = entropy_bottleneck or EntropyBottleneck()
self.h_a = h_a or nn.Identity()
self.h_s = h_s or nn.Identity()

def forward(self, y: Tensor, gain: Tensor, gain_inv: Tensor) -> Dict[str, Any]:
z = self.h_a(y)
Expand Down
9 changes: 6 additions & 3 deletions compressai/latent_codecs/gain/hyperprior.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,18 @@ class GainHyperpriorLatentCodec(LatentCodec):

latent_codec: Mapping[str, LatentCodec]

def __init__(self, z_channels: Optional[int] = None, **kwargs):
def __init__(
self, latent_codec: Optional[Mapping[str, LatentCodec]] = None, **kwargs
):
super().__init__()
self._kwargs = kwargs
self._set_group_defaults(
"latent_codec",
latent_codec,
defaults={
"y": GaussianConditionalLatentCodec,
"hyper": lambda: GainHyperLatentCodec(z_channels=z_channels),
"hyper": GainHyperLatentCodec,
},
save_direct=True,
)

def forward(
Expand Down
17 changes: 12 additions & 5 deletions compressai/latent_codecs/gaussian_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch.nn as nn

Expand Down Expand Up @@ -79,12 +79,19 @@ class GaussianConditionalLatentCodec(LatentCodec):
gaussian_conditional: GaussianConditional
entropy_parameters: nn.Module

def __init__(self, quantizer: str = "noise", **kwargs):
def __init__(
self,
gaussian_conditional: Optional[GaussianConditional] = None,
entropy_parameters: Optional[nn.Module] = None,
quantizer: str = "noise",
**kwargs,
):
super().__init__()
self._kwargs = kwargs
self.quantizer = quantizer
self._setdefault("gaussian_conditional", lambda: GaussianConditional(None))
self._setdefault("entropy_parameters", nn.Identity)
self.gaussian_conditional = gaussian_conditional or GaussianConditional(
**kwargs
)
self.entropy_parameters = entropy_parameters or nn.Identity()

def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]:
gaussian_params = self.entropy_parameters(ctx_params)
Expand Down
15 changes: 10 additions & 5 deletions compressai/latent_codecs/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,17 @@ class HyperLatentCodec(LatentCodec):
h_a: nn.Module
h_s: nn.Module

def __init__(self, z_channels: Optional[int] = None, **kwargs):
def __init__(
self,
entropy_bottleneck: Optional[EntropyBottleneck] = None,
h_a: Optional[nn.Module] = None,
h_s: Optional[nn.Module] = None,
**kwargs,
):
super().__init__()
self._kwargs = kwargs
self._setdefault("entropy_bottleneck", lambda: EntropyBottleneck(z_channels))
self._setdefault("h_a", nn.Identity)
self._setdefault("h_s", nn.Identity)
self.entropy_bottleneck = entropy_bottleneck or EntropyBottleneck(0)
self.h_a = h_a or nn.Identity()
self.h_s = h_s or nn.Identity()

def forward(self, y: Tensor) -> Dict[str, Any]:
z = self.h_a(y)
Expand Down
8 changes: 5 additions & 3 deletions compressai/latent_codecs/hyperprior.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,16 @@ class HyperpriorLatentCodec(LatentCodec):

latent_codec: Mapping[str, LatentCodec]

def __init__(self, z_channels: Optional[int] = None, **kwargs):
def __init__(
self, latent_codec: Optional[Mapping[str, LatentCodec]] = None, **kwargs
):
super().__init__()
self._kwargs = kwargs
self._set_group_defaults(
"latent_codec",
latent_codec,
defaults={
"y": GaussianConditionalLatentCodec,
"hyper": lambda: HyperLatentCodec(z_channels=z_channels),
"hyper": HyperLatentCodec,
},
save_direct=True,
)
Expand Down
17 changes: 11 additions & 6 deletions compressai/latent_codecs/rasterscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -85,12 +85,17 @@ class RasterScanLatentCodec(LatentCodec):
entropy_parameters: nn.Module
context_prediction: MaskedConv2d

def __init__(self, **kwargs):
def __init__(
self,
gaussian_conditional: Optional[GaussianConditional] = None,
entropy_parameters: Optional[nn.Module] = None,
context_prediction: Optional[MaskedConv2d] = None,
**kwargs,
):
super().__init__()
self._kwargs = kwargs
self._setdefault("gaussian_conditional", lambda: GaussianConditional(None))
self._setdefault("entropy_parameters", nn.Identity)
self._setdefault("context_prediction", lambda: None)
self.gaussian_conditional = gaussian_conditional or GaussianConditional()
self.entropy_parameters = entropy_parameters or nn.Identity()
self.context_prediction = context_prediction or MaskedConv2d()
self.kernel_size = _reduce_seq(self.context_prediction.kernel_size)
self.padding = (self.kernel_size - 1) // 2

Expand Down
8 changes: 1 addition & 7 deletions docs/source/latent_codecs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ Using :py:class:`~compressai.models.base.SimpleVAECompressionModel`, some Google
self.g_a = nn.Sequential(...)
self.g_s = nn.Sequential(...)
self.latent_codec = EntropyBottleneckLatentCodec(N)
self.latent_codec = EntropyBottleneckLatentCodec(channels=M)
.. code-block:: python
Expand All @@ -190,13 +190,10 @@ Using :py:class:`~compressai.models.base.SimpleVAECompressionModel`, some Google
h_s = nn.Sequential(...)
self.latent_codec = HyperpriorLatentCodec(
N=N,
# A HyperpriorLatentCodec is made of "hyper" and "y" latent codecs.
latent_codec={
# Side-information branch with entropy bottleneck for "z":
"hyper": HyperLatentCodec(
N,
h_a=h_a,
h_s=h_s,
entropy_bottleneck=EntropyBottleneck(N),
Expand All @@ -220,13 +217,10 @@ Using :py:class:`~compressai.models.base.SimpleVAECompressionModel`, some Google
h_s = nn.Sequential(...)
self.latent_codec = HyperpriorLatentCodec(
N=N,
# A HyperpriorLatentCodec is made of "hyper" and "y" latent codecs.
latent_codec={
# Side-information branch with entropy bottleneck for "z":
"hyper": HyperLatentCodec(
N,
h_a=h_a,
h_s=h_s,
entropy_bottleneck=EntropyBottleneck(N),
Expand Down

0 comments on commit 063a2f2

Please sign in to comment.