In [2]:
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.parameter import Parameter
import torchvision
import cv2 as cv
from typing import Optional

**Computer Vision Modules to Minimize Chance of Exploding / Vanishing Gradients**
---

In [3]:
# Equalized / Scaled 2D Convolution module
class EqualizedLR_Conv2d(nn.Module):
    """
    Equalized LR Convolutional 2d cell. Used to prevent exploding gradients
    """

    def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0):
        super().__init__()
        self.padding = padding
        self.stride = stride
        self.scale = np.sqrt(2 / (in_ch * kernel_size[0] * kernel_size[1]))

        self.weight = Parameter(T.Tensor(out_ch, in_ch, *kernel_size))
        self.bias = Parameter(T.Tensor(out_ch))

        nn.init.normal_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return F.conv2d(
            x, self.weight * self.scale, self.bias, self.stride, self.padding
        )

In [None]:
# Pixel-wise normalization of image
class Pixel_norm(nn.Module):
    """
    Pixel wise normalization
    """

    def __init__(self):
        super().__init__()

    def forward(self, a):
        b = a / T.sqrt(T.sum(a**2, dim=1, keepdim=True) + 10e-8)
        return b

In [None]:
# batch-wise standardization
class Minibatch_std(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        size = list(x.size())
        size[1] = 1

        std = T.std(x, dim=0)
        mean = T.mean(std)
        return T.cat((x, mean.repeat(size)), dim=1)

In [None]:
# Convert an RGB image to a binary image (reduce 3rd dim)
class fromRGB(nn.Module):
    """
    Learned conversion of a 3 channel image to a 1 channel image
    """

    def __init__(self, in_c, out_c):
        super().__init__()
        self.cvt = EqualizedLR_Conv2d(in_c, out_c, (1, 1), stride=(1, 1))
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.cvt(x)
        return self.relu(x)

# Convert binary image to RGB image (learned expansion of 3rd dim)
class toRGB(nn.Module):
    """
    Learned conversion of a 1 channel image to a 3 channel image
    """

    def __init__(self, in_c, out_c):
        super().__init__()
        self.cvt = EqualizedLR_Conv2d(in_c, out_c, (1, 1), stride=(1, 1))

    def forward(self, x):
        return self.cvt(x)