In [32]:
import torch
from torch import nn
from typing import Callable
import numpy as np
import bias_act

In [42]:
bias_act.bias_act(torch.rand([1,2]), torch.rand([2]), act='lrelu')

tensor([[0.1304, 0.8804]])

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, layer: Callable):
        super().__init__()
        self.layer = layer

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return (x + self.layer(x)) / np.sqrt(2)

In [66]:
class FullConnectedLayers(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, activation: str = 'linear',
                       lr_multiplier: float = 1.0, weight_init: float = 1.0, bias_init: float = 0.0):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.activation = activation

        self.weights = nn.Parameter(
            torch.randn([out_features, in_features])
            ) * (weight_init / lr_multiplier)
        
        bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), shape=[out_features])
        self.bias = nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
        self.weight_gain = lr_multiplier / np.sqrt(in_features)
        self.bias_gain = lr_multiplier
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.weights.to(x.dtype) * self.weight_gain
        b = self.bias

        if b is not None:
            if self.bias_gain != 1: b = b * self.bias_gain

        if self.activation == "linear" and b is not None:
            torch.addmm(b.unsqueeze(0), x, w.t())  # b + x @ Wᵀ
            # x: [batch_size, in_features]
	        # W: [out_features, in_features]
        else:
            x = torch.matmul(x, w.t())
            bias_act.bias_act(x = x, b = b, act = self.activation)
        
        return x
    
    def extra_repr(self):
        return f"In Features: {self.in_features}\nOut Features: {self.out_features}\nActivation Function: {self.activation}"


In [69]:
fcl = FullConnectedLayers(in_features=3, out_features=5, activation="lrelu")
fcl

FullConnectedLayers(
  In Features: 3
  Out Features: 5
  Activation Function: lrelu
)

In [65]:
fcl = FullConnectedLayers(in_features=3, out_features=5, activation="lrelu")
fcl(torch.rand([1, 3]))

tensor([[ 0.3582, -0.0680,  0.0544, -0.2280, -0.1576]], grad_fn=<MmBackward0>)