# Art BMSG-GAN

Notebook authored by Jeremy Webb

BMSG-GAN package by [Animesh Karnewar](https://akanimax.github.io/) can be found in this repository: https://github.com/akanimax/BMSG-GAN.

This notebook gathers a dataset of images of art in a specific genre and trains a GAN to produce similar art.

Below are some examples of landscapes generated by a GAN trained with this notebook.

<div align=center>
  <img src="https://gitlab.com/iota-lab/ai-marketplace/raw/master/GAN_landscapes.png" alt="Generated Landscape Examples" width="50%"/>
</div>

---

**To train a GAN, read and run each cell below in order by clicking on the play button on the left side.**

## License

MIT License

Copyright (c) 2019 Jeremy Webb

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

## Setup

Check that you are using a GPU for training by choosing *Runtime->Change runtime type->Hardware accelerator = GPU*.

If something goes wrong while running these instructions, you can start with a clean slate by doing *Runtime->Reset all runtimes...* This will delete everything and give you a fresh server instance.

Define some utility functions and imports.

In [0]:
from IPython.display import display
from ipywidgets import IntProgress
import os
import re

class ProgressBar():
  def __init__(self, maxValue = 100, startValue = 0):
    self.pBar = IntProgress(min=0, max = maxValue)
    self.maxValue = maxValue
    self.updateInterval = 1
    if self.maxValue > 150:
      self.updateInterval = int(self.maxValue / 100)
    display(self.pBar) # display the bar

  def update(self, value):
    if value % self.updateInterval == 0:
      self.pBar.value = value
    if(value == self.maxValue):
      self.pBar.bar_style = 'success'


def count_items(directory):
  dirs = os.listdir(directory)
  return len(dirs)

def getEpochReached(search_dir):
  model_files = os.listdir(search_dir)
  max_epoch = 0
  for f in model_files:
    if os.path.isfile(search_dir + f):
      # number for epoch should be in last part of filename
      # filename example: "GAN_GEN_1.pth"
      num = re.search(r'\d+', f)
      if num:
        num = int(num.group())
        if num > max_epoch:
          max_epoch = num
  return max_epoch

### Clone the BMSG-GAN Repository
Git clone the GAN repository.

In [0]:
!git clone https://github.com/akanimax/BMSG-GAN
# the below is necessary because the current master branch of the repository
# uses defaults set for AWS SageMaker
# these defaults cause errors when they are parsed during setup and don't exist
# so checkout a version without these defaults
!cd BMSG-GAN/ && git checkout 1d3a719910504714438f71e92714567e491f488e

### Mounting Google Drive

Using Google Drive to save the trained models automatically is recommended, but not required. In order to use Google Drive, you will need to have at least 4 GB of free space for storing model files, and optionally, samples generated during training.

The advantages of using Google Drive are that the model files will be automatically stored in case your notebook gets disconnected during training so you will not lose much progress. Additionally, if you save the samples generated during training to your Google Drive, you can see how the GAN is progressing.

If you do not use Google Drive, resuming training will be much more difficult because you will need to ensure that you download the latest model files before your notebook server is restarted (the server will automatically restart after 12 hours) and then upload them again when you want to resume. The largest size image that can be trained with decent results in one day is 64x64 pixels.

If you decide not to use Google Drive, set `use_drive = False`

In [0]:
#@title { form-width: "40%" }

#@markdown If checked, training data will be stored in Google Drive.
use_drive = True #@param {type: "boolean"}

from google.colab import drive

if use_drive:
  drive.mount('/content/drive')
  if not os.path.exists("/content/drive/My Drive/Colab Notebooks/"):
    print("Note: Create 'Colab Notebooks' directory in drive root to proceed.")
else:
  print(
        "Warning: Your models will not be saved automatically. "
        "Make sure to download the models periodically to avoid "
        "loss of progress."
       )

### Choose Genre

If you don't have your own dataset, it is recommended to choose a genre from [this list on Wikiart](https://www.wikiart.org/en/paintings-by-genre). A dataset of artwork in the public domain will be scraped from Wikiart. It is a good idea to choose a genre with more pieces of art so the GAN has enough training data.

In [0]:
#@title { run: "auto", display-mode: "form" }

genre = 'landscape' #@param {type: "string"}

### Define Default Directories

Define the default directories for saving the training data and results.

If you are using Google Drive and want to store the samples generated during training in Google Drive, set `samples_in_drive` to `True`. This will require additional storage space in Google Drive, but will enable you to see how the GAN is progressing as it trains.

In [0]:
#@title Store Samples in Google Drive? { run: "auto", form-width: "40%" }

#@markdown If checked, samples generated from training will be stored in Drive
samples_in_drive = True #@param {type: "boolean"}

import os

if use_drive:
  main_dir = "/content/drive/My Drive/Colab Notebooks/"
else:
  main_dir = "/content/BMSG-GAN/"

data_dir = main_dir + "Data/"
bmsggan_dir = data_dir + "BMSG-GAN/"

base_dir = "/content/BMSG-GAN/data/"
if not os.path.exists(base_dir):
    print("Clone git repository above first!")
else:
  # original images directory
  train_dir = base_dir + genre + "/original/"
  if not os.path.exists(train_dir):
      os.makedirs(train_dir)

  # processed images directory
  processed_dir = base_dir + genre + "/processed/"
  if not os.path.exists(processed_dir):
      os.makedirs(processed_dir)
  
  # saved models directory
  models_dir = bmsggan_dir + genre + "/Models/"
  if not os.path.exists(models_dir):
      os.makedirs(models_dir)
      
  # samples directory
  if samples_in_drive and use_drive:
    samples_dir = bmsggan_dir + genre + "/samples/"
  else:
    samples_dir = base_dir + genre + "/samples/"
  if not os.path.exists(samples_dir):
      os.makedirs(samples_dir)
      
  # logs directory
  logs_dir = base_dir + genre + "/logs/"
  if not os.path.exists(logs_dir):
      os.makedirs(logs_dir)
      
  print("Directories created successfully")

### Set Target Image Size

Images produced by the GAN are square. The `target_size` parameter defines how large the images should be. It needs to be a power of 2 (i.e. 4, 8, 16, 32, 64, 128, 256, 512). 512 is the largest this parameter can be and still run the BMSG-GAN on Google Colab (at the time of writing).

Choose the `target_size` parameter carefully. Although bigger pictures are attractive, they also take a lot longer to train. For example, with a `target_size` of 512, training on a decently-sized dataset (1000s of images) can take a few weeks on Google Colab. Using a `target_size` of 128, recognizable images could be obtained within a couple days. If you are not using Google Drive to store intermediate results, it is recommended that `target_size` be no larger than 64 to allow training to finish in one day (less than the max running time of Google Colab - 12 hours).

In [0]:
#@title  { run: "auto", vertical-output: true }

import math

# size in pixels of the generated images
target_size = "512" #@param [4, 8, 16, 32, 64, 128, 256, 512]
target_size = int(target_size)

if math.log2(target_size) % 1 != 0:
  raise ValueError('target_size must be a power of 2')

print(f'target_size set to {target_size}')

## Setup the Training Data

### Unzip the data if it exists

If the data exists in a zip directory inside the `data_dir`, extract it. If you have your own dataset, upload it to `data_dir` before running this cell. Otherwise try and scrape Wikiart for the dataset in the next section.

In [0]:
found_training_data = False
files = os.listdir(data_dir)
for f in files:
  filename = os.path.join(data_dir, f)
  if os.path.isfile(filename) and f.startswith(genre) and f.endswith('.zip'):
    found_training_data = True
    !unzip -o "{filename}" -d "{train_dir}"

### Scrape Wikiart for Artwork in Genre

This cell will scrape Wikiart for art pieces in the specified genre. Run it if you do not have your own dataset. If `use_drive` is `True`, a zipfile of the images will also be saved to drive so that the images don't have to be downloaded in the future, and instead can be simply extracted (see above cell).

In [0]:
if not found_training_data:
  # scrape Wikiart for the dataset since it wasn't found in previous step
  import os
  import urllib.request
  import json
  import requests
  import re
  import itertools
  import multiprocessing
  from multiprocessing.dummy import Pool
  import zipfile

  # total pages to scrape, 
  # can be larger than available pages (if you don't know how many there are)
  pages = 100

  # genre_to_scrape can be any one of the values found on the
  # following page: https://www.wikiart.org/en/paintings-by-genre
  # choose one that has enough paintings to train with,
  # a good amount is in the 1000s

  class PaintingInfo:
    def __init__(self, url, title):
      self.url = url
      self.title = title

    def getValidFilename(self):
      maxNameLen = 180
      s = self.title.strip().replace(' ', '_').lower()
      if len(s) > maxNameLen:
          s = s[:maxNameLen]
      s = re.sub('[^A-Za-z0-9\.\-_]+', '.', s)
      return s

  # get list of all links to paintings of the specified genre
  def get_painting_list(count, genre):
    url = "https://www.wikiart.org/en/paintings-by-genre/" \
        + genre+ "/" + str(count) + "?json=1"
    response = requests.get(url, allow_redirects=False)
    response.raise_for_status()
    if(response.status_code >= 300 or len(response.json()) == 0):
        # when the page count is invalid, a 300 redirect is returned
        # we don't want to do anymore processing after this
        return None, False
    data = response.json()
    url_list = []
    print("Processing {} elements".format(len(data)))
    for item in data:
        name = item['image']
        exclamationLoc = name.find('!')
        if exclamationLoc >= 0:
            name = name[:exclamationLoc]
        url_list.append(PaintingInfo(name, item['title']))
    return url_list, True

  def downloader(paintingInfo, genre, save_directory):
    item, painting = paintingInfo	
    name = save_directory + painting.getValidFilename() + ".jpg"

    if item != 0 and item % 100 == 0:
        print("Downloaded {} images".format(item))

    # skip if the file is already downloaded
    if not os.path.exists(name):
        try:
            urllib.request.urlretrieve(painting.url, name)
        except Exception as e:
            print(e)
            try:
                s = "Failed to download: " + painting.url
                print(s)
            except UnicodeEncodeError:
                # avoid crashing because characters are unicode encoded
                print("Failed to download file")	


  def run_scraper(genre, save_directory):
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    results = []

    print("Compiling list of paintings to download...")
    for i in range(1, pages + 1):
        result, isValid = get_painting_list(i, genre)
        if isValid:
            results.extend(result)
        else:
            break
    print("Starting to download {} images...".format(len(results)))
    pool_of_threads = Pool(max(1, multiprocessing.cpu_count() - 1))
    pool_of_threads.starmap(
        downloader,
        zip(enumerate(results),
            itertools.repeat(genre),
            itertools.repeat(save_directory)
           )
    )
    pool_of_threads.close()

  print("Starting scraping...")
  run_scraper(genre, train_dir)
  
  # if using drive,
  # create a zip file of the images and store them for future use
  if use_drive:
    print("Saving data")
    zipfile_name = data_dir + genre + ".zip"
    files_to_store = os.listdir(train_dir)
    with zipfile.ZipFile(zipfile_name, "w", zipfile.ZIP_DEFLATED) as zipdir:
      for f in files_to_store:
        zipdir.write(train_dir + f, f)

  print("All done!")

### Preprocess the Training Images
Remove alpha channel from images, which messes up training, and resize the images to the target size. The images are resized to width = target_size and height = target_size regardless of their aspect ratio, meaning that some of them can become distorted. However, this gives pretty good results anyway. If you are so inclined, you could put together an algorithm that does intelligent cropping, centering, and resizing.

In [0]:
from PIL import Image
import os, sys

path = train_dir
out_path = processed_dir

dirs = os.listdir(path)
print(f"Found {len(dirs)} images.")
progress = ProgressBar(maxValue = len(dirs) - 1)

for i, item in enumerate(dirs):
    if os.path.isfile(path+item):
        im = Image.open(path+item)
        RGB = im.convert('RGB')
        processed_im = RGB.resize((target_size, target_size))
        processed_im.save(out_path + item, 'JPEG', quality=90)
    progress.update(i)
    
print("Processed all images.")

Delete any images that can't be opened.

In [0]:
#@markdown Verify that training images are all valid
verify = True #@param {type: "boolean"}
if verify:
  from fastai.vision import verify_images
  verify_images(processed_dir, delete=True)

### Check Training Images

Show some images and check that they look good.

In [0]:
# show some of the images
from numpy import load
from matplotlib import pyplot
import matplotlib.image as mpimg
import os

# image directory, grid size for plot
def plot_images(img_directory, n):
  imgs = os.listdir(img_directory)
  if len(imgs) < n * n:
    print("Not enough images to display")
    return
  for i in range(n * n):
    img = mpimg.imread(img_directory + imgs[i])
		# define subplot
    pyplot.subplot(n, n, 1 + i)
		# turn off axis
    pyplot.axis('off')
		# plot raw pixel data
    pyplot.imshow(img)
  pyplot.show()

# load the face dataset
print(f"Directory contains {count_items(processed_dir)} items.")
plot_images(processed_dir, 5)

## Configure Training Parameters

Define the training parameters.

The `batch_size` parameter governs how many images are used for each update of the GAN. A larger size will train faster on a GPU, but too large a `batch_size` will cause a memory overflow error.

The `feedback_factor` specifies how often to add to the logs and generate samples.

The number of epochs specify how many iterations to run the GAN through before stopping training. This parameter is not that important because as long as you have the model files saved, you can resume training and train for more epochs. However, to get decent results, the number of epochs is usually greater than 150. This is highly dependent on your dataset also.

If you are unsure what to choose, just use the defaults below.

In [0]:
import math

target_depth = int(math.log2(target_size) - 1)
print(f"Target depth = {target_depth}")

# how many images to process at once
batch_size = 2**(11 - target_depth)
print(f"Batch size = {batch_size}")

# calculate feedback factor
total_images = count_items(processed_dir)
total_batches = total_images // batch_size
feedback_factor = min(total_batches, 10)
print(f"Feedback factor = {feedback_factor}")

# number of iterations used for training
# increase this number to train for longer
num_epochs = 180

Check if there are any previous models in the models_dir. If there are, assume training should be resumed instead of restarted.

If you would like to start training over, make sure the models folder does not contain previous models.

In [0]:
# check if there are model files in the models_dir
model_files = os.listdir(models_dir)
epoch_reached = 0

if len(model_files) > 0:
  # we want to resume training
  # first find what epoch we are up to
  epoch_reached = getEpochReached(models_dir)
  gen_file = os.path.join(models_dir, "GAN_GEN_" + str(epoch_reached) + ".pth")
  dis_file = os.path.join(models_dir, "GAN_DIS_" + str(epoch_reached) + ".pth")
  gen_shadow_file = os.path.join(
      models_dir, "GAN_GEN_SHADOW_" + str(epoch_reached) + ".pth"
  )
  gen_optim_file = os.path.join(
      models_dir, "GAN_GEN_OPTIM_" + str(epoch_reached) + ".pth"
  )
  dis_optim_file = os.path.join(
      models_dir, "GAN_DIS_OPTIM_" + str(epoch_reached) + ".pth"
  )
  print(f"Found model files. Starting from epoch: {epoch_reached + 1}.")
else:
  print("Model files not found. Starting from epoch 0.")

## Train the GAN

To see all the options available for training the GAN, set `show_help` to `True` and run the cell below.

In [0]:
#@title  { run: "auto" }
show_help = False #@param {type: "boolean"}
if show_help:
  !python train.py --help

### Start Training!
The cell below runs the actual training on the GAN. It will run continuously until stopped or it times out. If you lose connection to this page, don't worry, you have about 40 minutes to reconnect without losing any progress. The notebook will continue to run on the Google Colab server for up to 40 minutes with no connection.

Note that after about 12 hours, your Google Colab backend will be reset, so you can only train for 12 hours at most.

In [0]:
%cd /content/BMSG-GAN/sourcecode/
# if the depth_reached is greater than 0,
# then that means we are resuming a training session
if epoch_reached > 0:
  !python train.py \
               --depth="{target_depth}" \
               --num_epochs="{num_epochs}" \
               --flip_augment=True \
               --images_dir="{processed_dir}" \
               --sample_dir="{samples_dir}" \
               --model_dir="{models_dir}" \
               --batch_size="{batch_size}" \
               --feedback_factor="{feedback_factor}" \
               --start="{epoch_reached + 1}" \
               --generator_file="{gen_file}" \
               --discriminator_file="{dis_file}" \
               --shadow_generator_file="{gen_shadow_file}" \
               --generator_optim_file="{gen_optim_file}" \
               --discriminator_optim_file="{dis_optim_file}"
else:
  !python train.py \
               --depth="{target_depth}" \
               --num_epochs="{num_epochs}" \
               --flip_augment=True \
               --images_dir="{processed_dir}" \
               --sample_dir="{samples_dir}" \
               --model_dir="{models_dir}" \
               --batch_size="{batch_size}" \
               --feedback_factor="{feedback_factor}"

## Cleanup Samples and Models

### Cleanup Image Samples

This is only necessary to run if you are storing samples in Google Drive. This deletes all generated image samples except for the last one from each epoch. Make sure to "Empty Trash" in Drive to recover the space.

Keeping the last sample from each epoch is useful if you would like to make a gif of the progress of the GAN overtime (see last section).

In [0]:
import os
import re

def cleanupSamples(total_depth, max_epochs):
  for depth in range(2, total_depth + 2):
    imsize = 2**depth
    current_dir = os.path.join(samples_dir, f"{imsize}_x_{imsize}")
    print(f"Processing {imsize}_x_{imsize} images")
    deleteExtraImages(current_dir, max_epochs)
    if len(os.listdir(current_dir)) > max_epochs:
      print(f"Something went wrong for {imsize}_x_{imsize} images")
      return
  
def deleteExtraImages(current_dir, max_epochs):
  imgs_to_keep = [0]*(max_epochs + 1)
  imgs = os.listdir(current_dir)
  progress = ProgressBar(maxValue = len(imgs) - 1)
  for i, f in enumerate(imgs):
    # format of file name is "gen_1_100.png"
    # where the first number is the epoch, and the second is the iteration
    # we want to keep the file with the largest iteration per epoch
    epoch_num, iter_num = re.findall(r'\d+', f)
    epoch_num = int(epoch_num)
    iter_num = int(iter_num)
    if epoch_num == None or iter_num == None or epoch_num > max_epochs:
      # not enough numbers found so file is something else
      continue
    if iter_num > imgs_to_keep[epoch_num]:
      # current file has a higher iteration number than the one in imgs_to_keep
      # keep the current one and delete the one in imgs_to_keep
      file_to_remove = f"gen_{epoch_num}_{imgs_to_keep[epoch_num]}.png"
      imgs_to_keep[epoch_num] = iter_num
    else:
      # the current iter num is smaller than the one to keep, delete the file
      file_to_remove = f
    full_remove_path = os.path.join(current_dir, file_to_remove)
    if os.path.exists(full_remove_path):
      # check if the path exists before removing because imgs_to_keep
      # is initialized to 0, which may not be a valid iteration number
      os.remove(full_remove_path)
    progress.update(i)

if use_drive and samples_in_drive:
  max_epoch = getEpochReached(models_dir)
  # add 1 to max_epoch because the epoch that was interrupted
  # should also be processed
  cleanupSamples(target_depth, max_epoch + 1)

### Cleanup Models

This is only necessary if you are using Google Drive (use_drive is True). It deletes all non-generator models but the ones saved from the last two epochs. If you would also like to delete old generator files, set `delete_generator` to `True`. It's important to do this to free up Drive storage. Make sure to 'Empty Trash' after running this.

In [0]:
delete_generator = False  #@param {type: "boolean"}

import os

# Delete all models up to max_epochs except generator files
def deleteModels(delete_dir, max_epochs):
  progress = ProgressBar(maxValue = max_epochs)
  for epoch_num in range(1, max_epochs + 1):
    # format of file name is "GAN_GEN_2.pth"
    # where the number is the epoch
    files_to_remove = [f"GAN_GEN_SHADOW_{epoch_num}.pth",
                       f"GAN_GEN_OPTIM_{epoch_num}.pth",
                       f"GAN_DIS_OPTIM_{epoch_num}.pth",
                       f"GAN_DIS_{epoch_num}.pth"]
    if delete_generator:
      files_to_remove.append(f"GAN_GEN_{epoch_num}.pth")
    for file_to_remove in files_to_remove:
      full_remove_path = os.path.join(delete_dir, file_to_remove)
      if os.path.exists(full_remove_path):
        # check if the path exists before removing
        os.remove(full_remove_path)
    progress.update(epoch_num)

if use_drive:
  max_epoch = getEpochReached(models_dir)
  # keep models for last 2 epochs just in case
  deleteModels(models_dir, max_epoch - 2)

## Generate Images

### Generate a Single Image

Generate an image from the most-trained generator and display it. If you would also like to save the generated image, set `save_image` to `True`.

In [0]:
#@title Generate an Image { form-width: "40%" }

#@markdown Set the checkbox below if you want to save the image.
save_image = True  #@param {type: "boolean"}
#@markdown If you aren't using Google Drive, the name of the generated directory
#@markdown doesn't really matter
generated_dir_name = "generated"  #@param {type: "string"}
generated_dir = samples_dir + f"/{generated_dir_name}/"

%cd /content/BMSG-GAN/sourcecode/

from pathlib import Path
import uuid
import torch as th
from torch.nn.functional import interpolate
import numpy as np
from matplotlib import pyplot as plt
from MSG_GAN.GAN import Generator

# The functions below are licensed with the MIT license.
# They have been modified slightly.
# Original author: Animesh Karnewar
# Source: https://github.com/akanimax/BMSG-GAN

def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)):
    """
    adjust the dynamic colour range of the given input data
    :param data: input image data
    :param drange_in: original range of input
    :param drange_out: required range of output
    :return: img => colour range adjusted images
    """
    if drange_in != drange_out:
        scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
                np.float32(drange_in[1]) - np.float32(drange_in[0]))
        bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
        data = data * scale + bias
    return th.clamp(data, min=0, max=1)

