In [1]:
import torch
import torch.nn as  nn
from models.involution import * 
from models.unet import * 
from models.unet_long import *

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet_long = UNET_Long(in_channels = 3, out_channels = 3,
                      features = [32, 64, 128, 256], device = device)

unet_long.to(device)

# print(unet_long)
x = torch.randn(1, 3, 256, 256).to(device)
unet_long(x).shape

Encoder
torch.Size([1, 32, 256, 256])
Involution Block: torch.Size([1, 32, 256, 256])
Downsample: torch.Size([1, 64, 129, 129])
Involution Block: torch.Size([1, 64, 129, 129])
Downsample: torch.Size([1, 128, 66, 66])
Involution Block: torch.Size([1, 128, 66, 66])
Downsample: torch.Size([1, 256, 34, 34])
Involution Block: torch.Size([1, 256, 34, 34])
After Bottleneck:  torch.Size([1, 512, 34, 34])
Decoder
Upsample:  torch.Size([1, 256, 68, 68])
Involution Block:  torch.Size([1, 256, 34, 34])
Upsample:  torch.Size([1, 128, 68, 68])
Involution Block:  torch.Size([1, 128, 66, 66])
Upsample:  torch.Size([1, 64, 132, 132])
Involution Block:  torch.Size([1, 64, 129, 129])
Upsample:  torch.Size([1, 32, 258, 258])
Involution Block:  torch.Size([1, 32, 256, 256])


torch.Size([1, 3, 256, 256])

In [8]:
import torch
import torch.nn.functional as F

x = torch.randn(2, 3, 4)
normalized_x = F.normalize(x, dim=1)
normalized_x.shape

torch.Size([2, 3, 4])

In [9]:
norms = x.norm(2, dim=-1)
norms.shape

torch.Size([2, 3])

### LayerNorm

In [13]:
a = torch.tensor([[1, 2, 3, 4],
                  [4, 5, 6, 7]], dtype = torch.float32)

a[:, None, None].shape

torch.Size([2, 1, 1, 4])

In [14]:
class LayerNorm(nn.Module):
    """ LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps = 1e-6, data_format = "channels_last"):
        super().__init__()

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format

        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError

        self.normalized_shape = (normalized_shape, )


    def forward(self, x):

        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, 
                                self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim = True)
            s = (x - u).pow(2).mean(1, keepdim = True)
            x = (x - u) / torch.sqrt(s + self.eps)

            x = self.weight[:, None, None] * x + self.bias[:, None, None]

            return x

In [None]:
x = torch.randn(1, 3, 4, 4)

layernorm = layernorm