In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
sys.path.append("/content/drive/MyDrive/Anime GAN Code/losses")
sys.path.append("/content/drive/MyDrive/Anime GAN Code/model")
sys.path.append("/content/drive/MyDrive/Anime GAN Code")

In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import (
    Input, Conv2D, SeparableConv2D, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D,
    BatchNormalization, Activation, Dense, Dropout, Flatten, Multiply, Add, Lambda, SpatialDropout2D, Reshape, GlobalMaxPooling2D, Layer, UpSampling2D
)
from tensorflow.keras.models import Model
from tensorflow.keras.initializers import HeNormal, GlorotNormal
from tensorflow.keras.regularizers import l2
import matplotlib.pyplot as plt
import pathlib
from tensorflow.keras.optimizers.schedules import (
    ExponentialDecay,
    PolynomialDecay,
    PiecewiseConstantDecay,
    CosineDecay,
    CosineDecayRestarts,
    InverseTimeDecay,
    LearningRateSchedule
)
from tensorflow.keras.applications import VGG16, VGG19
from tensorflow.keras.models import Model
from keras.saving import register_keras_serializable

import support_fun

from model.encoder import build_encoder
from model.main_discriminator import build_main_discriminator
from model.main_gan import build_main_gan
from model.main_generator import build_main_generator
from model.support_discriminator import build_support_discriminator
from model.support_gan import build_support_gan
from model.support_generator import build_support_generator

from losses.color import color_loss
from losses.content import content_loss
from losses.dm import dm_loss
from losses.ds import ds_loss
from losses.gray_style import gray_style_loss
from losses.m_adv import m_adv_loss
from losses.per_pixel import per_pixel_loss
from losses.perception import perception_loss
from losses.region_smoothing import region_smoothing_loss
from losses.s_adv import s_adv_loss
from losses.tv import total_loss
from losses.vgg19 import build_vgg19

from support_fun import fine_grained_revision, guided_filter_tf

In [None]:
import gdown

file_id = ''
gdown.download(f'https://drive.google.com/uc?id={file_id}', 'anime_file.zip', quiet=False)

In [None]:
!unzip -q anime_file.zip -d anime_folder

In [None]:
file_id = ''
gdown.download(f'https://drive.google.com/uc?id={file_id}', 'landscape_file.zip', quiet=False)

In [None]:
!unzip -q landscape_file.zip -d landscape_folder

In [None]:
file_id = ''
gdown.download(f'https://drive.google.com/uc?id={file_id}', 'blurred_file.zip', quiet=False)

In [None]:
!unzip -q blurred_file.zip -d blurred_folder

In [None]:
def create_dataset(data_dir, batch_size=8):
    """Create train and validation datasets with a split."""
    data_dir = pathlib.Path(data_dir)
    all_image_paths = list(data_dir.glob('*.jpg'))
    all_image_paths = [str(path) for path in all_image_paths]

    # Shuffle the paths
    total_images = len(all_image_paths)
    tf.random.set_seed(42)  # Ensure reproducibility
    all_image_paths = tf.random.shuffle(all_image_paths)

    def process_image(file_path):
        image = tf.io.read_file(file_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [256, 256])

        image = image / 127.5 - 1
        return image

    def prepare_dataset(paths):
        paths_ds = tf.data.Dataset.from_tensor_slices(paths)
        dataset = paths_ds.map(lambda x: process_image(x))
        dataset = dataset.shuffle(buffer_size=1000)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        return dataset

    # Create train and validation datasets
    dataset = prepare_dataset(all_image_paths)

    return dataset

In [None]:
# Define parameters
anime_dir = "/content/anime_folder/7000_anime"
landscape_dir = "/content/landscape_folder/7000_landscape"
blurred_dir = "/content/blurred_folder/7000_blurred"

batch_size = 8

anime_dataset = create_dataset(anime_dir, batch_size=batch_size)
landscape_dataset = create_dataset(landscape_dir, batch_size=batch_size)
blurred_dataset = create_dataset(blurred_dir, batch_size=batch_size)

In [None]:
encoder = build_encoder()

main_generator = build_main_generator(encoder=encoder)
support_generator = build_support_generator(encoder=encoder)

main_discriminator = build_main_discriminator()
support_discriminator = build_support_discriminator()

main_gan = build_main_gan(generator=main_generator, discriminator=main_discriminator)
support_gan = build_support_gan(generator=support_generator, discriminator=support_discriminator)

In [None]:
main_gan.summary()

In [None]:
model_vgg = VGG19(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
model_vgg.trainable = False

In [None]:
import wandb
wandb.login(key="")

In [None]:
# Initialize a W&B run
wandb.init(
    project="Anime GAN v3",  # Your project name
    name="first attemts",             # Optional: Name of the run
    config={                         # Optional: Configurations for your run
        "learning_rate": 1.0e-4,
    }
)

In [None]:
main_gan_optimizer = tf.keras.optimizers.Adam(learning_rate=1.0e-3)
support_gan_optimizer = tf.keras.optimizers.Adam(learning_rate=1.0e-3)

main_discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)
support_discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)

