In [None]:
import time
import os
from scipy.interpolate import interp1d
from scipy.io import wavfile
import numpy as np
from glob import glob
from PIL import Image
import PIL
import torch
from torchvision import transforms
from model import Generator
from dataset import MultiResolutionDataset
import moviepy.editor

import matplotlib.pyplot as plt
from IPython.display import clear_output, display

In [None]:
transform_label = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,), inplace=True),
    ]
)

In [None]:
generator = Generator(256, 512, 8, 2, architecture='spade').cuda()
ckpt = torch.load('checkpoint_256_spade_with_noise/250000.pt')
generator.load_state_dict(ckpt['g_ema'])

In [None]:
def generate_with_w(label, w):
    with torch.no_grad():
        res = generator(label.unsqueeze(0).cuda(), [w], input_is_latent=True)

    res = res[0].cpu()[0].numpy()
    res = np.transpose(res, (1, 2, 0))
    res = (res * 0.5 + 0.5) * 255
    res = np.clip(res, 0, 255).astype(np.uint8)
    return res

In [None]:
n_labels = len(glob('youtube_512_one_person/label*'))
paths = [f'youtube_512_one_person/label_{i}.jpg' for i in range(n_labels)]

## Music

In [None]:
audio = {}
fps = 30

for mp3_filename in [f for f in os.listdir('audio_data') if f.endswith('.mp3')]:
    mp3_filename = f'audio_data/{mp3_filename}'
    wav_filename = mp3_filename[:-4] + '.wav'
    if not os.path.exists(wav_filename):
        audio_clip = moviepy.editor.AudioFileClip(mp3_filename)
        audio_clip.write_audiofile(wav_filename, fps=44100, nbytes=2, codec='pcm_s16le')
    track_name = os.path.basename(wav_filename)[15:-5]
    rate, signal = wavfile.read(wav_filename)
    signal = np.mean(signal, axis=1) # to mono
    signal = np.abs(signal)
    duration = signal.shape[0] / rate
    frames = int(np.ceil(duration * fps))
    samples_per_frame = signal.shape[0] / frames
    audio[track_name] = np.zeros(frames, dtype=signal.dtype)
    for frame in range(frames):
        start = int(round(frame * samples_per_frame))
        stop = int(round((frame + 1) * samples_per_frame))
        audio[track_name][frame] = np.mean(signal[start:stop], axis=0)
    audio[track_name] /= max(audio[track_name])

for track in sorted(audio.keys()):
    plt.figure(figsize=(8, 3))
    plt.title(track)
    plt.plot(audio[track])
    plt.savefig(f'audio_data/{track}.png')

In [None]:
seed = 2

In [None]:
def get_ws(n, frames, seed):
    filename = f'audio_data/ws_{n}_{frames}_{seed}.npy'
    if not os.path.exists(filename):
        src_ws = np.random.RandomState(seed).randn(n, 512)
        ws = np.empty((frames, 512))
        for i in range(512):
            # FIXME: retarded
            x = np.linspace(0, 3*frames, 3*len(src_ws), endpoint=False)
            y = np.tile(src_ws[:, i], 3)
            x_ = np.linspace(0, 3*frames, 3*frames, endpoint=False)
            y_ = interp1d(x, y, kind='quadratic', fill_value='extrapolate')(x_)
            ws[:, i] = y_[frames:2*frames]
        np.save(filename, ws)
    else:
        ws = np.load(filename)
    return ws

def mix_styles(wa, wb, ivs):
    w = np.copy(wa)
    for i, v in ivs:
        w[i] = wa[i] * (1 - v) + wb[i] * v
    return w

def normalize_vector(v):
    return v * np.std(w_avg) / np.std(v) + np.mean(w_avg) - np.mean(v)

def render_frame(t):
    global base_index
    frame = np.clip(np.int(np.round(t * fps)), 0, frames - 1)
    base_index += base_speed * audio[''][frame]**2
    base_w = base_ws[int(round(base_index)) % len(base_ws)]
    
    psi = 0.7
    base_w = w_avg + (base_w - w_avg) * psi
    w = base_w
    w += mouth_open * audio[''][frame] * 0.5
    
    label_frame = np.clip(np.int(np.round(len(paths) / frames * frame)), 0, len(paths) - 1)
    label = transform_label(Image.open(paths[label_frame]))
    image = generate_with_w(label, torch.tensor(w, dtype=torch.float32).cuda())
    image = Image.fromarray(image).resize((size, size), PIL.Image.LANCZOS)
    return np.array(image)

In [None]:
z = torch.randn(16384, 512, device='cuda')

with torch.no_grad():
    w = generator.style(z)

w_avg = w.mean(0, keepdim=True).cpu().numpy()

In [None]:
size = 512
seconds = int(np.ceil(duration))
resolution = 10
base_frames = resolution * frames
base_ws = get_ws(seconds, base_frames, seed)
base_speed = base_frames / sum(audio['']**2)
base_index = 0
mix_ws = get_ws(seconds, frames, seed + 1)
mouth_open = normalize_vector(np.random.RandomState(seed + 2).randn(512))

In [None]:
mp4_filename = 'audio_data/dua_lipa.mp4'
video_clip = moviepy.editor.VideoClip(render_frame, duration=duration)
audio_clip_i = moviepy.editor.AudioFileClip('audio_data/dua_lipa.wav')
# audio_clip_v = moviepy.editor.AudioFileClip('data/Culture Shock (Vocal).wav')
audio_clip = moviepy.editor.CompositeAudioClip([audio_clip_i])
video_clip = video_clip.set_audio(audio_clip)
video_clip.write_videofile(mp4_filename, fps=fps, codec='libx264', audio_codec='aac', bitrate='8M')