In [25]:
!git clone https://github.com/TheoXiong7/CS445_FinalProject.git

# Move data folder to current directory if not already there
!if [ ! -d "data" ]; then mv CS445_FinalProject/data .; fi

# Move results folder and eval file to current directory
!mv CS445_FinalProject/results .
!mv CS445_FinalProject/eval.csv .

!rm -rf CS445_FinalProject

Cloning into 'CS445_FinalProject'...
remote: Enumerating objects: 1670, done.[K
remote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 1670 (delta 4), reused 2 (delta 2), pack-reused 1651 (from 1)[K
Receiving objects: 100% (1670/1670), 599.85 MiB | 33.39 MiB/s, done.
Resolving deltas: 100% (43/43), done.
Updating files: 100% (1616/1616), done.
mv: cannot stat 'CS445_FinalProject/eval.csv': No such file or directory


In [4]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from IPython.display import display, clear_output
import ipywidgets as widgets
from PIL import Image

In [29]:
# Set paths for images to load
content_image_dir = os.path.join('data', 'images')
style_image_dir = os.path.join('data', 'styles')
baseline_image_dir = os.path.join('results', 'baseline')
depth_image_dir = os.path.join('results', 'depth_aware')

# Get all content/style images
content_images = sorted([f for f in os.listdir(content_image_dir) if f.endswith('.jpg')])
style_images = sorted([f for f in os.listdir(style_image_dir) if f.endswith('.jpg')])


# Create dropdowns for content/style images
content_dropdown = widgets.Dropdown(
    options = content_images,
    description = 'Content Images:',
    style = {'description_width': 'initial'}
)

style_dropdown = widgets.Dropdown(
    options = style_images,
    description = 'Style Images:',
    style = {'description_width': 'initial'}
)

output = widgets.Output()

def display_comparison(change = None):
  with output:
    clear_output(wait = True)

    content_name = content_dropdown.value
    style_name = style_dropdown.value

    content_image_path = os.path.join(content_image_dir, content_name)
    style_image_path = os.path.join(style_image_dir, style_name)
    eval_path = 'eval.csv'

    try:
      eval_df = pd.read_csv(eval_path)
    except FileNotFoundError:
      print(f"Error - {eval_path} not found.")


    # Remove '.jpg' file extension from content_name before using in output_name
    content_base_name = os.path.splitext(content_name)[0]
    style_base_name = os.path.splitext(style_name)[0]
    output_name = f"{content_base_name}__{style_base_name}.jpg"


    baseline_image_path = os.path.join(baseline_image_dir, output_name)
    depth_image_path = os.path.join(depth_image_dir, output_name)

    if not all(os.path.exists for p in [content_image_path, style_image_path, baseline_image_path, depth_image_path]):
      print(f"Error - could not find image file with {p} path")

    # Load images based on selection
    try:
      content_image = Image.open(content_image_path)
    except FileNotFoundError:
      print(f"Error - {content_image_path} not found.")

    try:
      style_image = Image.open(style_image_path)
    except FileNotFoundError:
      print(f"Error - {style_image_path} not found.")

    try:
      baseline_image = Image.open(baseline_image_path)
    except FileNotFoundError:
      print(f"Error - {baseline_image_path} not found.")

    try:
      depth_image = Image.open(depth_image_path)
    except FileNotFoundError:
      print(f"Error - {depth_image_path} not found.")

    # Display images
    fig = plt.figure(figsize = (18, 10))
    grid = fig.add_gridspec(2, 4, hspace = 0.3, wspace = 0.3)

    ax1 = fig.add_subplot(grid[0, 0])
    ax2 = fig.add_subplot(grid[0, 1])
    ax3 = fig.add_subplot(grid[0, 2])
    ax4 = fig.add_subplot(grid[0, 3])

    ax5 = fig.add_subplot(grid[1, :2])
    ax6 = fig.add_subplot(grid[1, 2:])

    ax1.imshow(content_image)
    ax1.set_title('Original Content Image')
    ax1.axis('off')

    ax2.imshow(style_image)
    ax2.set_title('Original Style Image')
    ax2.axis('off')

    ax3.imshow(baseline_image)
    ax3.set_title('Baseline Transfer')
    ax3.axis('off')

    ax4.imshow(depth_image)
    ax4.set_title('Depth-Aware Transfer')
    ax4.axis('off')

    # Display Evaluation metrics from csv for both baseline and depth-aware style transfers
    if not eval_df.empty:
      baseline_metrics = eval_df[
          (eval_df['content'] == content_base_name) &
          (eval_df['style'] == style_base_name) &
          (eval_df['variant'] == 'Baseline')
      ]

      depth_metrics = eval_df[
          (eval_df['content'] == content_base_name) &
          (eval_df['style'] == style_base_name) &
          (eval_df['variant'] == 'Depth-aware')
      ]

      ax5.axis('off')

      if not baseline_metrics.empty:
        row = baseline_metrics.iloc[0]
        metrics_text = "BASELINE METRICS \n" + "=" * 50 + "\n\n"
        metrics_text += f"SSIM                  {row['ssim']:.4f}\n"
        metrics_text += f"PSNR                  {row['psnr']:.2f}\n"
        metrics_text += f"MSE                   {row['mse']:.6f}\n"
        metrics_text += f"MAE                   {row['mae']:.6f}\n"
        metrics_text += f"Style -> Content      {row['style_to_content']:.4f}\n"
        metrics_text += f"Content Loss:         {row['content_loss']:.4f}\n"
        metrics_text += f"Style Loss:           {row['style_loss']:.4f}\n"

        if pd.notna(row.get('boundary_artifact')):
          metrics_text += f"Boundary Artifact:      {row['boundary_artifact']:.4f}\n"
        if pd.notna(row.get('lpips')):
          metrics_text += f"LPIPS:      {row['lpips']:.4f}\n"

      ax5.text(0.05, 0.95, metrics_text, fontsize = 10, verticalalignment = 'top', bbox=dict(boxstyle='round', facecolor = 'lightblue', alpha = 0.7))

      ax6.axis('off')
      if not depth_metrics.empty:
        row = depth_metrics.iloc[0]
        metrics_text = "BASELINE METRICS \n" + "=" * 50 + "\n\n"
        metrics_text += f"SSIM                  {row['ssim']:.4f}\n"
        metrics_text += f"PSNR                  {row['psnr']:.2f}\n"
        metrics_text += f"MSE                   {row['mse']:.6f}\n"
        metrics_text += f"MAE                   {row['mae']:.6f}\n"
        metrics_text += f"Style -> Content      {row['style_to_content']:.4f}\n"
        metrics_text += f"Content Loss:         {row['content_loss']:.4f}\n"
        metrics_text += f"Style Loss:           {row['style_loss']:.4f}\n"

        if pd.notna(row.get('boundary_artifact')):
          metrics_text += f"Boundary Artifact:      {row['boundary_artifact']:.4f}\n"
        if pd.notna(row.get('lpips')):
          metrics_text += f"LPIPS:      {row['lpips']:.4f}\n"

      ax6.text(0.05, 0.95, metrics_text, fontsize = 10, verticalalignment = 'top', bbox=dict(boxstyle='round', facecolor = 'lightblue', alpha = 0.7))



    plt.suptitle(f"{content_name} + {style_name}", fontsize = 14)
    plt.show()



content_dropdown.observe(display_comparison, names = 'value')
style_dropdown.observe(display_comparison, names = 'value')

display(widgets.VBox([
    widgets.HTML("<h2>Depth-Aware Style Transfer Results Viewer"),
    content_dropdown,
    style_dropdown,
    output
]))

display_comparison()








VBox(children=(HTML(value='<h2>Depth-Aware Style Transfer Results Viewer'), Dropdown(description='Content Imagâ€¦