def progressive_upscaling(images):
    """
    upsamples all images to the highest size ones
    :param images: list of images with progressively growing resolutions
    :return: images => images upscaled to same size
    """
    with th.no_grad():
        for factor in range(1, len(images)):
            images[len(images) - 1 - factor] = interpolate(
                images[len(images) - 1 - factor],
                scale_factor=pow(2, factor)
            )

    return images

# latent_size must be equal to size generator was trained with
# depth must be equal to depth generator was trained with
def generate_image(generator_file, depth, latent_size=512):
# create the generator object
    gen = th.nn.DataParallel(Generator(
        depth=depth,
        latent_size=latent_size
    ))

    print("Loading the generator weights from:", generator_file)
    # load the weights into it
    device = th.device("cuda" if th.cuda.is_available() 
                   else "cpu")
    gen.load_state_dict(
        th.load(generator_file, map_location=str(device))
    )

    print("Generating image...")
    # generate the images:
    with th.no_grad():
        point = th.randn(1, latent_size)
        point = (point / point.norm()) * (latent_size ** 0.5)
        ss_images = gen(point)

    # resize the images:
    ss_images = [adjust_dynamic_range(ss_image) for ss_image in ss_images]
    ss_images = progressive_upscaling(ss_images)
    ss_image = ss_images[depth - 1]

    img = ss_image.squeeze(0).permute(1, 2, 0).cpu()
    return img

