In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  


In [2]:
import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("Using GPUs:", gpus)
    except RuntimeError as e:
        print("Error initializing GPUs:", e)
else:
    print("No GPUs available.")


Using GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]


In [3]:
for gpu in tf.config.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)


In [12]:
import os
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
import keras
from keras import optimizers
from tqdm import tqdm
import gc
import psutil

class PhotoSketchDataset:
    def __init__(self, color_dir, sketch_dir):
        self.color_dir = color_dir
        self.sketch_dir = sketch_dir
        self.color_files = []
        self.sketch_files = []
        
    def load_image_paths(self, preprocess_fraction=0.1):
        self.color_files = sorted(os.listdir(self.color_dir))
        self.sketch_files = sorted(os.listdir(self.sketch_dir))
        
        n_samples = int(len(self.color_files) * preprocess_fraction)
        self.color_files = self.color_files[:n_samples]
        self.sketch_files = self.sketch_files[:n_samples]
        
        train_color, temp_color, train_sketch, temp_sketch = train_test_split(
            self.color_files, self.sketch_files, train_size=0.7, random_state=42
        )
        
        val_ratio = 0.5
        val_color, test_color, val_sketch, test_sketch = train_test_split(
            temp_color, temp_sketch, train_size=val_ratio, random_state=42
        )
        
        return (train_color, train_sketch), (val_color, val_sketch), (test_color, test_sketch)
    
    @staticmethod
    def load_and_preprocess_image(image_path, target_size=(256, 256)):
        img = cv2.imread(image_path)
        if img is not None:
            img = cv2.resize(img, target_size)
            img = img.astype(np.float32) / 127.5 - 1
            return img
        return None

    def data_generator(self, color_files, sketch_files, batch_size=4):
        num_samples = len(color_files)
        while True:
            indices = np.random.permutation(num_samples)
            for start_idx in range(0, num_samples, batch_size):
                batch_indices = indices[start_idx:start_idx + batch_size]
                
                batch_color = []
                batch_sketch = []
                
                for idx in batch_indices:
                    color_path = os.path.join(self.color_dir, color_files[idx])
                    sketch_path = os.path.join(self.sketch_dir, sketch_files[idx])
                    
                    color_img = self.load_and_preprocess_image(color_path)
                    sketch_img = self.load_and_preprocess_image(sketch_path)
                    
                    if color_img is not None and sketch_img is not None:
                        batch_color.append(color_img)
                        batch_sketch.append(sketch_img)
                
                if batch_color and batch_sketch:
                    yield np.array(batch_color), np.array(batch_sketch)

In [13]:
class MemoryEfficientCycleGAN:
    def __init__(self):
        self.lambda_cycle = 10.0
        self.checkpoint_dir = './training_checkpoints'
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        self.generator_g = self._build_generator()
        self.generator_f = self._build_generator()
        self.discriminator_x = self._build_discriminator()
        self.discriminator_y = self._build_discriminator()
        
        self.generator_g_optimizer = optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
        self.generator_f_optimizer = optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
        self.discriminator_x_optimizer = optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
        self.discriminator_y_optimizer = optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
    
    def _build_generator(self):
        inputs = keras.layers.Input(shape=[256, 256, 3])
        
        x = keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
        x = keras.layers.LeakyReLU()(x)
        
        x = keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.LeakyReLU()(x)
        
        x = keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same')(x)
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.ReLU()(x)
        
        outputs = keras.layers.Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')(x)
        
        return keras.Model(inputs=inputs, outputs=outputs)
    
    def _build_discriminator(self):
        inputs = keras.layers.Input(shape=[256, 256, 3])
        
        x = keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
        x = keras.layers.LeakyReLU()(x)
        
        x = keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.LeakyReLU()(x)
        
        outputs = keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)
        
        return keras.Model(inputs=inputs, outputs=outputs)

    def _generator_loss(self, generated_output, real_target, source, cycled_source):
        gen_loss = tf.reduce_mean(keras.losses.binary_crossentropy(
            tf.ones_like(generated_output), generated_output))
        
        cycle_loss = tf.reduce_mean(tf.abs(source - cycled_source))
        total_loss = gen_loss + self.lambda_cycle * cycle_loss
        
        return total_loss, gen_loss, cycle_loss

    def _discriminator_loss(self, real_output, generated_output):
        real_loss = keras.losses.binary_crossentropy(tf.ones_like(real_output), real_output)
        generated_loss = keras.losses.binary_crossentropy(tf.zeros_like(generated_output), generated_output)
        total_loss = tf.reduce_mean(real_loss + generated_loss)
        return total_loss

    @tf.function
    def train_step(self, real_x, real_y):
        with tf.GradientTape(persistent=True) as tape:
            fake_y = self.generator_g(real_x, training=True)
            fake_x = self.generator_f(real_y, training=True)
            
            cycled_x = self.generator_f(fake_y, training=True)
            cycled_y = self.generator_g(fake_x, training=True)
            
            disc_real_x = self.discriminator_x(real_x, training=True)
            disc_fake_x = self.discriminator_x(fake_x, training=True)
            disc_real_y = self.discriminator_y(real_y, training=True)
            disc_fake_y = self.discriminator_y(fake_y, training=True)
            
            gen_g_total_loss, gen_g_loss, cycle_g_loss = self._generator_loss(
                disc_fake_y, real_y, real_x, cycled_x)
            gen_f_total_loss, gen_f_loss, cycle_f_loss = self._generator_loss(
                disc_fake_x, real_x, real_y, cycled_y)
            
            disc_x_loss = self._discriminator_loss(disc_real_x, disc_fake_x)
            disc_y_loss = self._discriminator_loss(disc_real_y, disc_fake_y)
        
        generator_g_gradients = tape.gradient(gen_g_total_loss, self.generator_g.trainable_variables)
        generator_f_gradients = tape.gradient(gen_f_total_loss, self.generator_f.trainable_variables)
        discriminator_x_gradients = tape.gradient(disc_x_loss, self.discriminator_x.trainable_variables)
        discriminator_y_gradients = tape.gradient(disc_y_loss, self.discriminator_y.trainable_variables)
        
        self.generator_g_optimizer.apply_gradients(
            zip(generator_g_gradients, self.generator_g.trainable_variables))
        self.generator_f_optimizer.apply_gradients(
            zip(generator_f_gradients, self.generator_f.trainable_variables))
        self.discriminator_x_optimizer.apply_gradients(
            zip(discriminator_x_gradients, self.discriminator_x.trainable_variables))
        self.discriminator_y_optimizer.apply_gradients(
            zip(discriminator_y_gradients, self.discriminator_y.trainable_variables))
        
        return {
            'gen_g_loss': gen_g_loss,
            'gen_f_loss': gen_f_loss,
            'disc_x_loss': disc_x_loss,
            'disc_y_loss': disc_y_loss,
            'cycle_loss': (cycle_g_loss + cycle_f_loss) / 2
        }

    def train(self, dataset, train_files, epochs=200, batch_size=4, save_interval=10):
        steps_per_epoch = len(train_files[0]) // batch_size
        
        for epoch in range(epochs):
            print(f"\nEpoch {epoch + 1}/{epochs}")
            generator = dataset.data_generator(train_files[0], train_files[1], batch_size)
            
            progress_bar = tqdm(range(steps_per_epoch))
            for _ in progress_bar:
                batch_x, batch_y = next(generator)
                losses = self.train_step(batch_x, batch_y)
                
                progress_bar.set_description(
                    f"G_loss: {(losses['gen_g_loss'] + losses['gen_f_loss'])/2:.4f}, "
                    f"D_loss: {(losses['disc_x_loss'] + losses['disc_y_loss'])/2:.4f}, "
                    f"Cycle_loss: {losses['cycle_loss']:.4f}")
                
                current_memory = psutil.Process().memory_info().rss / 1024 / 1024 / 1024
                if current_memory > 25:  
                    gc.collect()
                    tf.keras.backend.clear_session()
            
            if (epoch + 1) % save_interval == 0:
                self.save_models(epoch + 1)
            
            gc.collect()
    
    def save_models(self, epoch):
        self.generator_g.save(f'generator_g_epoch_{epoch}.keras')
        self.generator_f.save(f'generator_f_epoch_{epoch}.keras')

