In [1]:
import torch
from torch import nn
import numpy as np

In [33]:
class LearnableLogitScaling(nn.Module):
    def __init__(
        self, 
        logit_scale_init: float = 1 / 0.07,
        learnable: bool = True,
        max_logit_scale: float = 100
        ):
        super().__init__()
        self.logit_scale_init = logit_scale_init
        self.learnable = learnable
        self.max_logit_scale = max_logit_scale

        log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)

        if self.learnable: self.log_logit_scale = nn.Parameter(log_logit_scale)
        else:              self.register_buffer("log_logit_scale", log_logit_scale)

    def forward(self, x):
        clipped = torch.clip(self.log_logit_scale.exp(), max = self.max_logit_scale)
        return clipped * x
    
    def extra_repr(self):
        st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}," \
             f" max_logit_scale={self.max_logit_scale}"
        return st

In [34]:
learnablelogitscale = LearnableLogitScaling()
learnablelogitscale(torch.rand(2,2))

tensor([[ 2.4878, 13.3646],
        [12.9722,  6.9665]], grad_fn=<MulBackward0>)

In [35]:
learnablelogitscale

LearnableLogitScaling(logit_scale_init=14.285714285714285,learnable=True, max_logit_scale=100)

In [47]:
class SelectElement(nn.Module):
    def __init__(self, index):
        super().__init__()
        self.index = index
    
    def forward(self, x):
        assert x.ndim >= 3
        return x[:, self.index, ...]

x = torch.rand(2, 3, 224, 224)
selector = SelectElement(index = 0)
selector(x).shape

torch.Size([2, 224, 224])

In [52]:
class Normalize(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, x):
        return nn.functional.normalize(x, p = 2, dim = self.dim)

x = torch.rand(2, 3, 224, 224)
normalize = Normalize(2)
normalize(x).shape

torch.Size([2, 3, 224, 224])