epoch_reached = getEpochReached(models_dir)
gen_file = Path(models_dir) / ("GAN_GEN_" + str(epoch_reached) + ".pth")
img = generate_image(gen_file, target_depth)

plt.imshow(img)
plt.show()

if save_image:
    filename = uuid.uuid4().hex[:8]
    Path(generated_dir).mkdir(parents=True, exist_ok=True)
    plt.imsave(generated_dir + filename, img)


### Download Generated Images

Download the generated images if you wish. Set `num_files` to the maximum number of images you want to download. Note that `save_image` in the previous cell must be `True` in order to download images.



In [0]:
# max number of files to download
#@markdown Max number if images to download
num_files = 10  #@param {type: "number"}

import tempfile
import zipfile
from google.colab import files

im_files = os.listdir(generated_dir)
max_files = min(num_files, len(im_files))
with tempfile.TemporaryDirectory() as tempdir:
  zip_name = "gen_images.zip"
  with zipfile.ZipFile(zip_name, "w", zipfile.ZIP_DEFLATED) as zip:
    for i in range(max_files):
      zip.write(generated_dir + im_files[i], im_files[i])
  files.download(zip_name)

### Generate GIF of training progress

Set `download_image` to `True` to download the gif. Also choose which image in the samples grid to use by setting `image_y` and `image_x` which are 0-based indices starting in the top-left corner (i.e. the top-left image is 0, 0). Change `duration` to speed up or slow down the gif.

