In [5]:
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import RandomCrop
# from torch_utils import training_stats
from clip import CLIP

import sys
import os

cur_path = '/'.join(os.getcwd().split('/')[:-1])
sys.path.insert(0, f'{cur_path}/torch_utils/ops')
sys.path.insert(0, f'{cur_path}/torch_utils')

import upfirdn2d

In [19]:
def spherical_distance(x: torch.Tensor, y: torch.Tensor):
    x = F.normalize(x, dim = -1)
    y = F.normalize(y, dim = -1)

    # Smaller angle -> more similar
    # Larger angle  -> more dissimilar
    return (x * y).sum(-1).arccos().pow(2)

x = torch.rand(5, 10)
y = torch.rand(5, 10)
spherical_distance(x, y), spherical_distance(x, x)

(tensor([0.6201, 0.4719, 0.2988, 0.6245, 0.7970]),
 tensor([0.0000e+00,        nan, 0.0000e+00, 1.1921e-07, 1.1921e-07]))

In [39]:
def set_blur_sigma(cur_nimg: int):
    blur_fade_kimg = 2 # fade out after 2,000 images
    blur_init_sigma = 2

    if blur_fade_kimg > 1:
        blur_curr_sigma = max(1 - cur_nimg / (blur_fade_kimg  * 1000), 0) * blur_init_sigma
    else: 
        blur_curr_sigma = 0

    return blur_curr_sigma

set_blur_sigma(0), set_blur_sigma(1000), set_blur_sigma(2000)

(2.0, 1.0, 0.0)

In [None]:
def blur(img: torch.Tensor, blur_sigma: float) -> torch.Tensor:
    # Applies Blur
    blur_size = np.floor(blur_sigma * 3)
    if blur_size > 0:
        f = torch.arange(-blur_size, blur_size + 1, device=img.device, dtype = torch.float32) # e.g., [-3, -2, ..., 3]
        f = f.div(blur_sigma).square().neg().exp2()                                           # exp(-x^2 / (2σ^2))
        img = upfirdn2d.filter2d(img, f / f.sum())
    return img

img = torch.rand(5, 3, 224, 224)
blur(img, 3).shape

torch.Size([5, 3, 224, 224])