# VQGAN Download and Caching Framework

This is a framework for downloading and caching VQGAN models, many Jupyter Notebook platforms such as Google Colaboratory have local fast ephemeral storage along with slower persistent storage.  This framework works by downloading the files locally then copying them to persistent storage.  When the function is executed once the cache is primed the files are copied from the cache rather than downloaded again.

## Business Continuity

Sometimes models previously made available go missing or are temporarily unavailable, as this framework saves a copy it can be a useful business contunity mechanism that reduces the risk of interuptions to research.

## Cache Invalidation

This framework **does not implement** any form of cache invalidation, it is possible that the original configs and checkpoints could be updated meaning old models are cached.

In [None]:
from os.path import isdir, abspath, join
from os import mkdir

working_dir = abspath('./content') # working_dir must always be on a local disk
model_dir = join(working_dir, 'models') # model_dir must always be on a local disk
cache_dir = join(working_dir, 'cache') # cache should in general be on a persistent disk
model_cache_dir = join(cache_dir, 'models') # models cache should in general be in a persistent disk

# create directories if they don't exist.
if not isdir(working_dir):
    mkdir(working_dir)

if not isdir(model_dir):
    mkdir(model_dir)

if not isdir(cache_dir):
    mkdir(cache_dir)

if not isdir(model_cache_dir):
    mkdir(model_cache_dir)

In [None]:
from urllib.request import urlretrieve
from os.path import isdir, isfile, join
from shutil import copy

if not (isdir(working_dir) and isdir(model_dir) and isdir(cache_dir) and isdir(model_cache_dir)):
    raise RuntimeError('Directory not found, has the cell above been executed?')

def download_vqgan_model(name, config_url, checkpoint_url):
    config_file = join(model_dir, f'{name}.yaml')
    checkpoint_file = join(model_dir, f'{name}.ckpt')
    model_cache = join (model_cache_dir, name)
    if not isdir(model_cache):
        mkdir(model_cache)
    config_cache = join(model_cache, f'{name}.yaml')
    checkpoint_cache = join(model_cache, f'{name}.ckpt')

    is_available = isfile(config_file) and isfile(checkpoint_file)
    is_cached = isfile(config_cache) and isfile(checkpoint_cache)

    if (is_available):
        print(f'The model ({name}) is already available locally.')
        return name

    if (is_cached):
        print(f'The model ({name}) is available in the cache, copying locally.')
        copy(config_cache, config_file)
        copy(checkpoint_cache, checkpoint_file)
        return name

    print(f'The model ({name}) was not found, downloading and caching.')
    urlretrieve(config_url, config_file)
    urlretrieve(checkpoint_url, checkpoint_file)
    copy(config_file, config_cache)
    copy(checkpoint_file, checkpoint_cache)
    return name

model = download_vqgan_model(
    name='vqgan_imagenet_f16_1024',
    config_url='https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1',
    checkpoint_url='https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1')

print(model)