In [1]:
import os
WORK = os.environ["WORK"]
%cd $WORK

/lustre/work/chaselab/malyetama


In [44]:
from glob import glob
from PIL import Image as Img
import os
import multiprocessing
from joblib import Parallel, delayed
from tqdm import tqdm
from pathlib import Path
import imageio
from pygifsicle import optimize
from IPython.display import Markdown, display, Image
import dotenv
import base64
import requests
import json
import shutil


def create_fakes_gif(
    DATASET_NAME,
    subset=None,
    output_dir=None,
    ftype='gif',
    display_output=False,
    verbose=False,
    shift={
        'shift_r': 0,
        'shift_b': 0
    }
):
    
    def process(i):
        im = Img.open(i)
        if DATASET_NAME == 'metfaces':
            left, top, right, bottom = 0, 0, (256 * 4) * 2, (256 * 4) * 2
        else:
            left, top, right, bottom = 1 * shift['shift_r'], 1 * shift[
                'shift_b'], (256 * 4) + shift['shift_r'], (256 * 4) + shift['shift_b']
        im_cropped = im.crop((left, top, right, bottom))
        return im_cropped.save(f'{history}/{Path(i).stem}.png')
    
    def upload_img(image, token):
        with open(image, "rb") as file:
            url = "https://api.imgbb.com/1/upload"
            parameters = {
                "key": token,
                "image": base64.b64encode(file.read()),
            }
            res = requests.post(url, parameters)
            link = res.json()
            url = link['data']['url']
            return url

    WORK = os.environ["WORK"]
    PROJ_DIR = f'{WORK}/ADA_Project'
    
    dotenv.load_dotenv(f'{PROJ_DIR}/.env')
    token = os.getenv('TOKEN')
    
    history = f'{PROJ_DIR}/datasets/{DATASET_NAME}_history'
    try:
        shutil.rmtree(history)
    except:
        pass
    
    TRfolders = f'{PROJ_DIR}/training_runs/'
    TRfolders_ = glob(f'{PROJ_DIR}/training_runs/*')
    datasets = [
        x.replace(TRfolders, '').replace('_training-runs', '')
        for x in TRfolders_
    ]
    ds_rename = lambda before, after: [after if x == before else x for x in datasets]
    datasets = ds_rename('AFHQ', 'AFHQ-CAT')
    
#     if verbose:
#         print(f'Available datasets:\n {datasets}')

    d = {}
    
    with open(f'{PROJ_DIR}/FID_of_best_snapshots.json') as jf:
        jd = json.load(jf)
        best = 'fakes' + jd[DATASET_NAME]['snapshot'].replace('network-snapshot-', '')

    for folder, dataset in zip(TRfolders_, datasets):
        files = sorted(glob(folder + "/**/*"))
        fakes = [x for x in files if 'fakes' in x]
        if fakes == []:
            continue
        d[dataset] = {}
        d[dataset]['files'] = fakes
        
    __fakes = d[DATASET_NAME]['files']
    if (len(__fakes) % 2) != 0:
        __fakes = __fakes[:-1]
        
    best_fake = f'{Path(__fakes[-1]).parent}/{best}{Path(__fakes[-1]).suffix}'
    __fakes = __fakes[::subset] + [best_fake]
        
        
    Path(history).mkdir(exist_ok=True)
    
    n_jobs = multiprocessing.cpu_count() - 1

    _ = Parallel(n_jobs=n_jobs)(delayed(process)(i)
                                     for i in tqdm(__fakes))

    history_imgs = sorted([x for x in glob(f'{history}/*.png')])
    history_imgs = [history_imgs[-1]] + [x for x in history_imgs if 'init' not in x]
    
    
#     if subset is not None and verbose:
#         print(f'Subset size: {len(history_imgs)} image')

    if output_dir is None:
        output_dir = Path.cwd()
        anim_file = f'{output_dir}/{DATASET_NAME}.{ftype}'
    anim_file = f'{DATASET_NAME}.{ftype}'
    
    with imageio.get_writer(anim_file, mode='I') as writer:
        for filename in tqdm(history_imgs):
            image = imageio.imread(filename)
            writer.append_data(image)
            if str(Path(filename).stem) == best:
                break
        best_fake = f'{Path(history_imgs[-1]).parent}/{best}{Path(history_imgs[-1]).suffix}'
        best_fake = imageio.imread(best_fake)
        for _ in range(20):
            writer.append_data(best_fake)
    
    file_size = lambda file: Path(file).stat().st_size / 1e+6
    if verbose:
        if ftype == 'mp4':
            print(f'file size: {file_size(anim_file):.2f} MB')
        else:
            print(f'file size before optimization: {file_size(anim_file):.2f} MB')
    
    if ftype == 'gif':
        optimize(source=anim_file, destination=anim_file)
            
        if verbose:
            print(f'          after optimization: {file_size(anim_file):.2f} MB')
    
    if display_output is True:
        if ftype == 'gif':
            print('Loading...')
            img_url = upload_img(anim_file, token)
            print(f'{ftype} url ==> {img_url}')
            display(Markdown(f'![]({img_url})'))

