In [1]:
import os
import cv2
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from PIL import Image
from google.colab import files

# Set environment variable for compressed TF Hub models
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'

# Load TensorFlow Hub style transfer model
model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')

# Create input and output directories
input_dir = "sample_inputs"
output_dir = "sample_outputs"
os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

def load_img(path, max_dim=1024):
    """Loads image, converts to RGB, resizes keeping aspect ratio to max_dim, and normalizes."""
    img = Image.open(path).convert('RGB')
    img.thumbnail((max_dim, max_dim))
    img_np = np.array(img).astype(np.float32)[np.newaxis, ...] / 255.
    return img_np

def tensor_to_image(tensor):
    """Convert TensorFlow tensor to PIL Image."""
    tensor = tensor * 255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor) > 3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return Image.fromarray(tensor)

def stylize_image(content_img, style_img):
    """Apply style transfer."""
    stylized = model(tf.constant(content_img), tf.constant(style_img))[0]
    return stylized

def process_video(input_video_path, style_image_path, output_video_path, max_dim=720):
    """Stylize video frame-by-frame and save to output."""
    style_img = load_img(style_image_path, max_dim=max_dim)

    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print(f"Failed to open video: {input_video_path}")
        return

    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    scale = min(max_dim / width, max_dim / height, 1)
    new_width = int(width * scale)
    new_height = int(height * scale)

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (new_width, new_height))

    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f"Processing {frame_count} frames...")

    frame_num = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img_pil = Image.fromarray(frame_rgb)
        img_pil.thumbnail((new_width, new_height))
        content_np = np.array(img_pil).astype(np.float32)[np.newaxis, ...] / 255.

        stylized_frame = stylize_image(content_np, style_img)
        stylized_img = tensor_to_image(stylized_frame)

        stylized_bgr = cv2.cvtColor(np.array(stylized_img), cv2.COLOR_RGB2BGR)
        out.write(stylized_bgr)

        frame_num += 1
        if frame_num % 10 == 0 or frame_num == frame_count:
            print(f"Processed {frame_num}/{frame_count} frames")

    cap.release()
    out.release()
    print(f"Stylized video saved to: {output_video_path}")

# User upload prompts
print("Upload your content image file:")
content_uploaded = files.upload()

print("Upload your style image file:")
style_uploaded = files.upload()

print("Optionally, upload a video file (or skip):")
video_uploaded = files.upload()

# Save uploaded files with fixed names
for filename in content_uploaded.keys():
    content_path = os.path.join(input_dir, "content.jpg")
    with open(content_path, 'wb') as f:
        f.write(content_uploaded[filename])
    print(f"Content image saved as {content_path}")

for filename in style_uploaded.keys():
    style_path = os.path.join(input_dir, "style.jpg")
    with open(style_path, 'wb') as f:
        f.write(style_uploaded[filename])
    print(f"Style image saved as {style_path}")

video_path = None
if len(video_uploaded) > 0:
    for filename in video_uploaded.keys():
        video_path = os.path.join(input_dir, "input_video.mp4")
        with open(video_path, 'wb') as f:
            f.write(video_uploaded[filename])
        print(f"Video file saved as {video_path}")
else:
    print("No video uploaded.")

# Stylize image
if os.path.exists(content_path) and os.path.exists(style_path):
    content_img = load_img(content_path, max_dim=1024)  # Use HD max dimension for image
    style_img = load_img(style_path, max_dim=1024)
    stylized_tensor = stylize_image(content_img, style_img)
    output_img = tensor_to_image(stylized_tensor)
    output_img_path = os.path.join(output_dir, "stylized_image_hd.jpg")
    output_img.save(output_img_path)
    print(f"Stylized image saved to: {output_img_path}")
else:
    print("Content or Style image missing, skipping image stylization.")

# Stylize video if available
if video_path and os.path.exists(video_path) and os.path.exists(style_path):
    output_video_path = os.path.join(output_dir, "stylized_video.mp4")
    process_video(video_path, style_path, output_video_path, max_dim=720)
else:
    print("Video or style image missing, skipping video stylization.")

Upload your content image file:


Saving stata.jpg to stata.jpg
Upload your style image file:


Saving bango.jpg to bango.jpg
Optionally, upload a video file (or skip):


Saving fox.mp4 to fox.mp4
Content image saved as sample_inputs/content.jpg
Style image saved as sample_inputs/style.jpg
Video file saved as sample_inputs/input_video.mp4
Stylized image saved to: sample_outputs/stylized_image_hd.jpg
Processing 420 frames...
Processed 10/420 frames
Processed 20/420 frames
Processed 30/420 frames
Processed 40/420 frames
Processed 50/420 frames
Processed 60/420 frames
Processed 70/420 frames
Processed 80/420 frames
Processed 90/420 frames
Processed 100/420 frames
Processed 110/420 frames
Processed 120/420 frames
Processed 130/420 frames
Processed 140/420 frames
Processed 150/420 frames
Processed 160/420 frames
Processed 170/420 frames
Processed 180/420 frames
Processed 190/420 frames
Processed 200/420 frames
Processed 210/420 frames
Processed 220/420 frames
Processed 230/420 frames
Processed 240/420 frames
Processed 250/420 frames
Processed 260/420 frames
Processed 270/420 frames
Processed 280/420 frames
Processed 290/420 frames
Processed 300/420 frames
Pr