In [0]:
#@markdown Check the box below if you would like to download the 
#@markdown generated gif. If you are not using Google Drive you
#@markdown should download the image because it is the only way to
#@markdown view the gif.
download_image = False  #@param {type: "boolean"}
#@markdown Indices of the image in the samples grid to use. The
#@markdown top-left corner is `image_y` = 0, `image_x` = 0.
image_y = 0  #@param {type: "number"}
image_x = 0  #@param {type: "number"}
#@markdown Set the amount of time in milliseconds to display each
#@markdown image in the gif
duration = 250  #@param {type: "number"}
gif_name = 'train_time_lapse.gif'  #@param {type: "string"}

from PIL import Image
from google.colab import files
from tqdm import tqdm

def natural_key(string_):
    """See http://www.codinghorror.com/blog/archives/001018.html"""
    string_ = str(string_)
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)]

sample_image_dir = Path(samples_dir) / f"{target_size}_x_{target_size}"
start_x = target_size * image_x
end_x = start_x + target_size
start_y = target_size * image_y
end_y = start_y + target_size
coords = (start_x, start_y, end_x, end_y)

images = []
for img in tqdm(sorted(sample_image_dir.iterdir(), key=natural_key)):
  oimg = Image.open(img)
  cimg = oimg.crop(coords)
  images.append(cimg)
images[0].save(samples_dir + gif_name,
            save_all=True,
            append_images=images[1:],
            duration=duration,
            loop=0)
if download_image:
  files.download(samples_dir + gif_name)