In [14]:
if __name__ == "__main__":
    dataset = PhotoSketchDataset(
        color_dir='/kaggle/input/anime-colorization/color',
        sketch_dir='/kaggle/input/anime-colorization/sketch'
    )
    
    train_files, val_files, test_files = dataset.load_image_paths(preprocess_fraction=0.1)
    
    cycle_gan = MemoryEfficientCycleGAN()
    cycle_gan.train(dataset, train_files, epochs=5, batch_size=16)


Epoch 1/5


G_loss: 6.2228, D_loss: 9.0390, Cycle_loss: 0.6662: 100%|██████████| 294/294 [02:57<00:00,  1.66it/s] 



Epoch 2/5


G_loss: 4.0426, D_loss: 9.7134, Cycle_loss: 0.5269: 100%|██████████| 294/294 [02:24<00:00,  2.03it/s] 



Epoch 3/5


G_loss: 2.3677, D_loss: 3.9286, Cycle_loss: 0.4792: 100%|██████████| 294/294 [02:15<00:00,  2.16it/s]



Epoch 4/5


G_loss: 1.6448, D_loss: 1.2303, Cycle_loss: 0.4014: 100%|██████████| 294/294 [02:16<00:00,  2.16it/s]



Epoch 5/5


G_loss: 1.1475, D_loss: 1.6121, Cycle_loss: 0.3711: 100%|██████████| 294/294 [02:25<00:00,  2.02it/s]


In [15]:
cycle_gan.generator_g.save('generator_g_final.keras')
cycle_gan.generator_f.save('generator_f_final.keras')


In [16]:
def generate_sketch(color_image_path, generator):
    color_img = PhotoSketchDataset.load_and_preprocess_image(color_image_path)
    
    if color_img is not None:
        color_img = np.expand_dims(color_img, axis=0)
        
        sketch_img = generator.predict(color_img)
        
        sketch_img = (sketch_img[0] + 1) / 2.0  
        sketch_img = (sketch_img * 255).astype(np.uint8)  
        
        return sketch_img
    return None

if __name__ == "__main__":
    generator_g = keras.models.load_model('generator_g_final.keras')
    
    sketch_image = generate_sketch('/kaggle/input/anime-colorization/color/1000241.png', generator_g)
    
    cv2.imwrite('generated_sketch.jpg', sketch_image)


I0000 00:00:1730487285.458234     252 service.cc:145] XLA service 0x7f9b08005c90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1730487285.458311     252 service.cc:153]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1730487285.458320     252 service.cc:153]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step


I0000 00:00:1730487286.481143     252 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
