In [None]:
def kaggle_setup():
    """Convenient Kaggle environment setup function.

    Clones the source from my repository and then manipulate the environment variable
    such that packages can be loaded.
    """
    import os
    import sys
    import subprocess
    from pathlib import Path


    if Path("/kaggle/working/kaggle_kernels").exists(): 
        subprocess.call("cd /kaggle/working/kaggle_kernels; git pull", shell=True)
        print("git pulled")
    else:
        subprocess.call("git clone https://github.com/anthropikos/kaggle_kernels.git", shell=True)
        print("git cloned")
    subprocess.call("cd /kaggle/working/kaggle_kernels", shell=True)

    package_dir_path = Path("/kaggle/working/kaggle_kernels/gan/src")

    package_dir_path = os.path.abspath(package_dir_path)
    
    if package_dir_path not in sys.path:
        sys.path.append(package_dir_path)

    return

kaggle_setup()

In [None]:
from gan.utility_training import training_loop

model = training_loop(monet_data_dir="/kaggle/input/gan-getting-started/monet_jpg", photo_data_dir="/kaggle/input/gan-getting-started/photo_jpg")

In [None]:
from gan.data import ImageDataset
from gan.utility_data import map_rgb_to_tanh, map_tanh_to_rgb
from gan.utility_plotting import plot_before_after
import numpy as np
import torch


# model = model.to(torch.device("cpu"))

# photo_dataset = ImageDataset("/kaggle/input/gan-getting-started/photo_jpg")
# img_idx = np.random.randint(0, len(photo_dataset), 1)[0]

# img_rgb = photo_dataset[img_idx]
# img_tanh = map_rgb_to_tanh(img_rgb)

# img_gen_tanh = model.generate_monet(img_tanh)
# img_gen_rgb = map_tanh_to_rgb(img_gen_tanh)

# fig = plot_before_after(img_rgb, img_gen_rgb)

In [None]:
# # Predict through all the photos

# from gan.utility_plotting import plot_single_RGB_tensor
# from pathlib import Path
# from tqdm import tqdm

# model.to(torch.device("cpu"))

# for idx, photo_rgb in tqdm(enumerate(photo_dataset)): 
#     photo_tanh = map_rgb_to_tanh(photo_rgb)

#     gen_monet_tanh = model.generate_monet(photo_tanh)
#     gen_monet_rgb = map_tanh_to_rgb(gen_monet_tanh)

#     ax = plot_single_RGB_tensor(gen_monet_rgb)
#     fig = ax.get_figure()
#     (Path("/kaggle/working") / Path("images")).mkdir(exist_ok=True)
#     fig.savefig(f"images/gen_monet_{idx}.jpg")
#     plt.close()

# plt.close()

In [None]:
# Predict through all the photos

from gan.utility_plotting import plot_single_RGB_tensor
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt


for idx, photo_rgb in tqdm(enumerate(photo_dataset)): 
    photo_tanh = map_rgb_to_tanh(photo_rgb).to(torch.device('cuda'))

    gen_monet_tanh = model.generate_monet(photo_tanh)
    gen_monet_rgb = map_tanh_to_rgb(gen_monet_tanh).to(torch.device('cpu'))

    ax = plot_single_RGB_tensor(gen_monet_rgb)
    fig = ax.get_figure()
    (Path("/kaggle/working") / Path("images")).mkdir(exist_ok=True)
    fig.savefig(f"images/gen_monet_{idx}.jpg")
    plt.close()


In [None]:
from shutil import make_archive

src = Path("/kaggle/working/images")
dst = Path("/kaggle/working")

make_archive(src, "zip", dst)