In [35]:
! pip install git+https://github.com/openai/CLIP.git
! pip install ftfy regex tqdm
! pip install wav2clip

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-ep087kv5
  Running command git clone -q https://github.com/openai/CLIP.git /tmp/pip-req-build-ep087kv5
Collecting wav2clip
  Downloading wav2clip-0.1.0-py3-none-any.whl (9.8 kB)
Installing collected packages: wav2clip
Successfully installed wav2clip-0.1.0


In [6]:
!curl -L -o karras2019stylegan-ffhq-1024x1024.for_g_all.pt -C - 'https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/karras2019stylegan-ffhq-1024x1024.for_g_all.pt'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   685  100   685    0     0   3186      0 --:--:-- --:--:-- --:--:--  3186
100  100M  100  100M    0     0  55.7M      0  0:00:01  0:00:01 --:--:-- 85.3M


In [3]:
import os
import argparse

import torch
import torchvision
import clip
import numpy as np
from PIL import Image

In [4]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict

In [39]:
import librosa
import wav2clip

In [7]:
class MyLinear(nn.Module):
    """Linear layer with equalized learning rate and custom learning rate multiplier."""
    def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True):
        super().__init__()
        he_std = gain * input_size**(-0.5) # He init
        # Equalized learning rate and custom learning rate multiplier.
        if use_wscale:
            init_std = 1.0 / lrmul
            self.w_mul = he_std * lrmul
        else:
            init_std = he_std / lrmul
            self.w_mul = lrmul
        self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(output_size))
            self.b_mul = lrmul
        else:
            self.bias = None

    def forward(self, x):
        bias = self.bias
        if bias is not None:
            bias = bias * self.b_mul
        return F.linear(x, self.weight * self.w_mul, bias)



class MyConv2d(nn.Module):
    """Conv layer with equalized learning rate and custom learning rate multiplier."""
    def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True,
                intermediate=None, upscale=False):
        super().__init__()
        if upscale:
            self.upscale = Upscale2d()
        else:
            self.upscale = None
        he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init
        self.kernel_size = kernel_size
        if use_wscale:
            init_std = 1.0 / lrmul
            self.w_mul = he_std * lrmul
        else:
            init_std = he_std / lrmul
            self.w_mul = lrmul
        self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(output_channels))
            self.b_mul = lrmul
        else:
            self.bias = None
        self.intermediate = intermediate

    def forward(self, x):
        bias = self.bias
        if bias is not None:
            bias = bias * self.b_mul
        
        have_convolution = False
        if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
            # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
            # this really needs to be cleaned up and go into the conv...
            w = self.weight * self.w_mul
            w = w.permute(1, 0, 2, 3)
            # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
            w = F.pad(w, (1,1,1,1))
            w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
            x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2)
            have_convolution = True
        elif self.upscale is not None:
            x = self.upscale(x)
    
        if not have_convolution and self.intermediate is None:
            return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2)
        elif not have_convolution:
            x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2)
        
        if self.intermediate is not None:
            x = self.intermediate(x)
        if bias is not None:
            x = x + bias.view(1, -1, 1, 1)
        return x


class NoiseLayer(nn.Module):
    """adds noise. noise is per pixel (constant over channels) with per-channel weight"""
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(channels))
        self.noise = None
    
    def forward(self, x, noise=None):
        if noise is None and self.noise is None:
            noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
        elif noise is None:
            # here is a little trick: if you get all the noiselayers and set each
            # modules .noise attribute, you can have pre-defined noise.
            # Very useful for analysis
            noise = self.noise
        x = x + self.weight.view(1, -1, 1, 1) * noise
        return x  


class StyleMod(nn.Module):
    def __init__(self, latent_size, channels, use_wscale):
        super(StyleMod, self).__init__()
        self.lin = MyLinear(latent_size,
                            channels * 2,
                            gain=1.0, use_wscale=use_wscale)
        
    def forward(self, x, latent):
        style = self.lin(latent) # style => [batch_size, n_channels*2]
        shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]
        style = style.view(shape)  # [batch_size, 2, n_channels, ...]
        x = x * (style[:, 0] + 1.) + style[:, 1]
        return x