In [None]:
@tf.function
def train_main_generator(original_image):
    with tf.GradientTape() as tape:
        main_image = main_generator(original_image)
        support_image = support_generator(original_image)
        revised_image = fine_grained_revision(support_image)

        per_pixel_value = 2000 * per_pixel_loss(revised_image, main_image)
        perception_value = 3 * perception_loss(revised_image, main_image, model_vgg)
        m_adv_value = 200 * m_adv_loss(main_gan, original_image)
        main_tv_value = 0.005 * total_loss(main_image)

        main_loss = per_pixel_value + perception_value + m_adv_value + main_tv_value
    gradients = tape.gradient(main_loss, main_gan.trainable_variables)

    clipped_gradients = [tf.clip_by_norm(g, 5.0) for g in gradients]  # Clip gradients

    main_gan_optimizer.apply_gradients(zip(clipped_gradients, main_gan.trainable_variables))

    return [main_image, per_pixel_value, perception_value, m_adv_value, main_tv_value]


In [None]:
@tf.function
def train_support_generator(original_image, anime_image):
    with tf.GradientTape() as tape:
        support_image = support_generator(original_image)
        guided_image = guided_filter_tf(support_image)


        content_value = 2 * content_loss(original_image, guided_image, model_vgg)
        # region_smoothing_value = region_smoothing_loss(original_image, guided_image, model_vgg)
        region_smoothing_value = 0
        support_tv_value = 0.004 * total_loss(guided_image)
        color_value = 500 * color_loss(original_image, guided_image)
        s_adv_value = 50 * s_adv_loss(support_gan, original_image)
        gray_style_value = 0.05 * gray_style_loss(anime_image, guided_image, model_vgg)

        support_loss = content_value + gray_style_value + region_smoothing_value + s_adv_value + color_value + support_tv_value
    gradients = tape.gradient(support_loss, support_gan.trainable_variables)

    clipped_gradients = [tf.clip_by_norm(g, 5.0) for g in gradients]  # Clip gradients

    support_gan_optimizer.apply_gradients(zip(clipped_gradients, support_gan.trainable_variables))

    return [support_image, content_value, support_tv_value, color_value, s_adv_value, gray_style_value]


In [None]:
@tf.function
def train_main_discriminator(original_image):
    with tf.GradientTape() as tape:
        main_image = main_generator(original_image)
        support_image = support_generator(original_image)
        revised_image = fine_grained_revision(support_image)

        dm_value = 40 * dm_loss(main_discriminator, main_image, revised_image)
    gradients = tape.gradient(dm_value, main_discriminator.trainable_variables)

    clipped_gradients = [tf.clip_by_norm(g, 3.0) for g in gradients]  # Clip gradients

    main_discriminator_optimizer.apply_gradients(zip(clipped_gradients, main_discriminator.trainable_variables))

    return dm_value


In [None]:
@tf.function
def train_support_discriminator_gp(original_image):
    with tf.GradientTape() as tape:
        support_image = support_generator(original_image)
        guided_image = guided_filter_tf(support_image)

        ds_value = 10 * ds_loss(support_discriminator, guided_image, "gsp")
    gradients = tape.gradient(ds_value, support_discriminator.trainable_variables)

    clipped_gradients = [tf.clip_by_norm(g, 3.0) for g in gradients]  # Clip gradients

    support_discriminator_optimizer.apply_gradients(zip(clipped_gradients, support_discriminator.trainable_variables))
    return ds_value

@tf.function
def train_support_discriminator_a(a):
    with tf.GradientTape() as tape:
        ds_value = 10 * ds_loss(support_discriminator, a, "a")

    gradients = tape.gradient(ds_value, support_discriminator.trainable_variables)

    clipped_gradients = [tf.clip_by_norm(g, 3.0) for g in gradients]  # Clip gradients

    support_discriminator_optimizer.apply_gradients(zip(clipped_gradients, support_discriminator.trainable_variables))

    return ds_value

@tf.function
def train_support_discriminator_e(e):
    with tf.GradientTape() as tape:
        ds_value = 10 * ds_loss(support_discriminator, e, "e")
    gradients = tape.gradient(ds_value, support_discriminator.trainable_variables)

    clipped_gradients = [tf.clip_by_norm(g, 3.0) for g in gradients]  # Clip gradients

    support_discriminator_optimizer.apply_gradients(zip(clipped_gradients, support_discriminator.trainable_variables))

    return ds_value


In [None]:
def denorm(image):
    return (image + 1) / 2

In [None]:
epochs = 15
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    iter1 = iter(landscape_dataset)
    iter2 = iter(anime_dataset)
    iter3 = iter(blurred_dataset)
    step = 0
    while True:
        try:
            all_losses = {}

            main_discriminator.trainable = False
            support_discriminator.trainable = False

            original_image = next(iter1)
            anime_image = next(iter2)
            blurred_image = next(iter3)

            main_image, all_losses["per pixel"], all_losses["perception"], all_losses["m_adv"], all_losses["main_tv"] = train_main_generator(original_image)
            support_image, all_losses["content value"], all_losses["support tv"], all_losses["color value"], all_losses["s_adv"], all_losses["gray style"] = train_support_generator(original_image, anime_image)
            if step % 1 == 0:
                main_discriminator.trainable = True
                support_discriminator.trainable = True

                all_losses["dm"] = train_main_discriminator(original_image)
                all_losses["ds"] = train_support_discriminator_gp(original_image)
                all_losses["ds"] += train_support_discriminator_a(anime_image)
                all_losses["ds"] += train_support_discriminator_e(blurred_image)
            if step % 5 == 0:
                wandb.log({
                    "main_image": wandb.Image(denorm(main_image).numpy()[0], caption=f"Step {step + 1}"),
                    "support_image": wandb.Image(denorm(support_image).numpy()[0], caption=f"Step {step + 1}"),
                    **all_losses
                })
            step += 1
        except StopIteration:
            break
