In [1]:
import torch
from torch import nn 
import torch.nn.functional as F

StyleGAN - well-developed type of GAN network, which is able to generate realistic images. It uses several specific architectural tricks, which, for example, make model to pay attention on image styles. The basic element of StyleGAN - is a style-modulation. It makes unique filters for each element of data batch. Here is the way how style-modulation block with convolution could be realised:

In [2]:
# Сlassic nn.Conv2d creates the same filters for all images in a batch, but in StyleGAN we want each element in the batch to have its own unique filters, so instead of classic nn.Conv2d this block will use F.conv2d for convolution operation and random convolution weights that will be updated during training.

class ModulatedCond2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, cond_dim):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.eps = 1e-8

        # here we create not a layer, but convolution WEIGHTS - weights in pytorch have the form [out_channels, in_channels, kernel_size, kernel_size]
        # then we will update this tensor in backward. This tensor is a learnable parameter shared by the entire batch. It will be scaled (modulated) individually on each element.
        
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
        )
        
        # here is a layer for modulated style vector
        self.style = nn.Linear(cond_dim, in_channels)

    def forward(self, x, y):
        # x - original tensor (for example of size - [B, C, H, W])
        # y - style vector, random noise after style mapping - passing through several convolution layers
        # (for example of size - [B, C])
        B, C, H, W = x.shape

        # bring the modulated style vector to a size where it can be multiplied by weight
        style = self.style(y).view(B, 1, C, 1, 1)
        print(f"Style shape: {style.shape}")

        # !!! This is the modulation of weights - multiplication of the convolution weights by the style vector
        weight = self.weight * style
        print(f"Weight shape 1: {weight.shape}")

        # here weight has dimension [B, 256, 512, 3, 3], and sum([2, 3, 4] "collapses" the last three dimensions and the tensor itself to size [B, 256]
        
        demod = torch.rsqrt((weight ** 2).sum([2, 3, 4]) + self.eps)  
        print(f"Demodulation shape: {demod.shape}")
        
        # this is the process of demodulation - multiplication of weights by the demodulated vector
        weight = weight * demod.view(B, self.out_channels, 1, 1, 1)
        print(f"Weight shape 2: {weight.shape}")

        # bring the input tensor x and the modulated weight to the correct size for feeding into the group convolution - this is a special mode of operation of Conv2d, in which the input and output channels are divided into groups, and each group of channels is processed separately by its own filters, without interaction with other groups.
        
        x = x.view(1, B * C, H, W)
        weight = weight.view(B * self.out_channels, C, self.kernel_size, self.kernel_size)

        print(f"Weight shape 3: {weight.shape}")
        print(f"X shape: {x.shape}")

        # weight will have size [512, 512, 3, 3]
        # x will have size [1, 1024, 16, 16]

        # group convolution divides x into B groups of 512 channels and weight divides into B groups of 256 channels. Each group is convolved independently!
        
        out = F.conv2d(x, weight, padding=self.kernel_size // 2, groups=B)

        
        out = out.view(B, self.out_channels, H, W)
        return out

In [3]:
tensor1 = torch.randn(1, 512, 16, 16)
tensor2 = torch.randn(1, 26)

block = ModulatedCond2d(512, 256, 3, 26)

out = block(tensor1, tensor2)
out.shape

Style shape: torch.Size([1, 1, 512, 1, 1])
Weight shape 1: torch.Size([1, 256, 512, 3, 3])
Demodulation shape: torch.Size([1, 256])
Weight shape 2: torch.Size([1, 256, 512, 3, 3])
Weight shape 3: torch.Size([256, 512, 3, 3])
X shape: torch.Size([1, 512, 16, 16])


torch.Size([1, 256, 16, 16])

In [4]:
tensor1 = torch.randn(2, 512, 16, 16)
tensor2 = torch.randn(2, 26)

block = ModulatedCond2d(512, 256, 3, 26)

out = block(tensor1, tensor2)
out.shape

Style shape: torch.Size([2, 1, 512, 1, 1])
Weight shape 1: torch.Size([2, 256, 512, 3, 3])
Demodulation shape: torch.Size([2, 256])
Weight shape 2: torch.Size([2, 256, 512, 3, 3])
Weight shape 3: torch.Size([512, 512, 3, 3])
X shape: torch.Size([1, 1024, 16, 16])


torch.Size([2, 256, 16, 16])