In [1]:
from clip_model import CLIP_model
from Generator import Generator
from Discriminator import ProjectedDiscriminator

c_dim = 768
z_dim = 64

img_resolution = 64
batch = 5

clip = CLIP_model()
generator     = Generator(z_dim = z_dim, conditional=True, img_resolution = img_resolution)
discriminator = ProjectedDiscriminator(c_dim = c_dim)

  model = create_fn(


In [2]:
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import RandomCrop
# from torch_utils import training_stats

from helper import show_one

import sys
import os

from typing import List

cur_path = '/'.join(os.getcwd().split('/')[:-1])
sys.path.insert(0, f'{cur_path}/torch_utils/ops')
sys.path.insert(0, f'{cur_path}/torch_utils')

import conv2d_resample
import upfirdn2d
import bias_act
import fma


In [3]:
def spherical_distance(x: torch.Tensor, y: torch.Tensor):
    x = F.normalize(x, dim = -1)
    y = F.normalize(y, dim = -1)

    # Smaller angle -> more similar
    # Larger angle  -> more dissimilar
    return (x * y).sum(-1).arccos().pow(2)

# x = torch.rand(5, 10)
# y = torch.rand(5, 10)
# spherical_distance(x, y), spherical_distance(x, x)

In [4]:
blur_fade_kimg = 2 # fade out after 2,000 images
blur_init_sigma = 2

def set_blur_sigma(cur_nimg: int):
    # cur_nimg is basically num images sees
    if blur_fade_kimg > 1:
        blur_curr_sigma = max(1 - cur_nimg / (blur_fade_kimg  * 1000), 0) * blur_init_sigma
    else: 
        blur_curr_sigma = 0

    return blur_curr_sigma

# set_blur_sigma(0), set_blur_sigma(1000), set_blur_sigma(2000)

In [5]:
def blur(img: torch.Tensor, blur_sigma: float) -> torch.Tensor:
    # Applies Blur
    blur_size = np.floor(blur_sigma * 3)
    if blur_size > 0:
        f = torch.arange(-blur_size, blur_size + 1, device=img.device, dtype = torch.float32) # e.g., [-3, -2, ..., 3]
        f = f.div(blur_sigma).square().neg().exp2()                                           # exp(-x^2 / (2σ^2))
        img = upfirdn2d.filter2d(img, f / f.sum())
    return img

# img = torch.rand(5, 3, 224, 224)
# blur(img, 3).shape

In [6]:
def run_Generator(z: torch.Tensor, c: torch.Tensor):
    ws = generator.mapping(z, c)
    imgs = generator.synthesis(ws)
    return imgs

# z = torch.rand(batch, z_dim)
# c = ["cat", "dog", "tiger", "elephant", "zebra"]
# imgs = run_Generator(z, c)
# imgs.shape

In [7]:
# show_one(imgs)

In [8]:
def run_Discriminator(imgs: torch.Tensor, c: torch.Tensor):
    if imgs.shape[-1] > generator.img_resolution:
        imgs = F.interpolate(imgs, generator.img_resolution, mode='area')
    imgs = blur(imgs, blur_sigma = set_blur_sigma(200))
    return discriminator(imgs, c)

# disc_out = run_Discriminator(imgs, generator.mapping.clip.encode_texts(["cat", "dog", "tiger", "elephant", "zebra"]))
# disc_out.shape

In [9]:
clip_weight = 3

def accumulate_gradients(phase: str, 
                         real_imgs: torch.Tensor, 
                         c_raw: List[str], 
                         gen_z: torch.Tensor,
                         cur_nimg: int,
                         verbose: bool = False):
    
    # gen_z    : Fake Images
    # real_imgs: Real Images

    batch_size = real_imgs.shape[0]

    c_enc = None
    if isinstance(c_raw[0], str):
        c_enc = clip.encode_texts(c_raw)

    if phase == 'D':
        # Minimize logits for generated images
        fake_images      = run_Generator(gen_z, c_raw)
        fake_images_disc = run_Discriminator(fake_images, c_enc)

        fake_images_loss = (F.relu(torch.ones_like(fake_images_disc) + fake_images_disc)).mean() / batch_size
        # fake_images_loss.backward()
                           # 1 + -fake_logits; if disciminator is confident; fake_logits = 2 (above 1)
                           #                  (1 + -2) = -1; Relu(-1) = 0    NO PENALTY IS DISCRIMINATOR IS CONFIDENT

                           # 1 - fake_logits; if disciminator is NOT confident; fake_logits = 0 (below 0)
                           #                  (1 + -0) = 1; Relu(1) = 1         PENALTY IS DISCRIMINATOR IS NOT CONFIDENT

        real_images = real_imgs.detach().requires_grad_(False)
        real_images_disc = run_Discriminator(real_images, c_enc)
        real_images_loss = (F.relu(torch.ones_like(real_images_disc) - real_images_disc)).mean() / batch_size
        # real_images_loss.backward()
                           # 1 - real_logits; if disciminator is confident; real_logits = 2 (above 1)
                           #                  (1 - 2) = -1; Relu(-1) = 0    NO PENALTY IF DISCRIMINATOR IS CONFIDENT

                           # 1 - real_logits; if disciminator is NOT confident; real_logits = 0 (below 0)
                           #                  (1 - 0) = 1; Relu(1) = 1         PENALTY IF DISCRIMINATOR IS NOT CONFIDENT
        (fake_images_loss + real_images_loss).backward()
        training_stats = {
            "Discriminator Score for Fake Images": round(fake_images_loss.item(),4),
            "Discriminator Score for Real Images": round(real_images_loss.item(),4),
            "Discriminator Total Loss"           : round(fake_images_loss.item() + real_images_loss.item(),4)
        }
        if verbose: print(training_stats)

    elif phase == "G":
        gen_img          = run_Generator(gen_z, c_raw)
        fake_images_disc = run_Discriminator(gen_img, c_enc)

        generator_loss = (-fake_images_disc).mean() / batch_size
        # If G is doing a good job → gen_logits will be positive

        # Minimize spherical distance between image and text features
        clip_loss = 0
        if clip_weight > 0:
            if gen_img.shape[-1] > 64:
                gen_img = RandomCrop(64)(gen_img)
            gen_img = F.interpolate(gen_img, 224, mode='area')
            gen_img_features = clip.encode_image(gen_img.add(1).div(2))
            clip_loss = spherical_distance(gen_img_features, c_enc).mean()

        total_generator_loss = generator_loss + clip_weight * clip_loss
        total_generator_loss.backward()

        training_stats = {
            "Generator Loss"       : round(generator_loss.item(),4),
            "CLIP Loss"            : round(clip_loss.item(),4),
            "Generator Total Loss" : round(total_generator_loss.item(),4)
        }
        if verbose: print(training_stats)

In [None]:
real_imgs = torch.rand(1, 3, 128, 128)
gen_z = torch.rand(5, z_dim)
c_raw = ["cat", "dog", "tiger", "elephant", "zebra"]

accumulate_gradients(phase = "D", 
                     cur_nimg = 500,
                     real_imgs = real_imgs,
                     c_raw = c_raw, 
                     gen_z = gen_z,
                     verbose = True)

In [12]:
accumulate_gradients(phase = "G", 
                     cur_nimg = 500,
                     real_imgs = real_imgs,
                     c_raw = c_raw, 
                     gen_z = gen_z,
                     verbose = True)

  return torch._native_multi_head_attention(


{'Generator Loss': 0.0008, 'CLIP Loss': 2.0306, 'Generator Total Loss': 6.0926}