class PixelNormLayer(nn.Module):
    """ This layer ensures that the input vector have std = 1:
        - std = 1/N * sqrt(x-mean(x))
        - In this case x comes from a normal centred in 0 so mean(x) = 0
        - Dividing the value of std(x) to x makes the normal of x become 1:
            + If the std == 1 it remains the same
            + If the std > 1 it decrease the values of x which in consequence reduce the std
            + If the std < 1 it increase the values of x which in consequence increase the std
    """
    def __init__(self, epsilon=1e-8):
        super().__init__()
        self.epsilon = epsilon
    def forward(self, x):
        return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)


class BlurLayer(nn.Module):
    def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1):
        super(BlurLayer, self).__init__()
        kernel=[1, 2, 1]
        kernel = torch.tensor(kernel, dtype=torch.float32)
        kernel = kernel[:, None] * kernel[None, :]
        kernel = kernel[None, None]
        if normalize:
            kernel = kernel / kernel.sum()
        if flip:
            kernel = kernel[:, :, ::-1, ::-1]
        self.register_buffer('kernel', kernel)
        self.stride = stride
    
    def forward(self, x):
        # expand kernel channels
        kernel = self.kernel.expand(x.size(1), -1, -1, -1)
        x = F.conv2d(
            x,
            kernel,
            stride=self.stride,
            padding=int((self.kernel.size(2)-1)/2),
            groups=x.size(1)
        )
        return x


def upscale2d(x, factor=2, gain=1):
    assert x.dim() == 4
    if gain != 1:
        x = x * gain
    if factor != 1:
        shape = x.shape
        x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)
        x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])
    return x


class Upscale2d(nn.Module):
    def __init__(self, factor=2, gain=1):
        super().__init__()
        assert isinstance(factor, int) and factor >= 1
        self.gain = gain
        self.factor = factor
    def forward(self, x):
        return upscale2d(x, factor=self.factor, gain=self.gain)


class G_mapping(nn.Sequential):
    def __init__(self, nonlinearity='lrelu', use_wscale=True):
        act, gain = {'relu': (torch.relu, np.sqrt(2)),
                     'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
        layers = [
            ('pixel_norm', PixelNormLayer()),
            ('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense0_act', act),
            ('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense1_act', act),
            ('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense2_act', act),
            ('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense3_act', act),
            ('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense4_act', act),
            ('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense5_act', act),
            ('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense6_act', act),
            ('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense7_act', act)
        ]
        super().__init__(OrderedDict(layers))
        
    def forward(self, x):
        x = super().forward(x)
        # Broadcast
        x = x.expand(-1, 18, -1)
        return x


class Truncation(nn.Module):
    def __init__(self, avg_latent, max_layer=8, threshold=0.7):
        super().__init__()
        self.max_layer = max_layer
        self.threshold = threshold
        self.register_buffer('avg_latent', avg_latent) # parameter of the module which is not trainable and is not passed to the optimizer when calling parameters() function
    def forward(self, x):
        assert x.dim() == 3
        interp = torch.lerp(self.avg_latent, x, self.threshold)
        do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1)
        return torch.where(do_trunc, interp, x)


class LayerEpilogue(nn.Module):
    """Things to do at the end of each layer."""
    def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
        super().__init__()
        layers = []
        if use_noise:
            layers.append(('noise', NoiseLayer(channels)))
        layers.append(('activation', activation_layer))
        if use_pixel_norm:
            layers.append(('pixel_norm', PixelNormLayer()))
        if use_instance_norm:
            layers.append(('instance_norm', nn.InstanceNorm2d(channels)))
        self.top_epi = nn.Sequential(OrderedDict(layers))
        if use_styles:
            self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale)
        else:
            self.style_mod = None
    def forward(self, x, dlatents_in_slice=None):
        x = self.top_epi(x)
        if self.style_mod is not None:
            x = self.style_mod(x, dlatents_in_slice)
        else:
            assert dlatents_in_slice is None
        return x


class InputBlock(nn.Module):
    def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
        super().__init__()
        self.const_input_layer = const_input_layer
        self.nf = nf
        if self.const_input_layer:
            # called 'const' in tf
            self.const = nn.Parameter(torch.ones(1, nf, 4, 4))
            self.bias = nn.Parameter(torch.ones(nf))
        else:
            self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN
        self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
        self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)
        self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
        
    def forward(self, dlatents_in_range):
        batch_size = dlatents_in_range.size(0)
        if self.const_input_layer:
            x = self.const.expand(batch_size, -1, -1, -1)
            x = x + self.bias.view(1, -1, 1, 1)
        else:
            x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4)
        x = self.epi1(x, dlatents_in_range[:, 0])
        x = self.conv(x)
        x = self.epi2(x, dlatents_in_range[:, 1])
        return x



