In [1]:
import torch
from torch import nn
from helper import supress_tracer_warnings, assert_shape, is_list_of_str, normalise_2nd_moment

from typing import Optional, Any, List
import numpy as np

from shared import FullyConnectedLayers

In [None]:
import sys
import os

cur_path = ''.join(os.getcwd().split('/')[:-1])
sys.path.insert(0, f'{cur_path}/torch_utils/ops')
sys.path.insert(0, f'{cur_path}/torch_utils')

In [3]:
import conv2d_resample

In [4]:
def modulated_conv2d(x: torch.Tensor, weight: torch.Tensor, styles: torch.Tensor, noise: Optional[torch.Tensor] = None,
                     up: int = 1, down: int = 1, padding: int = 0, resample_filter: Optional[List[int]] = None, 
                     demodulate: bool = True, flip_weight: bool = True, fused_mod_cov: bool = True) -> torch.Tensor:
    
    # x:      [B, inC, H, W]
    # weight: [outC, inC, kh, kw]
    # styles: [B, inC]

    batch_size = x.shape[0]
    out_channels, in_channels, kw, kh = weight.shape
    assert_shape(weight, [out_channels, in_channels, kw, kh])
    assert_shape(x, [batch_size, in_channels, None, None])    # x's & weight's batch_size and In channels must remain same
    assert_shape(styles, [batch_size, in_channels])

    if x.dtype == torch.float16 and demodulate:
        a = 1 / np.sqrt(in_channels * kh * kw)
        b = weight.norm(p = float("inf"), dim = [1,2,3], keepdim=True)          # max of inC, kh, kw
        weight = weight * (a / b)

        styles = styles / styles.norm(p = float('inf'), dim = [1], keep_dim = True) # max of inC
    
    w = None
    dcoef = None
    if demodulate or fused_mod_cov:
        w = weight.unsqueeze(0)                                    # w:      [1, outC, inC, kh, kw]
        w = w * styles.reshape(batch_size, 1, -1, 1, 1)            # styles: [B,  1,   inC, 1,  1]
                                                                   # w:      [B, outC, inC, kh, kw]
    if demodulate:
        dcoef = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt()       # dcoef:  [B, outC]
    
    if demodulate and fused_mod_cov:
        w = w * dcoef.reshape(batch_size, -1, 1, 1, 1)             # w:      [B, outC, inC, kh, kw]
    
    if not fused_mod_cov:
        x = x * styles.reshape(batch_size, -1, 1, 1)                 # style: [B, inC x kh x kw, C, 1, 1]
        x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
        x = x.reshape(batch_size, -1, *x.shape[2:])
        if noise is not None:
            x = x.add_(noise)
    return x

In [5]:
modulated_conv2d(
    x = torch.rand(5, 3, 10, 10),
    weight = torch.rand(6, 3, 5, 5),
    styles =  torch.rand(5, 3),
    up = 3,
    fused_mod_cov = False
).shape

torch.Size([5, 6, 26, 26])

In [6]:
class GroupNorm_float32(nn.GroupNorm):
    def forward(self, x: torch.Tensor):
        return super().forward(x.float()).type(x.dtype) # Converts x to float32 -> applies Group norm -> converts back to original type

x = torch.randn(8, 32, 64, 64).half()  # float16 tensor

gn = GroupNorm_float32(num_groups=8, num_channels=32)
y = gn(x)
y.dtype

torch.float16

In [7]:
class StyleSplit(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs):
        super().__init__()

        self.fcl = FullyConnectedLayers(in_features=in_channels, out_features = 3*out_channels, **kwargs)
    
    def forward(self, x: torch.Tensor):
        x = self.fcl(x)
        m1, m2, m3 = x.chunk(chunks=3, dim=1)
        return m1 * m2 + m3

In [8]:
ss = StyleSplit(in_channels=2, out_channels=5)

x = torch.rand([10, 2])
ss(x).shape

torch.Size([10, 5])

In [None]:
class SynthesisInput(nn.Module):
    """
    Latent w → SynthesisInput → Feature map [B, 64, 64, 64]

    sin(2pi (f . x + phi))
    
    Where:
	•	f = 2D frequency vector (learned)
	•	x = pixel location
	•	phi = phase offset (learned)
	•	All of this is modulated by latent vector w through an affine transformation
    """

    def __init__(self, w_dim: int, channels: int, size: int, sampling_rate: int, bandwidth: int):
        super().__init__()

        self.w_dim = w_dim
        self.channels = channels
        self.size = np.broadcast_to(size, [2])
        self.sampling_rate = sampling_rate
        self.bandwidth = bandwidth

        # Draw random frequencies from uniform 2D disc.
        freqs = torch.rand([3, 2])                                  # Gausian Cloud in 2D space
        radii = freqs.square().sum(dim = 1, keepdim=True).sqrt()    #  $r = \sqrt{x^2 + y^2}$
        



In [49]:
torch.rand([3])

tensor([0.0120, 0.0989, 0.9029])

In [None]:
freqs = torch.rand([3, 2]) # Gausian Cloud in 2D space
freqs

tensor([[0.4917, 0.4293],
        [0.2333, 0.6255],
        [0.3529, 0.9713]])

In [31]:
radii = freqs.square().sum(dim = 1, keepdim=True).sqrt()
radii

tensor([[0.6527],
        [0.6676],
        [1.0335]])

In [32]:
radii.square().exp().pow(0.25)

tensor([[1.1124],
        [1.1179],
        [1.3061]])