In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.utils import save_image
import torchvision
import numpy as np
import matplotlib.pylab as plt
import os
import cv2
import math
from math import log2
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!nvidia-smi

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

Thu Jan 23 17:38:02 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 522.06       Driver Version: 522.06       CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000   WDDM  | 00000000:01:00.0  On |                  Off |
| 67%   84C    P2   266W / 300W |  26063MiB / 49140MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
# Basic
IMG_CHANNELS = 1
STEP = 6
FIRST_EPOCH = 10
LAST_EPOCH = 160
EPOCH_INTERVAL = 10
NUM_IMAGES = 2182

# Optional
Z_DIM = 512
W_DIM = 512
IN_CHANNELS = 512
factors = [1,1,1,1/2,1/4,1/8,1/16,1/32]

In [4]:
class EqualizedLinear(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True):
        super(EqualizedLinear, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim, bias=bias)
        nn.init.normal_(self.linear.weight, mean=0.0, std=1.0 / np.sqrt(in_dim))
        if bias:
            nn.init.zeros_(self.linear.bias)
        self.scale = (2 / in_dim) ** 0.5

    def forward(self, x):
        return self.linear(x) * self.scale

In [5]:
class PixelNorm(nn.Module):
    def __init__(self, epsilon=1e-8):
        super(PixelNorm, self).__init__()
        self.epsilon = epsilon
    
    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim = 1, keepdim = True) + self.epsilon)

In [6]:
class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            EqualizedLinear(z_dim, w_dim),
            nn.LeakyReLU(0.2),
            EqualizedLinear(w_dim, w_dim),
            nn.LeakyReLU(0.2),
            EqualizedLinear(w_dim, w_dim),
            nn.LeakyReLU(0.2),
            EqualizedLinear(w_dim, w_dim),
            nn.LeakyReLU(0.2),
            EqualizedLinear(w_dim, w_dim),
            nn.LeakyReLU(0.2),
            EqualizedLinear(w_dim, w_dim),
            nn.LeakyReLU(0.2),
            EqualizedLinear(w_dim, w_dim),
            nn.LeakyReLU(0.2),
            EqualizedLinear(w_dim, w_dim),
            nn.LeakyReLU(0.2),
        )
    
    def forward(self,z):
        return self.mapping(z)

In [7]:
class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale   = nn.Linear(w_dim, channels)
        self.style_bias    = nn.Linear(w_dim, channels)

    def forward(self,x,w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias  = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

In [8]:
class EqualizedConv2d(nn.Module):
    def __init__(
            self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1,bias = True
    ):
        super(EqualizedConv2d, self).__init__()
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        fan_in = in_channels * kernel_size * kernel_size
        self.scale = 1 / math.sqrt(fan_in)
        
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None
        
        self.stride = stride
        self.padding = padding

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.conv2d(
            x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding
        )

In [9]:
class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

In [10]:
class NoiseInjection(nn.Module):
    def __init__(self, channels):
        super(NoiseInjection, self).__init__()
        self.weight = nn.Parameter(torch.zeros(1,channels,1,1))

    def forward(self, x, noise=None):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device = x.device)
        return x + self.weight * noise

In [11]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GeneratorBlock, self).__init__()
        self.conv1 = EqualizedConv2d(in_channels, out_channels)
        self.conv2 = EqualizedConv2d(out_channels, out_channels)
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        self.noise_injection1 = NoiseInjection(out_channels)
        self.noise_injection2 = NoiseInjection(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.activation(self.noise_injection1(self.conv1(x))), w)
        x = self.adain2(self.activation(self.noise_injection2(self.conv2(x))), w)
        return x

In [12]:
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels):
        super(Generator, self).__init__()
        self.constant = nn.Parameter(torch.ones(1, in_channels, 4, 4))
        self.mapping_network = MappingNetwork(z_dim, w_dim)
        self.first_adain1 = AdaIN(in_channels, w_dim)
        self.first_adain2 = AdaIN(in_channels, w_dim)
        self.first_noise1 = NoiseInjection(in_channels)
        self.first_noise2 = NoiseInjection(in_channels)
        self.first_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        
        self.first_torgb = EqualizedConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)

        self.prog_blocks, self.to_rgbs = (
            nn.ModuleList([]),
            nn.ModuleList([self.first_torgb])
        )
        
        for i in range(len(factors)-1):
            in_channels_c = int(in_channels * factors[i])
            out_channels_c = int(in_channels * factors[i+1])
            self.prog_blocks.append(GeneratorBlock(in_channels_c, out_channels_c, w_dim))
            self.to_rgbs.append(EqualizedConv2d(out_channels_c, img_channels, kernel_size=1, stride=1, padding=0))

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, z, alpha, steps):
        w = self.mapping_network(z)
        x = self.first_adain1(self.first_noise1(self.constant), w)
        x = self.first_conv(x)
        out = self.first_adain2(self.activation(self.first_noise2(x)), w)

        if steps == 0:
            return self.first_torgb(x)
        
        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode='bilinear')
            out = self.prog_blocks[step](upscaled, w)

        final_upscaled = self.to_rgbs[steps-1](upscaled)
        final_out = self.to_rgbs[steps](out)

        return self.fade_in(alpha, final_upscaled, final_out)

In [13]:
def generate_from_checkpoint(checkpoint_path, step, epoch, num_images=10):
    generator = Generator(Z_DIM, W_DIM, IN_CHANNELS, IMG_CHANNELS).to(DEVICE)
    generator.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
    generator.eval()

    alpha = 1.0
    os.makedirs(f"synthesis_images/step{step}/epoch{epoch}", exist_ok=True)

    with torch.no_grad():
        for i in range(num_images):
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = generator(noise, alpha, step)
            save_image(img * 0.5 + 0.5, f"synthesis_images/step{step}/epoch{epoch}/synthesis_image_{i}.tif")
    print(f"Saved {num_images} images to 'synthesis_images/'")


def main(epoch):
    step = STEP
    epoch = epoch
    checkpoint_path = f"./saved_models/step{step}/generator_step{step}_epoch{epoch}.pth"
    num_images = NUM_IMAGES
    generate_from_checkpoint(checkpoint_path, step, epoch, num_images)

if __name__ == "__main__":
    for epoch in range(FIRST_EPOCH, LAST_EPOCH+10, EPOCH_INTERVAL):
        main(epoch)

Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
Saved 2182 images to 'synthesis_images/'