class GSynthesisBlock(nn.Module):
    def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
        # 2**res x 2**res # res = 3..resolution_log2
        super().__init__()
        if blur_filter:
            blur = BlurLayer(blur_filter)
        else:
            blur = None
        self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,
                                 intermediate=blur, upscale=True)
        self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
        self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)
        self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
            
    def forward(self, x, dlatents_in_range):
        x = self.conv0_up(x)
        x = self.epi1(x, dlatents_in_range[:, 0])
        x = self.conv1(x)
        x = self.epi2(x, dlatents_in_range[:, 1])
        return x



class G_synthesis(nn.Module):
    def __init__(self,
        dlatent_size        = 512,          # Disentangled latent (W) dimensionality.
        num_channels        = 3,            # Number of output color channels.
        resolution          = 1024,         # Output resolution.
        fmap_base           = 8192,         # Overall multiplier for the number of feature maps.
        fmap_decay          = 1.0,          # log2 feature map reduction when doubling the resolution.
        fmap_max            = 512,          # Maximum number of feature maps in any layer.
        use_styles          = True,         # Enable style inputs?
        const_input_layer   = True,         # First layer is a learned constant?
        use_noise           = True,         # Enable noise inputs?
        randomize_noise     = True,         # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
        nonlinearity        = 'lrelu',      # Activation function: 'relu', 'lrelu'
        use_wscale          = True,         # Enable equalized learning rate?
        use_pixel_norm      = False,        # Enable pixelwise feature vector normalization?
        use_instance_norm   = True,         # Enable instance normalization?
        dtype               = torch.float32,  # Data type to use for activations and outputs.
        blur_filter         = [1,2,1],      # Low-pass filter to apply when resampling activations. None = no filtering.
        ):
        
        super().__init__()
        def nf(stage):
            return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
        self.dlatent_size = dlatent_size
        resolution_log2 = int(np.log2(resolution))
        assert resolution == 2**resolution_log2 and resolution >= 4

        act, gain = {'relu': (torch.relu, np.sqrt(2)),
                     'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
        num_layers = resolution_log2 * 2 - 2
        num_styles = num_layers if use_styles else 1
        torgbs = []
        blocks = []
        for res in range(2, resolution_log2 + 1):
            channels = nf(res-1)
            name = '{s}x{s}'.format(s=2**res)
            if res == 2:
                blocks.append((name,
                               InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale,
                                      use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
                
            else:
                blocks.append((name,
                               GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
            last_channels = channels
        self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale)
        self.blocks = nn.ModuleDict(OrderedDict(blocks))
        
    def forward(self, dlatents_in):
        # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
        # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)
        batch_size = dlatents_in.size(0)       
        for i, m in enumerate(self.blocks.values()):
            if i == 0:
                x = m(dlatents_in[:, 2*i:2*i+2])
            else:
                x = m(x, dlatents_in[:, 2*i:2*i+2])
        rgb = self.torgb(x)
        return rgb


avg_latent = torch.zeros(1, 18, 512)

g_all = nn.Sequential(OrderedDict([
    ('g_mapping', G_mapping()),
    # ('truncation', Truncation(avg_latent)),
    ('g_synthesis', G_synthesis())    
]))
g_all.load_state_dict(torch.load('karras2019stylegan-ffhq-1024x1024.for_g_all.pt'))

g_synthesis = g_all.g_synthesis
g_mapping = g_all.g_mapping

In [8]:
def saveFeatureMaps(model, layer=1):
    root = './feature_maps'
    if not os.path.exists(root):
        os.mkdir(root)
    for idx, m in enumerate(model.linear):
        print(m)
        if isinstance(m, nn.Linear):
            print('Linear ', idx)
            weight = m.weight.view(18,-1,512)
            print(weight.size())
            for s in range(weight.size(1)):
                w1 = weight[:,s,:]
                for w in w1:
                    w = sorted(w.detach().cpu().numpy(), reverse=True)
                    print(w[0:10])
                print('\n')

            # for w in range(weight.size(1)):
            #     print(weight[:,w,:], end='\n')
            #     break
            break


            # img = np.moveaxis(weight.cpu().detach().numpy(), 0, -1)
            # img_resized = np.moveaxis(img_resized[0].cpu().detach().numpy().squeeze(), 0, -1)

            # img_array = img*255
            # Image.fromarray(img_array.astype(np.uint8)).resize((400, 400)).save(current_dir + '/img' + str(0) + '.png')

            # img_array = img_resized*255
            # Image.fromarray(img_array.astype(np.uint8)).resize((400, 400)).save(current_dir + '/img_resized' + str(0) + '.png')

            # f_map = np.moveaxis(feature_maps_ref[0].cpu().detach().numpy().squeeze(), 0, -1)

            # if not os.path.exists(current_dir):
            #     os.mkdir(current_dir)

            # for f in range(f_map.shape[-1]):
            #     img_array = f_map[:,:,f]
            #     img_array = ((img_array - np.min(f_map))/np.max(f_map))*255
            #     f_map_img = Image.fromarray(img_array.astype(np.uint8)).convert('L').resize((400, 400))
                
            #     img_array = img_resized*255
            #     img_resized_img = Image.fromarray(img_array.astype(np.uint8)).resize((400, 400))
                
            #     img = np.expand_dims(np.asarray(f_map_img), 2).repeat(3,2)/255 * np.asarray(img_resized_img)
            #     Image.fromarray(img.astype(np.uint8)).save(current_dir + '/feature_map_' + str(0) + '_' + str(f) + '.png')


class GetFeatureMaps(nn.Module):
    """ 
        GetFeatureMaps: model that returns the flattened feature maps in a certain layer for a certain model
        
         """
    def __init__(self, model = None, layer = 45):
         
        super(GetFeatureMaps, self).__init__()

        # The Sequential element is the first dictionary of the model and it's where we can find the inception modules.
        self.feature_map = nn.Sequential(*model[0:layer+1])

    def forward(self, imgs):
        return self.feature_map(imgs).view(imgs.size(0), -1)


def transform_img(img, img_size):
    transforms = [
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # mean and std of the imageNet dataset
    ]
    transform = T.Compose(transforms)
    
    return transform(img).unsqueeze(0)


def compute_loss(f_maps, f_maps_ref, lambdas):
    assert len(f_maps) == len(f_maps_ref)

    num_f_maps = len(f_maps)
    loss = 0
    for i in range(num_f_maps):
        loss += (nn.functional.l1_loss(f_maps[i], f_maps_ref[i])/num_f_maps)*lambdas[i]
        # print("LOSSS " + str(i) + " -->",  nn.functional.l1_loss(f_map_norm, f_map_ref_norm))
    return loss

In [28]:
output_path = 'generations'
batch_size = 1
prompt = "The image of an young asian lady"
lr = 1e-2
img_save_freq = 100
ref_img_path = None
max_iter = 500

In [37]:
audio_prompt = "girl_talking.wav"

In [40]:
wav2clip_model = wav2clip.get_model()
audio, sr = librosa.load(audio_prompt, sr=16000)

Downloading: "https://github.com/descriptinc/lyrebird-wav2clip/releases/download/v0.1.0-alpha/Wav2CLIP.pt" to /root/.cache/torch/hub/checkpoints/Wav2CLIP.pt


  0%|          | 0.00/46.7M [00:00<?, ?B/s]

In [29]:
output_dir = output_path

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("USING ", device)

USING  cuda


In [42]:
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
vgg16 = torchvision.models.vgg16(pretrained=True).to(device)
vgg_layers = vgg16.features

vgg_layer_name_mapping = {
    '1': "relu1_1",
    '3': "relu1_2",
    '6': "relu2_1",
    '8': "relu2_2",
    # '15': "relu3_3",
    # '22': "relu4_3"
}

In [43]:
g_synthesis.eval()
g_synthesis.to(device)

G_synthesis(
  (torgb): MyConv2d()
  (blocks): ModuleDict(
    (4x4): InputBlock(
      (epi1): LayerEpilogue(
        (top_epi): Sequential(
          (noise): NoiseLayer()
          (activation): LeakyReLU(negative_slope=0.2)
          (instance_norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        )
        (style_mod): StyleMod(
          (lin): MyLinear()
        )
      )
      (conv): MyConv2d()
      (epi2): LayerEpilogue(
        (top_epi): Sequential(
          (noise): NoiseLayer()
          (activation): LeakyReLU(negative_slope=0.2)
          (instance_norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        )
        (style_mod): StyleMod(
          (lin): MyLinear()
        )
      )
    )
    (8x8): GSynthesisBlock(
      (conv0_up): MyConv2d(
        (upscale): Upscale2d()
        (intermediate): BlurLayer()
      )
      (epi1): LayerEpilogue(
        (top_epi): Sequential(
    

In [44]:
latent_shape = (batch_size, 1, 512)

normal_generator = torch.distributions.normal.Normal(
    torch.tensor([0.0]),
    torch.tensor([1.]),
)

# init_latents = normal_generator.sample(latent_shape).squeeze(-1).to(device)
latents_init = torch.zeros(latent_shape).squeeze(-1).to(device)
latents = torch.nn.Parameter(latents_init, requires_grad=True)

optimizer = torch.optim.Adam(
    params=[latents],
    lr=lr,
    betas=(0.9, 0.999),
)


In [57]:
def truncation(x, threshold=0.7, max_layer=8):
    avg_latent = torch.zeros(batch_size, x.size(1), 512).to(device)
    interp = torch.lerp(avg_latent, x, threshold)
    do_trunc = (torch.arange(x.size(1)) < max_layer).view(1, -1, 1).to(device)
    return torch.where(do_trunc, interp, x)

def tensor_to_pil_img(img):
    img = (img.clamp(-1, 1) + 1) / 2.0
    img = img[0].permute(1, 2, 0).detach().cpu().numpy() * 255
    img = Image.fromarray(img.astype('uint8'))
    return img


clip_transform = torchvision.transforms.Compose([
    # clip_preprocess.transforms[2],
    clip_preprocess.transforms[4],
])

if ref_img_path is None:
    ref_img = None
else:
    ref_img = clip_preprocess(Image.open(ref_img_path)).unsqueeze(0).to(device)

clip_normalize = torchvision.transforms.Normalize(
    mean=(0.48145466, 0.4578275, 0.40821073),
    std=(0.26862954, 0.26130258, 0.27577711),
)

def compute_clip_loss(img, text):
    # img = clip_transform(img)
    img = torch.nn.functional.upsample_bilinear(img, (224, 224))

    text_features = torch.from_numpy(wav2clip.embed_audio(audio, wav2clip_model)).to(device).to(torch.float16)
    image_features = clip_model.encode_image(img)

    logit_scale = clip_model.logit_scale.exp()
    logits_per_image = logit_scale * image_features @ text_features.t()
    logits_per_text = logits_per_image.t()
    
    return 1/logits_per_image * 100

def compute_perceptual_loss(gen_img, ref_img):
    gen_img = torch.nn.functional.upsample_bilinear(img, (224, 224))
    loss = 0
    len_vgg_layer_mappings = int(max(vgg_layer_name_mapping.keys()))

    ref_feats = ref_img
    gen_feats = gen_img

    for idx, (name, module) in enumerate(vgg_layers._modules.items()):
        ref_feats = module(ref_feats)
        gen_feats = module(gen_feats)
        if name in vgg_layer_name_mapping.keys():
            loss += torch.nn.functional.mse_loss(ref_feats, gen_feats)
        
        if idx >= len_vgg_layer_mappings:
            break
    
    return loss/len_vgg_layer_mappings

In [58]:
counter = 0
while True:
    dlatents = latents.repeat(1,18,1)
    img = g_synthesis(dlatents)
    
    # NOTE: clip normalization did not seem to have much effect
    # img = clip_normalize(img)

    loss = compute_clip_loss(img, prompt)

    # NOTE: uncomment to use perceptual loos. Still WIP. You will need to define
    # the `ref_img_path` to use it. The image referenced will be the one 
    # used to condition the generation.
    # perceptual_loss = compute_perceptual_loss(img, ref_img)
    # loss = loss + perceptual_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if counter % img_save_freq == 0:
        img = tensor_to_pil_img(img)
        img.save(os.path.join(output_dir, f'{counter}.png'))

        print(f'Step {counter}')
        print(f'Loss {loss.data.cpu().numpy()[0][0]}')

    counter += 1
    if(counter > max_iter): break



Step 0
Loss 0.047210693359375
Step 100
Loss 0.042572021484375
Step 200
Loss 0.036407470703125
Step 300
Loss 0.036407470703125
Step 400
Loss 0.033538818359375
Step 500
Loss 0.03253173828125
