In [2]:
import torch
import torch.nn as nn

class FinalLayer (nn.Module):
    """
    Final layer of the backbone
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
    
    def forward(self, x, y):
        scale, shift = self.adaLN_modulation(y).chunk(2, dim=-1) # 2x (B, C)
        x = modulate(self.norm_final(x), shift, scale) # x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) -> (B, T, C)
        x = self.linear(x) # (B, T, C) - > (B, T, patch_size * patch_size * out_channels)
        return x

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

In [3]:
layer = FinalLayer(7, 2, 4)

In [4]:
layer(torch.randn(1, 3, 7), torch.randn(1, 7)).shape

torch.Size([1, 3, 16])

In [5]:
import torch
import torch.nn as nn

class activation(nn.Module):
    def __init__(self, activation_layer= nn.GELU):
        super().__init__()
        self.act = activation_layer
    
    def forward (self, x):
        return self.act(x)

x = torch.randn(4)
a = activation(activation_layer=nn.GELU(approximate="tanh"))
a(x)


tensor([-0.0447, -0.0520, -0.0628,  0.8552])

In [7]:
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
class GatedMLP(nn.Module):
    def __init__(
            self,
            fan_in: int,
            fan_h: int = None,
            fan_out: int = None,
            act_layer = lambda:nn.GELU(approximate="tanh"),
            drop: float = 0.0,
            bias: bool = True,
    )-> None:
        super().__init__()
        fan_out = fan_out or fan_in # stores first truth value
        fan_h = fan_h or fan_in
        self.fc1 = nn.Linear(fan_in, 2*fan_h, bias=bias)
        self.fc2 = nn.Linear(fan_h, fan_out, bias=bias)
        self.act_layer = act_layer()
    
    def forward(self, x:Tensor)-> Tensor:
        x = self.fc1(x)
        x, scale = x.chunk(2, dim=-1)
        x = self.act_layer(x) * scale
        x = self.fc2(x)
        return x

# works! : when init calls approx_gelu(), it instantiates the GELU object
approx_gelu = lambda: nn.GELU(approximate="tanh")
m = GatedMLP(1, 2, act_layer=approx_gelu, drop=0, bias=False)
x = torch.randn(2, 1)
m(x)

tensor([[-4.7365e-05],
        [-1.8660e-03]], grad_fn=<MmBackward0>)

In [6]:
approx_gelu, nn.GELU

(<function __main__.<lambda>()>, torch.nn.modules.activation.GELU)