In [2]:
import torch 
import torch.nn as nn
import math

In [None]:
class WeightStandardizedConv(nn.Conv2d):
    """ 
        This class standardizes the weights inside the convolutional layer (z score normalization)
    """
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        w = self.weight
        w_mean = w.mean(dim=[1,2,3], keepdim=True)
        w_var = w.var(dim=[1,2,3], keepdim=True, correction=0) # divide by N instead of N-1
        eps = 1e-5 if x.type in (torch.float32, torch.float64) else 1e-3
        weights = (w - w_mean) / torch.sqrt(w_var + eps)
        return torch.nn.functional.conv2d(
            x,
            weights,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups
        )

In [21]:
class block(nn.Module):
    """ 
        Regular block inside resnet class - consists of projections + groupnorm + possible time embedding (scale and shift) + SiLU activation
    """
    def __init__(self, c_in:int, c_out:int, gn_groups:int):
        super().__init__()
        assert c_out % gn_groups == 0, "c_out must be divisible by gn_groups"
        self.norm = nn.GroupNorm(num_groups=gn_groups, num_channels=c_out)
        self.projection = WeightStandardizedConv(c_in, c_out, kernel_size=3, padding=1, bias=False)
        self.activation = nn.SiLU()
    
    def forward(self, x:torch.Tensor, scale_shift:torch.Tensor|None = None) -> torch.Tensor:
        x = self.projection(x)
        x = self.norm(x)
        if scale_shift is not None:
            scale, shift = scale_shift.chunk(chunks=2, dim=1)   # scale_shift shape: (B, 2*c_out, 1, 1)
            scale, shift = scale.to(x.device), shift.to(x.device)
            x = x * (scale+1) + shift
        return self.activation(x)

In [None]:
class resnet(nn.Module):
    """ 
        resnet block that applies resitual connection to multiple regular blocks
    """