In [None]:
from arg_utils import is_notebook, get_cfg
cfg = get_cfg()
# override variables to experiment in notebook
if is_notebook():
    cfg["gpu"] = 0
    cfg["song_name"] = "songs/gnossi_1.mp3"

locals().update(cfg)

In [None]:
# folder where to store temporary calculations (need quite some GBs)
base_folder = "/raid/8wiehe/"

In [None]:
import os
# params that need to be set for each computing node before importing torch
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

import torch
gpu_avail = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3

In [None]:
import sys
import gc
import shutil

import torch
import numpy as np

PYTHONPATH = sys.executable

# set some parameters based on how much memory is available
sideX, sideY = 656, 368
clip_batch_size = 4
gpt_name = "neo1.3"
if gpu_avail >= 10.5:
    clip_batch_size = 4
    gpt_name = "gptj" #"neo2.7"
    sideX, sideY = 880 - 16 * 2, 490 - 16
elif gpu_avail >= 7.5:
    clip_batch_size = 4
    gpt_name = "neo1.3" # neo2.7 possible if it is also int8 quantized
    # 7.6 GB for prompting, 8Gb for clip generation with (880, 490) and bs 8

    
if net == "image":
    if hq:
        sideX, sideY = 1920, 1080
    else:
        sideX, sideY = 1280, 720
    upscale = False
    clip_batch_size = 32

In [None]:
import soundfile
import librosa
import os

resampled_path = "tmp/resampled.wav"
os.makedirs("tmp", exist_ok=True)

# load song and resample to 16k Hz
sr = 16000
raw_song, old_sr = librosa.load(song_name, offset=offset, duration=duration)
song = librosa.resample(raw_song, old_sr, sr)
soundfile.write(resampled_path, song, sr)

In [None]:
import pandas as pd
from mustovi_utils import get_taggram
    
tag_dfs_folder = "./tmp/tag_dfs"
os.makedirs(tag_dfs_folder, exist_ok=True)
key_song_name = song_name.split("/")[-1].split(".")[0]
tag_df_name = f"{key_song_name}_{input_length}_{int(1 / input_overlap)}_{offset}_{duration}.csv"
tag_df_path = os.path.join(tag_dfs_folder, tag_df_name)
if os.path.exists(tag_df_path):
    normed_tag_df = pd.read_csv(tag_df_path, index_col=0)
else:
    normed_tag_df = get_taggram(resampled_path, input_overlap, input_length)
    normed_tag_df.to_csv(tag_df_path)
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
# ordered by importance
show_all_labels = False
if show_all_labels:
    plt.figure(figsize=(12, 20))
    show_df = normed_tag_df.T.copy()
    show_df["mean"] = normed_tag_df.mean(axis=0)
    show_df = show_df.sort_values("mean").drop(columns=["mean"])
    sns.heatmap(show_df)
    plt.tight_layout()

In [None]:
instruments = ["violin", "strings", "sitar", "piano", "harpsichord", 
               "harp", "guitar", "drums", "flute", "synth", "cello"]
genres = ["techno", "soul", "rock", "rnb", "punk", "pop", "opera", 
          "oldies", "new age", "metal", "jazz", "indie rock", "indie pop",
          "indie", "indian", "heavy metal", "hard rock", "funk", "folk", 
          "electronica", "electronic", "country", "classical", "classic rock",
          "classic", "choral", "blues", "alternative rock", "alternative",
          "Progressive rock", "House", "Hip-hop"]
eras = ["60s", "70s", "80s", "90s", "00s"]

speed_tags = ["fast", "slow"]
feeling_tags = ["weird", "soft", "happy", "sad", "catchy", "easy listening", "sexy", "chillout", "beautiful", "chill"]
loudness_tags = ["quiet", "loud"]
vibe_tags = ["ambient", "party", "dance", "Mellow", "experimental"]

genre_like_tags = ["solo", "blues", "Beat"]

feeling_tags = speed_tags + feeling_tags + loudness_tags + vibe_tags
    
plt.figure(figsize=(10, 7))
show_df = normed_tag_df[feeling_tags].T.copy()
show_df[show_df < show_df.mean(axis=0)] = 0
show_df["mean"] = show_df.mean(axis=1)
show_df = show_df.sort_values("mean").drop(columns=["mean"])
sns.heatmap(show_df)
plt.tight_layout()

In [None]:
# reduce taggram to fit fps
musicnn_fps = 62.5
#averaging_window = int(musicnn_fps / fps) # == 2 - 30fps
averaging_window = int(np.round(musicnn_fps / fps)) # == 3 - 20fps
# take step average taggram
fps_taggram = normed_tag_df.rolling(averaging_window, min_periods=1, axis=0).mean() 
fps_taggram = fps_taggram.iloc[::averaging_window, :]

In [None]:
# decide on using only subset
used_tag_df = fps_taggram.copy()
if taggram_mode == "feelings":    
    used_tag_df = used_tag_df[feeling_tags]
    
# merge some columns
merge_dict = {"chill": "chillout"}
for key in merge_dict:
    val = merge_dict[key]
    used_tag_df[val] = (used_tag_df[key] + used_tag_df[val]) / 2
    del used_tag_df[key]
# rename some columns
rename_dict = {"sexy": "sensual",
               "party": "energetic",
               "dance": "moving",
               "easy listening": "harmonious",
               "catchy": "captivating",
               "chillout": "relaxed"}
used_tag_df = used_tag_df.rename(columns=rename_dict)

tag_df_means = used_tag_df.mean()
used_tag_df[used_tag_df < tag_df_means] = 0

In [None]:
def apply_ema(arr, ema_val=0.9):
    ema = arr[0]
    out = []
    for item in arr:
        ema = ema * ema_val + item * (1 - ema_val)
        out.append(ema)
    return out

In [None]:
import matplotlib.pyplot as plt