### Available datasets:
 ['AFHQ-WILD', 'AFHQ-CAT', 'StyleGAN2_FFHQ_30K', 'metfaces', 'POKEMON', 'StanfordDogs',
  'StyleGAN2_AFHQ-DOG', 'cars196', 'AFHQ-DOG', 'FFHQ_5K', 'FFHQ_2K', 'ANIME-FACES',
  'conditional_CIFAR-10', 'StyleGAN2_FFHQ_5K', 'StyleGAN2_WILD-AFHQ',
  'unconditional_CIFAR-10', '102flowers', 'StyleGAN2_FFHQ', 'FFHQ_30K', 'FFHQ', 'StyleGAN2_FFHQ_2K']

In [45]:
DATASET_NAME = 'AFHQ-WILD'

create_fakes_gif(
    DATASET_NAME=DATASET_NAME,
    display_output=True,
    verbose=True,
    subset=10,
    ftype='mp4',
    shift={
        'shift_r': 256 * 0,
        'shift_b': 256 * 4
    }
)



  0%|          | 0/70 [00:00<?, ?it/s][A[A

100%|██████████| 70/70 [00:00<00:00, 209.17it/s]A[A


  0%|          | 0/63 [00:00<?, ?it/s][A[A

  5%|▍         | 3/63 [00:00<00:02, 20.44it/s][A[A

Subset size: 63 image




 10%|▉         | 6/63 [00:00<00:02, 21.03it/s][A[A

 14%|█▍        | 9/63 [00:00<00:02, 21.19it/s][A[A

 19%|█▉        | 12/63 [00:00<00:02, 21.40it/s][A[A

 24%|██▍       | 15/63 [00:00<00:02, 21.47it/s][A[A

 29%|██▊       | 18/63 [00:00<00:02, 21.49it/s][A[A

 33%|███▎      | 21/63 [00:00<00:01, 21.52it/s][A[A

 38%|███▊      | 24/63 [00:01<00:01, 21.34it/s][A[A

 43%|████▎     | 27/63 [00:01<00:01, 21.30it/s][A[A

 48%|████▊     | 30/63 [00:01<00:01, 21.35it/s][A[A

 52%|█████▏    | 33/63 [00:01<00:01, 21.35it/s][A[A

 57%|█████▋    | 36/63 [00:01<00:01, 21.32it/s][A[A

 62%|██████▏   | 39/63 [00:01<00:01, 21.34it/s][A[A

 67%|██████▋   | 42/63 [00:01<00:00, 21.42it/s][A[A

 71%|███████▏  | 45/63 [00:02<00:00, 21.34it/s][A[A

 76%|███████▌  | 48/63 [00:02<00:00, 21.32it/s][A[A

 81%|████████  | 51/63 [00:02<00:00, 21.09it/s][A[A

 86%|████████▌ | 54/63 [00:02<00:00, 20.87it/s][A[A

 90%|█████████ | 57/63 [00:02<00:00, 20.69it/s][A[A

 95%|█████

file size: 10.30 MB


In [49]:
_in = DATASET_NAME + '.mp4'
_out = DATASET_NAME + '.gif'
! /work/chaselab/malyetama/.conda/envs/ada-env/bin/ffmpeg -i $_in -vf \
    "fps=10,scale=512:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" \
    -loop 0 $_out -y

ffmpeg version 4.3.1 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.5.0 (crosstool-NG 1.24.0.131_87df0e6_dirty)
  configuration: --prefix=/work/chaselab/malyetama/.conda/envs/ada-env --cc=/home/conda/feedstock_root/build_artifacts/ffmpeg_1596712246804/_build_env/bin/x86_64-conda-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-gpl --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-libx264 --enable-pic --enable-pthreads --enable-shared --enable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  7.100 /  5.  7.100
  libswresample   3.  7.100 /  3.  7.100
  libpostproc    55.  7.100 / 55.  7.100
Input #0, mov,mp4,m4a,3gp,3g2,mj2, from

In [50]:
def upload_img(image, token):
    with open(image, "rb") as file:
        url = "https://api.imgbb.com/1/upload"
        parameters = {
            "key": token,
            "image": base64.b64encode(file.read()),
        }
        res = requests.post(url, parameters)
        link = res.json()
        url = link['data']['url']
        return url

In [51]:
link = upload_img(_out, os.getenv('TOKEN'))
print(link)

https://i.ibb.co/hHSVh77/fa5d8d47e524.gif


In [5]:
# display(Image(_out))