StyleGAN3 Music Video

This Colab is to try out **StyleGAN3** (aka Alias-Free GAN) released in [this repo](https://github.com/NVlabs/stylegan3) by NVidia. Colab produced by [crimeacs](https://twitter.com/EarthML1). 

**[UPD 17.10.2021]** Added Music Video Generation

[UPD 14.10.2021] Added Cosplay Faces trained by [@l4rz](https://twitter.com/l4rz)

In [None]:
# StyleGAN model. afhq is animal faces. ffhq flickr faces and metfaces painted faces. Cosplay is anime style
model = "stylegan2-cosplay-faces-512x512-px" #@param ["stylegan2-cosplay-faces-512x512-px", "stylegan3-r-afhqv2-512x512.pkl", "stylegan3-r-ffhq-1024x1024.pkl", "stylegan3-r-ffhqu-1024x1024.pkl","stylegan3-r-ffhqu-256x256.pkl","stylegan3-r-metfaces-1024x1024.pkl","stylegan3-r-metfacesu-1024x1024.pkl","stylegan3-t-afhqv2-512x512.pkl","stylegan3-t-ffhq-1024x1024.pkl","stylegan3-t-ffhqu-1024x1024.pkl","stylegan3-t-ffhqu-256x256.pkl","stylegan3-t-metfaces-1024x1024.pkl","stylegan3-t-metfacesu-1024x1024.pkl"]

# Input audio file (wav or mp3)
audio_file = '' #@param {type: "string"}

# Random seed. Each seed generates a unique image
seed =  100500 #@param {type:"number"}

#@markdown How variable should the video be? (lower values - less variable)
truncation_psi = 0.8 #@param {type:"number"}

#@markdown Proportion between changing single image vs random images (just try it)
single_image_vs_random =  0.9 #@param {type:"number"}

#@markdown Cut audio to N seconds
cut_len = -1 #@param {type:"number"}

#@markdown How many frames to use for interpolation?
interp_frames =  5 #@param {type:"number"}

output_path = "/content"


In [None]:
#@title Install dependencies
from IPython.display import clear_output

!git clone https://github.com/NVlabs/stylegan3.git
%cd stylegan3
!wget -N -O mini.sh https://repo.anaconda.com/miniconda/Miniconda3-py38_4.8.2-Linux-x86_64.sh
!chmod +x mini.sh
!bash ./mini.sh -b -f -p /usr/local
!conda install -q -y --prefix /usr/local jupyter
!python -m ipykernel install --name "py38" --user
!pip install click -q
!pip install numpy -q
!pip install pillow -q
!pip install torch -q
!pip install scipy -q
!pip install Ninja -q
!pip install imageio -q
!pip install imageio-ffmpeg -q
clear_output()


In [None]:
#@title # Generate 🎵 music video
#@markdown ##**Choose your settings**
from IPython.display import clear_output
%cd /content/stylegan3

import requests
import pickle
import torch 
import os
import numpy as np
import matplotlib.pyplot as plt

import librosa
from scipy.io import wavfile

import time
import torchvision.transforms.functional as TF
from tqdm.notebook import tqdm

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def fetch_model(url_or_path):
    basename = os.path.basename(url_or_path)
    if os.path.exists(basename):
        return basename
    else:
        !wget -c '{url_or_path}'
        return basename

baselink ='https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/'

if model == "stylegan2-cosplay-faces-512x512-px":
    baselink = 'https://l4rz.net/'
    model = 'cosplayface-snapshot-004000-18160-FID367.pkl'

network_url = baselink + model
device = torch.device('cuda:0')

with open(fetch_model(network_url), 'rb') as fp:
  G = pickle.load(fp)['G_ema'].to(device)



zs = torch.randn([10000, G.mapping.z_dim], device=device)
w_stds = G.mapping(zs, None).std(0)

#!wget {audio_link} -O audio.mp3
if audio_file.endswith('.mpeg'):
    !mv $audio_file audio.mp3
    audio_file = 'audio.mp3'
    
arr, fr = librosa.load(audio_file)
if if cut_len == -1:
  cut_len = len(arr) // fr
else:
    arr = arr[:fr*cut_len]

wavfile.write('audio.wav', fr, arr)

stft = torch.stft(torch.tensor(arr), 
           G.mapping.z_dim*2-1,  
           center=False, 
           pad_mode='reflect', 
           normalized=True, 
           onesided=True, 
           return_complex=True)

stft = torch.log(stft.abs())[:,::10]

clear_output()

#FRAMES
import time
import torchvision.transforms.functional as TF
from tqdm.notebook import tqdm

zq = []
with torch.no_grad():
    timestring = time.strftime('%Y%m%d%H%M%S')
    rand_z = torch.randn(stft.size(-1), G.mapping.z_dim).to(device)
    q = (G.mapping(rand_z, None, truncation_psi=truncation_psi))

    for i in range(stft.size(-1)):
        frame = stft[:,i].T.to(device)
        z = torch.mean(G.mapping(frame.unsqueeze(0), None, truncation_psi=truncation_psi), dim=0)
        zq.append(z.unsqueeze(0)*single_image_vs_random + q[i]*(1-single_image_vs_random))

    count = 0
    for k in tqdm(range(len(zq)-1)):
        i_val = torch.linspace(0,1,interp_frames).to(device)
        for interpolation in tqdm(i_val, leave=False):
            interp = torch.lerp(zq[k], zq[k+1], interpolation)
            images = G.synthesis(interp)
            images = ((images + 1)/2).clamp(0,1)
            pil_image = TF.to_pil_image(images[0].cpu())
            os.makedirs(f'samples/{timestring}', exist_ok=True)
            pil_image.save(f'samples/{timestring}/{count:04}.png')
            count+=1


#VIDEO
from IPython import display
from base64 import b64encode
from tqdm.notebook import tqdm
from PIL import Image

fps = count/cut_len

frames = []
# tqdm.write('Generating video...')
for i in sorted(os.listdir(f'samples/{timestring}')): #
    frames.append(Image.open(f"samples/{timestring}/{i}"))

from subprocess import Popen, PIPE
p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', 'video.mp4'], stdin=PIPE)
for im in tqdm(frames):
    im.save(p.stdin, 'PNG')
p.stdin.close()
p.wait()

!ffmpeg -y -i video.mp4 -i audio.wav -map 0 -map 1:a -c:v copy -shortest video_audio.mp4
!cp -v video_audio.mp4 $output_path
clear_output()
# mp4 = open('video.mp4','rb').read()
#mp4 = open('video_audio.mp4','rb').read()
#data_url = "data:video/mp4;base64," + b64encode(mp4).decode()

#display.HTML("""
#<video width=400 controls>
#      <source src="%s" type="video/mp4">
#</video>
#""" % data_url)