# create clusters
if create_clusters:
    import sklearn
    
    clustering_feats = used_tag_df.to_numpy()
    # add index to give some time continuity
    cluster_time = True
    if cluster_time:
        idx_arr = np.expand_dims(np.arange(len(clustering_feats)), 1)
        idx_arr = idx_arr / idx_arr.sum() * used_tag_df.sum().max() * cluster_time_weight
        clustering_feats = np.concatenate([clustering_feats, idx_arr], axis=1)
    # create high dim umap embeddings for clustering
    cluster_on_umap_high_d = False
    clusterable_embedding = np.array(apply_ema(clustering_feats, ema_val=ema_val_clustering))
    
    if use_k_means:
        # cluster        
        from yellowbrick.cluster import KElbowVisualizer

        # Instantiate the clustering model and visualizer
        model = clusterer = sklearn.cluster.KMeans(n_clusters=5, n_init=20, max_iter=500)
        visualizer = KElbowVisualizer(
            model, k=(min_clusters, max_clusters), metric='calinski_harabasz', #'silhouette', #'calinski_harabasz', 
            timings=False
        )

        visualizer.fit(clusterable_embedding)        # Fit the data to the visualizer
        visualizer.show()  
        num_clusters = visualizer.elbow_value_
        # make final clustering
        clusterer = sklearn.cluster.KMeans(n_clusters=num_clusters, n_init=50, max_iter=500)
        labels = clusterer.fit_predict(clusterable_embedding)
        # determine centers
        real_centers = clusterer.cluster_centers_
        dist_to_centers = np.array([np.mean((emb - real_centers) ** 2, axis=-1)
                                    for emb in clusterable_embedding[:, :]])
        dist_to_centers = torch.from_numpy(dist_to_centers)
        if cluster_time:
            centers = real_centers[:, :-1]
        else:
            centers = real_centers
        
    else:
        # cluster using hdbscan
        from hdbscan import HDBSCAN
        min_samples = 72
        hdbscan_labels = [-1]
        while len(np.unique(hdbscan_labels)) < min_clusters and min_samples > 2:
            hdbscan_clusterer = HDBSCAN(min_samples=min_samples,  #35, 
                                        cluster_selection_epsilon=0., 
                                        min_cluster_size=min(100, len(clusterable_embedding) // 10))
            hdbscan_labels = hdbscan_clusterer.fit_predict(clusterable_embedding)
            min_samples = min_samples // 2
        print("Min samples: ", min_samples)
        # assign outliers (labels == -1) to the previous non-outlier label 
        clean_hdbscan_labels = []
        last_label = np.array(hdbscan_labels[hdbscan_labels != -1])[0]
        for label in hdbscan_labels:
            if label == -1:
                label = last_label
            else:
                last_label = label
            clean_hdbscan_labels.append(label)
        labels = np.array(clean_hdbscan_labels)
        num_labels = len(np.unique(labels))

        
        # assign distances to cluster for each point
        from hdbscan_utils import *
        data = clusterable_embedding
        tree = hdbscan_clusterer.condensed_tree_
        exemplar_dict = {c: exemplars(c, tree) for c in tree._select_clusters()}
        cluster_ids = tree._select_clusters()
        raw_tree = tree._raw_tree
        all_possible_clusters = np.arange(data.shape[0], raw_tree['parent'].max() + 1).astype(np.float64)
        max_lambda_dict = {c:max_lambda_val(c, raw_tree) for c in all_possible_clusters}

        point_dict = {c:set(points_in_cluster(c, raw_tree)) for c in all_possible_clusters}
        cluster_distances = np.array([combined_membership_vector(x, data, tree, exemplar_dict, cluster_ids,
                                                       max_lambda_dict, point_dict, False) for x in range(len(data))])
        dist_to_centers =  1 - torch.from_numpy(cluster_distances)
        # find cluster representatives by averagin the points per cluster
        real_centers = np.array([data[labels == i].mean(axis=0) for i in range(num_labels)])
        if cluster_time:
            centers = real_centers[:, :-1]
        else:
            centers = real_centers

    show_2d_umap = 0
    if show_2d_umap:
        from umap import UMAP
        # create 2D UMAP embedding to plot
        mapper = UMAP(
            n_neighbors=30,
            min_dist=0.0,
            n_components=2,
            random_state=42,
            metric="cosine",
        ).fit(clustering_feats)
        # make plot
        import umap.plot
        umap.plot.output_notebook()
        df = pd.DataFrame({"step": list(range(len(labels))),
                           "cluster": labels,
                           })
        p = umap.plot.interactive(mapper, 
                                  labels=df["cluster"], 
                                  #values = df["step"],
                                  hover_data=df, point_size=10)
        umap.plot.show(p)
    # show clusters over time
    plt.figure(figsize=(10, 7))
    plt.scatter(range(len(labels)), labels, s=1.5)
    plt.show()
    plt.close()
    # show cluster dists over time
    plt.figure(figsize=(10, 7))
    for i in range(dist_to_centers.shape[1]):
        plt.plot(dist_to_centers[:, i], label=str(i))
    l = plt.legend()
    plt.show()
    plt.close()
    # show heatmap 
    plt.figure(figsize=(10, 7))
    show_df = used_tag_df.T.copy()
    show_df[show_df < show_df.mean(axis=0)] = 0
    show_df["mean"] = show_df.mean(axis=1)
    show_df = show_df.sort_values("mean").drop(columns=["mean"])
    sns.heatmap(show_df)
    plt.tight_layout()
    plt.show()
    plt.close()

In [None]:
import IPython

# listen to clusters

cluster_idx = 0

samples_per_step = int(len(song) / len(hdbscan_labels)) + 1
frame_assignments = []
for label in hdbscan_labels:
    frame_assignments.extend([label] * samples_per_step)
frame_assignments = np.array(frame_assignments)
sections = frame_assignments == cluster_idx
song_section = song[sections[:len(song)]]

IPython.display.Audio(song_section, rate=sr, autoplay=0)

In [None]:
def get_gpt_stories_and_weights(cluster_gpt_stories, n_start_prompts, dist_to_centers, gpt_story_top_k, idx):
    if cluster_gpt_stories is not None:
        story_idx = max(idx - n_start_prompts, 0)
        top_k = dist_to_centers[story_idx].topk(k=gpt_story_top_k, largest=False)
        story_weights = (1 - (top_k.values / dist_to_centers[story_idx].max())) ** 2
        top_idcs = top_k.indices
        gpt_stories = [cluster_gpt_stories[i] for i in top_idcs]
    else:
        gpt_stories = [""]
        story_weights = [1]
    return gpt_stories, story_weights

In [None]:
used_tag_df.mean().sort_values()

In [None]:
# highest_classification

def filter_theme_labels(series, n=6, threshold=0.5, factor=0.9):
    # sort by logits
    cluster_theme = series.sort_values(ascending=False)
    #print(cluster_theme)
    # filter logits below certain value
    cluster_theme = cluster_theme[cluster_theme > threshold]
    # filter logits that are lower than mean over all
    cluster_theme = cluster_theme[[col for col in cluster_theme.index if cluster_theme[col] > factor * used_tag_df.mean()[col]]]
    # pick top N
    cluster_theme = cluster_theme.iloc[:n]
    return cluster_theme


center_df = pd.DataFrame(centers, columns=used_tag_df.columns)

cluster_themes = []
for i in range(len(center_df)):
    cluster_theme = filter_theme_labels(center_df.iloc[i])
    cluster_vals = round(cluster_theme, 2).to_list()
    cluster_theme_names = cluster_theme.index.to_list()
    cluster_theme_names = ", ".join(cluster_theme_names).lower()
    print(str(i) + ":",  cluster_theme_names, cluster_vals)
    cluster_themes.append(cluster_theme_names)

In [None]:
main_theme = used_tag_df.mean().sort_values(ascending=False).iloc[:5]
main_theme_words = ", ".join(main_theme.index.to_list())
main_theme

In [None]:
# main distinctive features 
print(", ".join(center_df.std().sort_values(ascending=False)[:5].index.to_list()))

In [None]:
import IPython

# listen to clusters

cluster_idx = 0

plt.scatter(range(len(labels)), labels, s=1.5)
plt.show()

samples_per_step = int(len(song) / len(labels)) + 1

frame_assignments = []
for label in labels:
    frame_assignments.extend([label] * samples_per_step)
frame_assignments = np.array(frame_assignments)

sections = frame_assignments == cluster_idx

song_section = song[sections[:len(song)]]

IPython.display.Audio(song_section, rate=sr, autoplay=0)

In [None]:
from tqdm.auto import tqdm
import numpy as np

# create musicnn prompts
clip_prompts = []
pbar = tqdm(list(used_tag_df.iterrows()))

for i, row in pbar:
    row = row[row > tag_df_means]
    sorted_row = row.sort_values(ascending=False)

    # generate clip prompt for current musicnn targets
    if prompt_mode == "top_k":
        # get tags
        top_tag_names = list(sorted_row.iloc[:k].index)
        #print(top_tag_names)
        pbar.set_description(", ".join(top_tag_names))
        clip_prompt = ", ".join(top_tag_names)
    elif prompt_mode == "weighted_top_k":
        sorted_row = filter_theme_labels(sorted_row, n=k)#, factor=0.8, threshold=0.1)
        top_tag_names = list(sorted_row.index)
        top_tag_vals = list(sorted_row)
        clip_prompt = {name: val for name, val in zip(top_tag_names, top_tag_vals)}
    elif prompt_mode == "gpt":
        sorted_row = row.sort_values(ascending=False)
        top_tags = sorted_row.iloc[:k]
        top_tag_names = list(top_tags.index)
        if len(top_tag_names) == 0:
            top_tag_names = ["Undecided emptiness"]
        merged_top_tags = ", ".join(top_tag_names)
        if merged_top_tags in prompt_hash_table:
            clip_prompt = prompt_hash_table[merged_top_tags]
        else:
            clip_prompt = gpt_create_prompt(gpt_model, gpt_tokenizer, merged_top_tags)
            pbar.set_description("Tags: " + merged_top_tags + " Prompt: " + clip_prompt)
            #clip_encoding = imagine.create_text_encoding(clip_prompt)
            prompt_hash_table[merged_top_tags] = clip_prompt
            
    clip_prompts.append(clip_prompt)
    
# how many steps are there to fill at the start of the song (256 is the size of the fft-windows of musicnn)
start_prompt = clip_prompts[0]
n_start_prompts = int(np.round((len(song) / (256 * averaging_window) - len(used_tag_df))))
clip_prompts.extend([start_prompt] * n_start_prompts)

In [None]:
text_file is not None

In [None]:
if use_lyrics:    
    def time_str_to_seconds(time_str):
        hrs_str = time_str.split(":")[0]
        min_str = time_str.split(":")[1]
        seconds_str = time_str.split(":")[2].split(",")[0]
        ms_str = time_str.split(":")[2].split(",")[1]
        return int(hrs_str) * 360 + int(min_str) * 60 + int(seconds_str) + int(ms_str) / 1000

    if text_file is None:
        lyrics_path = ".".join(song_name.split(".")[:-1]) + ".srt"

        # read file
        with open(lyrics_path, "r+") as f:
            lines = [l for l in f]
        # format and extract lines
        #print(lines[:10])
        texts = []
        start_times = []
        end_times = []
        while len(lines) > 2:
            if lines[0] == "\n":
                del lines[0]
                continue
            #print("next lines")
            #print(lines[0].strip("\n"))
            #print(lines[1].strip("\n"))
            #print(lines[2].strip("\n"))
            count = 4
            # read times
            start_time_str = lines[1].split(" ")[0]
            end_time_str = lines[1].split(" ")[-1]
            # convert to seconds
            start_time = time_str_to_seconds(start_time_str)
            end_time = time_str_to_seconds(end_time_str)

            # read text
            text = lines[2].strip("♪.\n")
            if len(lines) > 3 and lines[3] != "\n":
                text += " " + lines[3].strip("♪.\n")
                count += 1
                if lines[4] != "\n":
                    text += " " + lines[4].strip("♪.\n")
                    count += 1
                    if lines[5] != "\n":
                        text += " " + lines[5].strip("♪.\n")
                        count += 1
            # remove formatting commands
            while text.find("<") != -1:
                start = text.find("<")
                end = text.find(">")
                text = text[:start] + text[end + 1:]
            # remove spaces and turn to lower-case
            text = text.strip("♪ .\n").lower()

            # save
            if len(text) > 2:
                texts.append(text)
                start_times.append(start_time)
                end_times.append(end_time)

            # delete lines to move on in file
            del lines[:count]
    else:
        with open(text_file, "r") as f:
            lines = f.readlines()
        assert len(lines) == 1
        texts = [s for s in lines[0].split(". ") if len(s) > 2]
        
        num_seconds = len(song) / sr
        times =  np.linspace(0, num_seconds, len(texts))#.astype(int)
        start_times = times[0:-1]
        end_times = times[1:]
                
    from deep_translator import  GoogleTranslator

    #!python3 -m pip install deep-translator

    translated_texts = [GoogleTranslator(source='auto', target='en').translate(text=text)
                        for text in texts]


In [None]:
# generate .srt file if read from .txt file
if use_lyrics and text_file is not None:
    #!python3 -m pip install srt
    import srt
    from srt import Subtitle
    from datetime import timedelta

    subtitles = []
    for i in range(len(start_times)):
        content = texts[i]
        start = timedelta(seconds=start_times[i])
        end = timedelta(seconds=end_times[i])
        subtitle = srt.Subtitle(i, start, end, content)
        subtitles.append(subtitle)
    subtitle_text = srt.compose(subtitles)
    
    # write subtitle text to file
    file_path = text_file.replace(".txt", "") + "_" + song_name.split(".")[0].split("/")[-1] + ".srt"
    with open(file_path, "w+") as f:
        f.write(subtitle_text)


In [None]:
# match texts with start and end time to frames
if use_lyrics:
    num_seconds = len(song) / sr
    seconds_per_step = num_seconds / len(clip_prompts)

    prompt = ""
    count = 0
    lyric_prompts = []
    for i in range(len(clip_prompts)):
        current_time = i * seconds_per_step
        while current_time >= end_times[count]:
            prompt = ""
            count += 1
            #if count >= len(clip_prompts):
            #    break
        if current_time >= start_times[count]:
            prompt = translated_texts[count]
        lyric_prompts.append(prompt)


In [None]:
import sys
sys.path.append("../StyleCLIP_modular")
sys.path.append("../CLIPGuidance")


import argparse
import torch
import gc

from style_clip import Imagine, create_text_path

args = {}
args["lr_schedule"] = 0
args["seed"] = 1

args["neg_text"] = None #'text, signature, watermarks' #'incoherent, confusing, cropped, watermarks'
#'text, signature, watermarks, writings, scribblings'#

args["clip_names"] = ["ViT-B/16", "ViT-B/32", "RN50"]
args["averaging_weight"] = 0
args["early_stopping_steps"] = 0
args["tv_loss_scale"] = 0.0
args["lpips_weight"] = lpips_weight
args["lpips_batch_size"] = 4
args["lpips_net"] = "squeeze"
args["use_russell_transform"] = 1

if net == "vqgan":
    args["model_type"] = "vqgan"
    args["lr"] = 0.03
    args["batch_size"] = clip_batch_size
elif net == "conv":
    args["act_func"] = "gelu"
    args["stride"] = 1
    args["num_layers"] = 5
    args["downsample"] = False
    args["norm_type"] = "layer"
    args["num_channels"] = 64
    args["sideX"] = 1080
    args["sideY"] = 720
    args["lr"] = 0.005
    args["stack_size"] = 4
elif net == "stylegan":
    args["style"] = style
    args["lr"] = 0.005
    args["opt_all_layers"] = 1
elif net == "image":
    args["lr"] = 0.005
    args["batch_size"] = 32
    args["stack_size"] = 1
elif net == "dip":
    args["lr"] = 0.00005
    args["batch_size"] = 16
    args["stack_size"] = 1
    args["optimizer"] = "madgrad"
args["model_type"] = net

args["sideX"] = sideX # 688 #624 #544 #480 
args["sideY"] = sideY # 384 #352 #304 #272 
# 688x384 - 7.792GB, 34s/it
# 720x400 - 7.948GB, 41.3s/it - crashes after a bit
# 624x352 - 6.9GB, 29.8s/it
# 544x304 - 5850MB, 24s/it at 100its per step
args["circular"] = 0

imagine = Imagine(
                save_progress=False,
                open_folder=False,
                save_video=False,
                verbose=False,
                use_mixed_precision=True,
                **args
               )

torch.cuda.empty_cache()
gc.collect()

In [None]:
"""
import importlib
import mustovi_utils
import gpt_j_low_prec
importlib.reload(mustovi_utils)
importlib.reload(gpt_j_low_prec)

from mustovi_utils import load_gpt_model, gen_sent
from clip import tokenize
clip_model = imagine.perceptor.models[0]
gpt_model, gpt_tokenizer = load_gpt_model(gpt_name)
"""
pass

In [None]:
"""
prefix = "The following are adjectives describing a song, followed by a description of the corresponding image:\n "
prefix = "The following are adjectives describing an image, listed in the order of importance. They are followed by a full description of the corresponding image:\n "
prompter = ". Full description:"
examples = {"sad, dark, fast": " A man is running through dark woods while crying.",
            "sad, beautiful, soft, quiet, slow": " An old woman is sitting on a chair in a beautiful garden with her hands folded in front of her. She is looking at you with a sad expression on her face.",
            "electronic, loud, happy, abstract": " Dynamic and vibrant colors forming strong geometric shapes that resemble a rave.",
            "weird, happy, fast": " A man is experiencing a strange dream. He is struggling to feel his feelings, his emotions as they rush too quickly through his body. He is an a state of ecstacy.",
            "harmonious, mellow": " An electric light begins to dim at a distant point in the sky. You feel complete and at one with your environment.",
            #"slow, quiet": " A lion's roar stops in front of him. The lion is slowly moving forward and approaching you. He is silent."
           }

target_text = cluster_themes[3]#"fast, brutal"
target_clip_feats = clip_model.encode_text(tokenize(target_text).to("cuda"))
target_text"""
pass

In [None]:
"An ancient dream starts to calm down. It is quiet and peaceful"
"An ordinary man appears before you. He is listening to you. He is a crazed crazy crazy crazy crazy crazy. He is beating on your legs"

In [None]:
#texts, df = gen_sent(gpt_model, gpt_tokenizer, clip_model, target_clip_feats, 
#             start_text=" Sweet dance candy", p=0.94, 
#            prefix=prefix, examples=examples, prompter=prompter, target_text=target_text,
#            clip_weight=0.7, 
#            clip_temp=0.45, gpt_temp=0.75, out_len=50, v=1, num_beams=10, return_num=5)
#print(texts)

In [None]:
#df = df.sort_values("post_score", ascending=False)

In [None]:
#for i, row in df.iterrows():
#    print(row[1:])#.drop(columns=["sent"]))
#    print(row["sent"].strip())
#    
#    print()

In [None]:
from tqdm.auto import tqdm
from mustovi_utils import load_gpt_model, gen_sent
from clip import tokenize


def gpt_create_prompt(cluster_words_list, gpt_name, clip_model, gpt_model=None, gpt_tokenizer=None, gpt_prefix=""):
    if gpt_model is None:
        gpt_model, gpt_tokenizer = load_gpt_model(gpt_name)

    prefix = "The following are adjectives describing a song, followed by a description of the corresponding image:\n "
    prefix = "The following are adjectives describing an image, listed in the order of importance. They are followed by a full description of the corresponding image:\n "
    prompter = ". Full description:"
    #examples = {"sad, dark, fast": " A man is running through dark woods while crying.",
    #            "sad, beautiful, soft, quiet, slow": " An old woman is sitting on a chair in a beautiful garden with her hands folded in front of her. She is looking at you with a sad expression on her face.",
    #            "electronic, loud, happy, abstract": " Dynamic and vibrant colors forming strong geometric shapes that resemble a rave.",
    #           }
    examples = {"sad, dark, fast": " Running through dark woods while crying.",
            "sad, beautiful, soft, quiet, slow": " An old widow is sitting on a chair in a beautiful garden with her hands folded in front of her. She is looking at you with a sad expression on her face.",
            "electronic, loud, happy, abstract": " Dynamic and vibrant colors are forming strong geometric shapes that resemble a rave.",
            "weird, happy, fast": " A man is experiencing a strange dream. He is struggling to feel his emotions as they rush too quickly through his body. He is an a state of ecstacy.",
            "harmonious, mellow": " An electric light begins to dim at a distant point in the sky. You feel complete and at one with your environment.",
            "slow, quiet": " A lion's roar stops in front of him. The lion is slowly moving forward and approaching you. He is silent."
           }

    
    gpt_stories = []
    for target_text in tqdm(cluster_words_list):
        target_clip_feats = clip_model.encode_text(tokenize(target_text).to("cuda"))

        texts = gen_sent(gpt_model, gpt_tokenizer, clip_model, target_clip_feats, 
                 start_text=gpt_prefix, p=0.93, 
                 prefix=prefix, examples=examples, prompter=prompter, target_text=target_text,
                 clip_weight=0.7, 
                 clip_temp=0.45, gpt_temp=0.75, out_len=50, v=-1, num_beams=50, return_num=1)
        
        text = texts[0]
        gpt_stories.append(text)
        print(text)
        print()
    gpt_model = gpt_model.to("cpu")
    return gpt_stories

used_gpt_stories = None
if do_create_gpt_cluster_stories:
    used_gpt_stories = gpt_create_prompt(cluster_themes, gpt_name, imagine.perceptor.models[0], gpt_prefix=gpt_cluster_prefix)
torch.cuda.empty_cache()
gc.collect()

In [None]:
cluster_themes

In [None]:
used_gpt_stories

In [None]:
from mustovi_utils import load_gpt_model, gen_sent
from clip import tokenize


def gpt_create_theme(theme_words, gpt_name, clip_model, gpt_model=None, gpt_tokenizer=None):
    if gpt_model is None:
        gpt_model, gpt_tokenizer = load_gpt_model(gpt_name)

    prefix = "The following are adjectives, followed by a matching artstyle:\n "
    prompter = ". Matching artstyle:"
    
    prompter = "The name of a matching painter is:"
    prefix= "The following are lists of words describing art, followed by the name of the artist:\n "
    
    prefix = "The following are lists of words describing art, followed by the name of the artist:\n "
    prompter = "Matching visual artist:"
    
    prefix = "The following are lists of adjectives, listed in order of importance. They are followed by a name of an artstyle that matches them:\n "
    prompter = ". Matching artstyle:"
        
    examples = {"introspective, beautiful, sad": " A moody, ambient painting.",
                "expressive, wild, colourful": " An expressionist piece of art.",
                "epic, fantasy, stunning, moody": " Illustrated by Greg Rutkowski.",
                "introspective, beautiful, sad": " Impressionism.",
                "realistic, beautiful, landscapes, forgotten civilizations": " By James Gurney."}
    # popular, internet{prompter} Trending on artstation.
    # rendered, detailed, high-quality{prompter} Rendered in unreal engine.
    # expressionist, beautiful, vibrant. {prompter} Van Gogh.
    # happy, dreamy, romantic, sensual. {prompter} Gustav Klimt.
    
    target_clip_feats = clip_model.encode_text(tokenize(theme_words).to("cuda"))

    texts = gen_sent(gpt_model, gpt_tokenizer, clip_model, target_clip_feats, 
                 start_text="", p=0.9, 
                 prefix=prefix, examples=examples, prompter=prompter, target_text=theme_words,
                 clip_weight=0.2, 
                 clip_temp=0.45, gpt_temp=0.75, out_len=50, v=-1, num_beams=50, return_num=5)    
    text = texts[0]
    print(texts)
    gpt_model = gpt_model.to("cpu")
    return text

gpt_theme = ""
if create_gpt_artstyle:
    gpt_theme = gpt_create_theme(main_theme_words.lower(), gpt_name, imagine.perceptor.models[0])
torch.cuda.empty_cache()
gc.collect()
gpt_theme

In [None]:
gpt_theme

In [None]:
clip_prompts[0:5]

In [None]:
imagine = imagine.cuda()

In [None]:
len(clip_prompts)

In [None]:
imagine = imagine.cuda()

In [None]:
# Calculate encodings based on prompts

clip_target_encodings = []
clip_feature_hash_table = dict()
gpt_suffix = "" if len(gpt_theme) == 0 else f" {gpt_theme}"

count = []

def encode(prompt):
    prompt = prefix + prompt
    if general_theme is not None:
        prompt = prompt + general_theme
    prompt += gpt_suffix
    if prompt in clip_feature_hash_table:
        encoding = clip_feature_hash_table[prompt]
    else:
        count.append(0)
        if len(count) % 50 == 0:
            print(prompt)
        encoding = imagine.create_clip_encoding(text=prompt, img=img_theme)
        #encoding = imagine.create_text_encoding(prompt)
        clip_feature_hash_table[prompt] = encoding
    return encoding


def weighted_average_encoding(encodings, weights):
    clip_encoding = [norm(torch.stack([norm(encoding[j]) * weight for encoding, weight in zip(encodings, weights)]).sum(dim=0))
                         for j in range(len(encodings[0]))]
    return clip_encoding


def norm(a):
    return a / a.norm(dim=-1, keepdim=True)


def clip_mean_direction(direction_prompt, base_prompts, imagine):
    base_encs = [imagine.create_clip_encoding(text=p) for p in base_prompts]
    base_plus_dir_encs = [imagine.create_clip_encoding(text=p + direction_prompt) for p in base_prompts]
    diff_encs = [[norm(norm(base_ext_enc[i]) - norm(base_enc[i])) for i in range(len(base_enc))] for base_enc, base_ext_enc in zip(base_encs, base_plus_dir_encs)]
    mean_diff_enc = [norm(torch.stack([diff_encs[j][i] for j in range(len(diff_encs))]).mean(dim=0)) for i in range(len(diff_encs[0]))]
    return mean_diff_enc



if use_mean_dirs:
    base_prompts = ["A photo of ", " ", "A painting of ", "This painting is: ", "This photo looks ", "I feel ", "I feel: ", "The sky is ",
                   "This is ", "The ground is ", "This person is ", "She is ", "He  is ", "A "]
    dir_dict = {col: clip_mean_direction(col, base_prompts, imagine) for col in used_tag_df}


for idx, prompt in enumerate(tqdm(clip_prompts)):
    gpt_stories, story_weights = get_gpt_stories_and_weights(used_gpt_stories, n_start_prompts, 
                                                             dist_to_centers, gpt_story_top_k, idx)
    
    if use_lyrics:
        lyrics_prompt = lyric_prompts[idx] + ". "
    else:
        lyrics_prompt = ""
    
    story_encodings = []
    for gpt_story, story_weight in zip(gpt_stories, story_weights):
        if isinstance(prompt, dict):
            if use_mean_dirs:
                take_avg = True
                
                base_encoding = encode(lyrics_prompt + gpt_story + " ")
                
                if take_avg:
                    dir_encodings = [dir_dict[tag] for tag in prompt]
                    weights = list(prompt.values())
                    if len(dir_encodings) > 0:                        
                        dir_mean_encoding = weighted_average_encoding(dir_encodings, weights)
                        clip_encoding = weighted_average_encoding([base_encoding, dir_mean_encoding], [1 - mood_weight, mood_weight])
                else:
                    for tag in prompt:
                        weight = prompt[tag]
                        clip_encoding = [base_encoding[i] + dir_dict[tag][i] * weight for i in range(len(base_encoding))]
                clip_encoding = [norm(enc) for enc in clip_encoding]
            else:
                if len(prompt) > 0:
                    encs_suff = [encode(lyrics_prompt + gpt_story + " It feels " + prompt_key + ".")
                                        for prompt_key in prompt]
                    encs_pre = [encode("It feels " + prompt_key + ". " + lyrics_prompt + gpt_story) 
                                        for prompt_key in prompt]
                    encodings = [weighted_average_encoding([encs_suff[i], encs_pre[i]], [1.0, 1.0])
                                for i in range(len(encs_suff))]
                    weights = list(prompt.values())
                    mood_weighted_mean_encoding = weighted_average_encoding(encodings, weights)
                    base_encoding = encode(lyrics_prompt + gpt_story + " ")
                    clip_encoding = weighted_average_encoding([base_encoding, mood_weighted_mean_encoding],
                                                              [1 - mood_weight, mood_weight])
                else:
                    clip_encoding = encode(lyrics_prompt + gpt_story + " ")

        else:
            story_prompt = lyrics_prompt + gpt_story + " " + prompt + "."
            clip_encoding = encode(story_prompt)
        story_encodings.append(clip_encoding)
    clip_encoding = weighted_average_encoding(story_encodings, story_weights)

    
    clip_encoding = [enc.to("cpu") for enc in clip_encoding]
    clip_target_encodings.append(clip_encoding)
    
print(len(count))

In [None]:
# test directions
"""
print(dir_dict.keys())
base_text = "Home."
steps = 150
text_weight = 0.4
dirs = ["slow", "experimental"]
weights = [1.0, 1.0, 1.0, 1.0]

text_enc = imagine.create_text_encoding(base_text)


dir_encodings = [dir_dict[tag] for tag in dirs]
print(dir_encodings[0][0][0][:10])
dir_mean_encoding = weighted_average_encoding(dir_encodings, weights)
clip_encoding = weighted_average_encoding([text_enc, dir_mean_encoding], [text_weight, 1 - text_weight])
print(clip_encoding[0][0][:10])

#for p, w in zip(dirs, weights):
#    dir_enc = dir_dict[p]
#    clip_encoding = [norm(text_enc[i] * text_weight + dir_enc[i] * (1 - text_weight)) for i in range(len(dir_enc))]

    
imagine.set_clip_encoding(encoding=clip_encoding)
imagine.reset()
for _ in tqdm(range(steps)):
    img, loss = imagine.train_step(0, 0)
"""
# to_pil(img.squeeze())

In [None]:
clip_prompts[0]

In [None]:
clip_prompts[:3]

In [None]:
#clip_target_encodings[10][0][0][:10]

In [None]:
# take ema of encodings to smoothen
ema_encodings = []
ema = clip_target_encodings[0]

for encoding in clip_target_encodings:
    ema = [ema_val * ema[i].to("cpu") + (1 - ema_val) * encoding[i].to("cpu") for i in range(len(encoding))]
    ema_encodings.append(ema)

In [None]:
import torchvision.transforms as T
from scipy.interpolate import NearestNDInterpolator
from mustovi_utils import get_spec_norm
import librosa

# create zoom, rotate, shift effects
effects = ["zoom", "rotate", "shiftX", "shiftY", "shear"]
harm_effect_dict =dict()  #{"rotate": 0.0}
perc_effect_dict = dict() #{"zoom": -0.5}
old_cqt_effect_dict = [{"zoom": 1.0}, 
                   {"rotate": 1.0},
                   {"shiftX": 1.0}, 
                   {"shiftY": 1.0},
                   {"shiftY": -1.0},
                   {"shiftX": -1.0},
                   {"rotate": -1.0},
                   {"zoom": -1.0},
                  ]
cqt_effect_dict = [{"zoom": -1.00}, 
                   {"zoom": -0.5},
                   {"rotate": 0.75}, 
                   {"rotate": 0.5},
                   {"rotate": -0.5},
                   {"rotate": -0.75},
                   {"zoom": 0.5},
                   {"zoom": 1.0},
                  ]
                    
# divide song in percussion and harm (might divide in pitches later)
song_harm, song_perc = librosa.effects.hpss(song)
spec_norm_harm = get_spec_norm(song_harm)
spec_norm_perc = get_spec_norm(song_perc)
# get cqt spec
n_chroma = len(cqt_effect_dict)
cqt_spec = librosa.feature.chroma_cqt(y=song, sr=sr,hop_length=256, 
                                      n_chroma=n_chroma, n_octaves=7, 
                                      bins_per_octave=n_chroma * 4, norm=None)
sns.heatmap(cqt_spec)
plt.show()
# take window averages to match video fps
N = averaging_window
spec_norm_harm = np.convolve(spec_norm_harm, np.ones(N) / N , mode='valid')[::N]
spec_norm_perc = np.convolve(spec_norm_perc, np.ones(N) /N, mode='valid')[::N]
cqt_spec = np.array([np.convolve(cqt_line, np.ones(N) / N, mode='valid')[::N] 
                     for cqt_line in cqt_spec])
# min-max norm
spec_norm_harm = (spec_norm_harm - spec_norm_harm.min()) / (spec_norm_harm.max() - spec_norm_harm.min())
spec_norm_perc = (spec_norm_perc - spec_norm_perc.min()) / (spec_norm_perc.max() - spec_norm_perc.min())
cqt_spec = (cqt_spec - cqt_spec.min()) / (cqt_spec.max() - cqt_spec.min())
# create effects
import kornia
    
class Effect:
    def __init__(self, strength, zoom=0, rotate=0, 
                 shiftX=0, shiftY=0, shear=0):
        max_zoom = 0.22
        self.zoom = 1 + max_zoom * zoom * strength
        max_rotate = 10
        self.rotate = max_rotate * rotate * strength
        max_shift = 12
        self.shift_x = max_shift * shiftX * strength
        self.shift_y = max_shift * shiftY * strength
        
        self.shear_x = 0
        self.shear_y = 0

    def __call__(self, img):
        # transform it
        matrix = kornia.geometry.transform.get_affine_matrix2d(translations=torch.tensor([(self.shift_x, self.shift_y)]).float(), 
                                                      center=torch.tensor([(img.shape[-1] // 2, img.shape[-2] // 2)]).float(), 
                                                      scale=torch.tensor([(self.zoom, self.zoom)]).float(), 
                                                      angle=torch.tensor([self.rotate]).float(), 
                                                      sx=torch.tensor([self.shear_x]).float(),
                                                      sy=torch.tensor([self.shear_y]).float())
        if img.ndim < 4:
            img = img.unsqueeze(0)
        transformed = kornia.geometry.transform.warp_perspective(img, matrix.to(img.dtype).to(img.device), img.shape[-2:],
                                      mode='bilinear', padding_mode='reflection', align_corners=True)
        return transformed
        

def merge_dicts(effect_dict, effect_strength_dict, amplitude):
    for key in effect_strength_dict:
        content = effect_strength_dict[key] * amplitude
        if key in effect_dict:
            effect_dict[key] += content
        else:
            effect_dict[key] = content

# create effects that directly alter the image
if args["model_type"] in ("vqgan", "image", "dip"):
    effects_list = []
    for i in range(len(spec_norm_harm)):
        harm = spec_norm_harm[i]
        perc = spec_norm_perc[i]
        cqt = cqt_spec[:, i]

        effect_dict = {}
        merge_dicts(effect_dict, harm_effect_dict, harm)
        merge_dicts(effect_dict, perc_effect_dict, perc)
        for cqt_effect, cqt_amplitude in zip(cqt_effect_dict, cqt):
            merge_dicts(effect_dict, cqt_effect, cqt_amplitude)

        effect = Effect(total_effect_strength, **effect_dict)
        effects_list.append([effect])
else:
    # create effects that alter the clip target shortly
    # these effects should have a different name
    effects_list = [[]] * len(spec_norm_harm)

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
len(ema_encodings)

In [None]:
len(effects_list)

In [None]:
def minmax(a):
    min_ = a.min()
    return (a - min_) / (a.max() - min_)

In [None]:
substeps_per_step = [sub_steps] * len(ema_encodings)
if use_variable_substeps:
    # vary number of substeps to take depending on amplitude
    use_spec_norm = 0
    if use_spec_norm:
        spec_norm_df = pd.DataFrame(get_spec_norm(song))
        # take step average taggram
        fps_spec_norm = spec_norm_df.rolling(averaging_window, min_periods=1, axis=0).mean() 
        fps_spec_norm = spec_norm_df.iloc[::averaging_window, :]
        spec_norm = fps_spec_norm.to_numpy()[:, 0]
        spec_norm = minmax(np.array(apply_ema(spec_norm, 0.99)))
        substeps_per_step = spec_norm * (max_substeps - min_substeps) + min_substeps 
    else:
        last = used_tag_df.iloc[0]
        mood_changes = []
        for i, row in used_tag_df.iterrows():
            mood_change = np.mean((row - last) ** 2)
            mood_changes.append(mood_change)
            last = row
        plt.plot(np.concatenate([[min_substeps] * n_start_prompts, minmax(np.array(mood_changes)) * (max_substeps - min_substeps) + min_substeps]))
        spec_norm = minmax(np.array(apply_ema(mood_changes, 0.95)))
        substeps_per_step = spec_norm * (max_substeps - min_substeps) + min_substeps 
        # add steps for first N steps
        substeps_per_step = np.concatenate([[substeps_per_step[0]] * n_start_prompts, substeps_per_step])
    plt.plot(substeps_per_step)
    substeps_per_step = np.round(substeps_per_step).astype(int)
    print(substeps_per_step.mean())
    #plt.plot(substeps_per_step)

In [None]:
import multiprocessing

import torchvision
from PIL import Image


def save_latent(latent, path, net):
    if net == "image":
        latent = torch.clamp(latent, 0, 1)
        pil_img = torchvision.transforms.ToPILImage()(latent.squeeze())
        jpg_path = path.replace(".pt", ".jpg")
        #pil_img.save(jpg_path, quality=95, subsample=0)
        
        p = multiprocessing.Process(target=pil_img.save, args=(jpg_path,), kwargs={"quality": 95, "subsample": 0})
        p.start()
        #p.join
    else:
        torch.save(latent, path)
        
        
# TODO: use subprocess to save stuff! (maybe also to load?)
#     p = multiprocessing.Process(target=save, args=(img, "test2.jpg"))
    # p.start()
#p.join()

def load_latent(path, net):
    if net == "image":
        jpg_path = path.replace(".pt", ".jpg")
        img = Image.open(jpg_path)
        latent = torchvision.transforms.ToTensor()(img)
    else:
        latent = torch.load(path)
    return latent

In [None]:
import time

time.time()

In [None]:
time.gmtime()

In [None]:
import shutil

from tqdm.auto import tqdm
import torchvision
import tensorflow as tf
from IPython.display import display, clear_output


to_pil = torchvision.transforms.ToPILImage()


if len(ema_encodings) > len(effects_list):
    ema_encodings = ema_encodings[:-1]
if len(ema_encodings) < len(effects_list):
    effects_list = effects_list[:-1]
assert len(ema_encodings) == len(effects_list), f"{len(ema_encodings)}, {len(effects_list)}"


img_latents = []
time_str = time.strftime("%Y_%m_%d__%H_%M_%S", time.gmtime())
save_folder = base_folder + f"tmp/vid_latents{time_str}/"
if os.path.exists(save_folder):
    shutil.rmtree(save_folder)
os.makedirs(save_folder, exist_ok=True)


imagine.to("cuda")
imagine.reset()
imagine.set_clip_encoding(encoding=[item.to(imagine.device) for item in ema_encodings[0]])
img, loss = imagine.train_step(0, 0)
img = img.detach().cpu()

imgs = []
save_image = (net == "dip")

pbar = tqdm(list(range(len(ema_encodings))))
for i in pbar:
    clip_encoding, effects = ema_encodings[i], effects_list[i]
    # apply effects
    transformed_img = img.float() 
    if img is not None and len(effects) > 0:
        if net == "dip":
            effect_img = imagine.model.model.latents.squeeze(0).cpu() + 1
            effect_img = effect_img.float().clip(0, 1)
        else:
            effect_img = transformed_img.cpu()
        
        for effect in effects:
            effect_img = effect(effect_img)

        if net == "dip":
            imagine.model.model.latents = effect_img.unsqueeze(0).to(imagine.device) - 1
        else:
            if net == "vqgan":
                effect_img_normed = effect_img.mul(2).sub(1).to(imagine.device)
                latent, _, [_, _, indices] = imagine.model.model.model.encode(effect_img_normed)
            else:
                latent = effect_img
            imagine.set_latent(latent)
    # set target encoding in CLIP
    clip_encoding = [part.to(imagine.device) for part in clip_encoding]
    imagine.set_clip_encoding(encoding=clip_encoding)
    # optimize for some steps
    for _ in range(substeps_per_step[i]):
        img, loss = imagine.train_step(0, 0, lpips_img=transformed_img.to(imagine.device))
    img = img.detach().cpu()
    
    # get latent of img
    latent = imagine.model.model.get_latent()
    if net != "dip":
        latent = latent.detach().cpu()
    if save_folder is None:
        img_latents.append(latent)
    else:
        latent_name = str(i) + ".pt"
        save_latent(latent, save_folder + latent_name, net)
    # save final img
    if save_image:
        imgs.append(to_pil(img.squeeze(0)))
    if i % (len(ema_encodings) // 20) == 0:
        pil_img = to_pil(img.squeeze())
        display(pil_img)

#img_latents = sequential_gen(ema_encodings, effects_list)
#img_latents = parallel_gen(clip_prompts)

# TODO: instead of saving the latent in every step, fill up buffer and save in batch

In [None]:
def merge_latents(a, b, w_a=1, w_b=1):
    if net == "dip":
        #a = a[1]
        #b = b[1]
        return [a[1][key] * w_a + b[1][key] * w_b for key in a[1]]
    else:
        return a * w_a + b * w_b

In [None]:
def apply_ema_disk(in_folder, out_folder, ema_val=0.9):
    if os.path.exists(out_folder):
        shutil.rmtree(out_folder)
    os.makedirs(out_folder, exist_ok=True)
    
    num_items = len([item for item in os.listdir(in_folder) if item.endswith(".pt") or item.endswith(".jpg")])
    ema = load_latent(in_folder + "0.pt", net)
    out = []
    for item_idx in tqdm(range(num_items)):
        item_name = str(item_idx) + ".pt"
        # load
        item = load_latent(in_folder + item_name, net)
        # calc
        ema = merge_latents(ema, item, ema_val, 1 - ema_val)
        #ema = ema * ema_val + item * (1 - ema_val)
        # store
        save_latent(ema, out_folder + item_name, net)
        
    if os.path.exists(in_folder):
        shutil.rmtree(in_folder)

In [None]:
if net != "dip":
    # take ema of encodings to smoothen
    in_folder = save_folder
    ema_folder = base_folder + f"tmp/vid_ema_latents_{time_str}/"
    ema_latents = apply_ema_disk(in_folder, ema_folder, ema_val=ema_val_latent)
    #ema_latents = apply_ema(img_latents, ema_val=ema_val_latent)

In [None]:
# interpolate between latents to increase fps and make video smoother
from mustovi_utils import slerp

def boost_frames(ema_latents, boost_fps, song):
    goal_frame_count = boost_fps * len(song) / 16000
    current_frame_count = len(ema_latents)
    frames_to_add = np.ceil(goal_frame_count / current_frame_count)
    if frames_to_add > 1:
        video_latents = []
        for i, latent in enumerate(ema_latents):
            if i + 1 == len(ema_latents):
                next_latent = ema_latents[i + 1]
            else:
                next_latent = ema_latents[i + 1]
            if net == "dip":
                latents_to_add = [[slerp(latent[i], next_latent[i], frac) for i in range(len(latent))]
                                  for frac in np.arange(frames_to_add) / frames_to_add]
            else:
                latents_to_add = [slerp(latent, next_latent, frac) 
                                  for frac in np.arange(frames_to_add) / frames_to_add]
            video_latents.extend(latents_to_add)
        return video_latents
    else:
        return ema_latents
    
def boost_frames_disk(in_folder, out_folder, boost_fps, song):
    if os.path.exists(out_folder):
        shutil.rmtree(out_folder)
    os.makedirs(out_folder, exist_ok=True)
    
    num_items = len([item for item in os.listdir(in_folder) if item.endswith(".pt") or item.endswith(".jpg")])
    
    goal_frame_count = boost_fps * len(song) / 16000
    current_frame_count = num_items
    frames_to_add = np.ceil(goal_frame_count / current_frame_count)
    if frames_to_add > 1:
        video_latents = []
        count = 0
        for i in tqdm(range(num_items)):
            item_name = str(i) + ".pt"
            
            latent = load_latent(in_folder + item_name, net)
            if i + 1 == num_items:
                next_latent = latent#ema_latents[i]
            else:
                next_latent_name = str(i + 1) + ".pt"
                next_latent = load_latent(in_folder + next_latent_name, net)
            latents_to_add = [slerp(latent, next_latent, frac) 
                              for frac in np.arange(frames_to_add) / frames_to_add]
            for new_latent in latents_to_add:
                item_name = str(count) + ".pt"
                save_latent(new_latent, out_folder + item_name, net)
                count += 1
    else:
        # rename in folder to out folder
        os.rename(input_folder, output_folder)
        
    if os.path.exists(in_folder):
        shutil.rmtree(in_folder)
    
if net != "dip":
    #video_latents = boost_fps(ema_latents, boost_fps, song)
    in_folder = ema_folder
    boost_folder = base_folder + f"tmp/vid_boosted_latents_{time_str}/"
    boost_frames_disk(in_folder, boost_folder, boost_fps, song)
    in_folder

In [None]:
import os

def load_img_paths(root):
    paths = [os.path.join(root, f) for f in os.listdir(root)
        if f.endswith(".png") or f.endswith(".jpg")]
    paths = sorted(paths, key= lambda x: int(x.split("/")[-1].split("_")[0].split(".")[0]))
    return paths


In [None]:
imagine = imagine.to("cpu")
torch.cuda.empty_cache()
gc.collect()

In [None]:
# create images and save to disk:
import torchvision
from tqdm.auto import tqdm
import numpy as np
import shutil

if net != "dip":
    to_pil = torchvision.transforms.ToPILImage()
    extension = ".jpg"
    tmp_folder = base_folder + f"tmp/vid_imgs_{time_str}/" 
    final_latents_folder = boost_folder
    #vid_boosted_latents

    if os.path.exists(tmp_folder):
        shutil.rmtree(tmp_folder)
    os.makedirs(tmp_folder, exist_ok=True)

    gen_model = imagine.model.model
    imagine.cpu()
    torch.cuda.empty_cache()
    gc.collect()

    device = torch.device("cuda")
    gen_model.to(device)


    num_items = len([item for item in os.listdir(final_latents_folder) if item.endswith(".pt") or item.endswith(".jpg")])
    for i in tqdm(range(num_items)):
        latent = load_latent(final_latents_folder + str(i) + ".pt", net)
        if net == "image":
            img = torch.clamp(latent, 0, 1)
            #print(img.min(), img.max())
            #img = minmax(img)
        else:
            img = gen_model(latents=latent.to(device)).to("cpu")
        pil_img = to_pil(img.squeeze())
        #print(np.array(pil_img).min(), np.array(pil_img).max())
        pil_img.save(os.path.join(tmp_folder, f"{i}{extension}"), subsampling=0, quality=95)
        
    if os.path.exists(final_latents_folder):
        shutil.rmtree(final_latents_folder)

In [None]:
import shutil
import subprocess

from PIL import Image
import soundfile
import moviepy.editor as mpy
from datetime import datetime


def create_video(imgs, song, song_name, vid_name, sr, duration, bitrate="5000k"):
    os.makedirs("video_gens", exist_ok=True)
    video_path = f"video_gens/{vid_name}"
    
    if isinstance(imgs, list):
        img_len = len(imgs)
    else:
        img_len = len([img for img in os.listdir(imgs) 
                       if img.endswith(".jpg") or img.endswith(".png")])
    
    # Generate final video
    vid_fps = len(imgs) / (len(song) / sr) #audio.duration
    print(vid_fps)
    video = mpy.ImageSequenceClip(imgs, fps=vid_fps)
    temp_vid_path = "tmp/video.mp4"
    video.write_videofile(temp_vid_path, 
                      codec="libx264",
                      fps=vid_fps,
                      #audio_codec="aac",
                      threads=10,
                      bitrate=bitrate,
                      #audio_bitrate="320k",
                      preset="slow",
                     )
    
    command = f"ffmpeg -i '{temp_vid_path}' -i '{song_name}' -c:v copy -map 0:v:0 -map \
1:a:0 -c:a aac -b:a 512k '{video_path}'"
    subprocess.run(command, shell=True)



In [None]:
date_time = datetime.now().strftime("%m_%d_%H:%M")  #("%m/%d/%Y, %H:%M:%S")

video_name = f"{song_name.split('/')[-1].split('.')[0]}_{prompt_mode}_ema{ema_val}_steps{sub_steps}"
video_name += "_"+ gpt_theme.replace(" ", "_") if create_gpt_artstyle else ""
video_name += f"_{args['model_type']}_{date_time}.mp4"
#video_name = "loud_pipes_kiefer.mp4"

if net == "dip":
    paths = [np.array(img) for img in imgs]
else:
    paths = load_img_paths(tmp_folder)
create_video(paths, raw_song, song_name, video_name, old_sr, duration)

In [None]:
# Clone Real-ESRGAN and enter the Real-ESRGAN
if not os.path.exists("Real-ESRGAN"):
    !git clone https://github.com/xinntao/Real-ESRGAN.git
    %cd Real-ESRGAN
    # Set up the environment
    !$PYTHONPATH -m pip install basicsr facexlib gfpgan
    !$PYTHONPATH -m pip install -r requirements.txt
    !$PYTHONPATH setup.py develop
    # Download the pre-trained model
    !wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
    !wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P experiments/pretrained_models
    %cd ..

In [None]:
imagine = imagine.cpu()
torch.cuda.empty_cache()
gc.collect()

In [None]:
import sys
sys.path.append("Real-ESRGAN")
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from tqdm import tqdm
import torch
import numpy as np

import matplotlib.pyplot as plt

@torch.inference_mode()
def upscale_imgs(imgs, out_folder=None, scale=4, tile=0):
    
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, 
                    num_block=23, num_grow_ch=32, scale=scale)
    upsampler = RealESRGANer(
        scale=scale,
        model_path="Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x4plus.pth",
        model=model,
        tile=tile,
        tile_pad=10,
        pre_pad=0,
        half=1)
    
    outs = []
    for i, img in enumerate(tqdm(imgs)):
        if isinstance(img, str):
            img = np.array(Image.open(img))[:,:,::-1]
        
        output, _ = upsampler.enhance(img, outscale=scale)
        pil_img = Image.fromarray(output[:,:,::-1])
        
        if out_folder:
            pil_img.save(os.path.join(out_folder, f"{i}.jpg"), subsample=0, quality=95)
        else:
            outs.append(pil_img)
    return outs


import torchvision

to_tensor = torchvision.transforms.ToTensor()

@torch.inference_mode()
def upscale_imgs_custom(imgs, out_folder=None, scale=4, tile=0):
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, 
                    num_block=23, num_grow_ch=32, scale=scale)
    loadnet = torch.load("Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x4plus.pth")
    if 'params_ema' in loadnet:
        keyname = 'params_ema'
    else:
        keyname = 'params'
    model.load_state_dict(loadnet[keyname], strict=True)
    model.eval().cuda().half()

    outs = []
    for i, img in enumerate(tqdm(imgs)):
        if isinstance(img, str):
            #img = torch.from_numpy(np.ascontiguousarray(Image.open(img))[:,:,::-1].copy()).unsqueeze(0)
            img = torch.from_numpy(np.transpose(np.array(Image.open(img))[:,:,::-1].copy(), (2, 0, 1))).float().unsqueeze(0).to("cuda")
        
        output = model(img.half()).cpu().squeeze().permute(1, 2, 0).numpy()
        #output, _ = upsampler.enhance(img, outscale=scale)
        output = (output * 255.0).round().astype(np.uint8)
        pil_img = Image.fromarray(output[:,:,::-1])
        
        if out_folder:
            pil_img.save(os.path.join(out_folder, f"{i}.jpg"), subsample=0, quality=95)
        else:
            outs.append(pil_img)
    return outs

In [None]:
import imageio

def load_images_from_mp4(path):
    vid = imageio.get_reader(path,  'ffmpeg')
    imgs = [np.array(image) for image in vid.iter_data()]
    return imgs

In [None]:
import os
import shutil

if upscale:
    input_folder = tmp_folder
    output_folder = base_folder + f"tmp/upscaled_vid_imgs_{time_str}"
    if os.path.exists(output_folder):
        shutil.rmtree(output_folder)
    os.makedirs(output_folder, exist_ok=True)
    
    print("Upscaling...")
    
    #video_path = "video_gens/any_colour_you_like_pink_floyd_hd_studio_quality_7032261705832661515_top_k_ema0.2_steps100Weird_and_beautiful._vqgan_11_20_14:02.mp4"
    #video_name = video_path.split("/")[0]
    #input_paths = load_images_from_mp4(video_name)
    
    input_paths = load_img_paths(input_folder)
    upscale_imgs(input_paths, out_folder=output_folder, scale=4)
    #!CUDA_VISIBLE_DEVICES=0 python Real-ESRGAN/inference_realesrgan.py --model_path Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x4plus.pth --input $input_folder --output $output_folder  --netscale 4 --outscale 4 --half --face_enhance > /dev/null
    print("Done!")
    # edit name
    upscaled_video_name = video_name.split("/")
    upscaled_video_name[-1] = "HD_" + upscaled_video_name[-1]
    upscaled_video_name = "/".join(upscaled_video_name)
    # create video
    paths = load_img_paths(output_folder)    
    create_video(paths, raw_song, song_name, upscaled_video_name, old_sr, duration, bitrate="12000k")
    if os.path.exists(input_folder):
        shutil.rmtree(input_folder)

In [None]:
if twice_upscale:
    torch.cuda.empty_cache()
    gc.collect()
    
    input_folder = base_folder + "tmp/upscaled_vid_imgs"
    output_folder = base_folder + "tmp/twice_upscaled_vid_imgs"
    if os.path.exists(output_folder):
        shutil.rmtree(output_folder)    
    print("Upscaling...")
    input_paths = load_img_paths(input_folder)
    upscale_imgs(input_paths, out_folder=output_folder, scale=4)
    #!CUDA_VISIBLE_DEVICES=0 python Real-ESRGAN/inference_realesrgan.py --model_path Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x2plus.pth --input  $input_folder --output $output_folder  --netscale 2 --outscale 2 --half --face_enhance > /dev/null
    print("Done!")
    # edit name
    twice_upscaled_video_name = upscaled_video_name.split("/")
    twice_upscaled_video_name[-1] = "Full" + twice_upscaled_video_name[-1]
    twice_upscaled_video_name = "/".join(twice_upscaled_video_name)
    # create video
    paths = load_img_paths(output_folder)
    create_video(paths, song_name, twice_upscaled_video_name, old_sr)