In [16]:
import cv2
import torch
import numpy as np
import torch.nn as nn
from torch import nn
from torchvision import models
import torchvision.transforms as transforms
from torchvision import transforms
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
from PIL import Image, ImageTk, ImageFilter


# Function to upload an image
def upload_image():
    global original_image, displayed_image
    file_path = filedialog.askopenfilename()
    if file_path:
        original_image = Image.open(file_path)
        update_original_image(original_image)


# Function to save the processed image
def save_image():
    if processed_image:
        file_path = filedialog.asksaveasfilename(defaultextension=".png",
                                                 filetypes=[("PNG files", ".png"), ("All files", ".*")])
        if file_path:
            processed_image.save(file_path)



# Function to resize the image
def resize_image():
    global processed_image, displayed_image

    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before resizing.")
        return

    size = resize_combobox.get()
    if size:
        try:
            width, height = map(int, size.split('x'))
            processed_image = original_image.resize((width, height), Image.LANCZOS)
            update_processed_image(processed_image)
        except ValueError:
            messagebox.showerror("Invalid Input", "Please select a valid size from the dropdown list.")
    else:
        messagebox.showerror("Invalid Input", "Please select a size to resize the image.")


# Function to rotate the image
def rotate_image():
    global processed_image, displayed_image
    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before rotating.")
        return

    angle = angle_entry.get()
    if angle:
        try:
            angle = int(angle)
            processed_image = original_image.rotate(angle, expand=True)
            update_processed_image(processed_image)
        except ValueError:
            messagebox.showerror("Invalid Input", "Please enter a valid integer for angle.")
    else:
        messagebox.showerror("Invalid Input", "Please enter an angle to rotate the image.")


# Function to crop the image
def crop_image():
    global processed_image, displayed_image

    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before cropping.")
        return

    crop_values = crop_entry.get()
    if crop_values:
        try:
            # Parse crop values and convert them to integers
            left, top, right, bottom = map(int, crop_values.split(','))
            # Perform the cropping operation
            processed_image = original_image.crop((left, top, right, bottom))
            # Update the processed image display to show the cropped image
            update_processed_image(processed_image)
        except ValueError:
            messagebox.showerror("Invalid Input",
                                 "Please enter four integers separated by commas (e.g., 10, 10, 100, 100).")
    else:
        messagebox.showerror("Invalid Input", "Please enter crop values.")


# Function to show image properties
def show_image_properties():
    if original_image:
        info = f"Format: {original_image.format}\nSize: {original_image.size}\nMode: {original_image.mode}"
        messagebox.showinfo("Image Properties", info)
    else:
        messagebox.showerror("Error", "No image uploaded to show properties.")


# Function to update the original image display
def update_original_image(image):
    global displayed_image
    displayed_image = ImageTk.PhotoImage(image)
    original_canvas.create_image(0, 0, anchor="nw", image=displayed_image)
    original_canvas.config(scrollregion=original_canvas.bbox(tk.ALL))


# Function to update the processed image display
def update_processed_image(image):
    global displayed_image
    displayed_image = ImageTk.PhotoImage(image)
    processed_canvas.create_image(0, 0, anchor="nw", image=displayed_image)
    processed_canvas.config(scrollregion=processed_canvas.bbox(tk.ALL))


# Function to filter the image
original_image = None  
processed_image = None
filter_entry = None

def apply_image_filter():
    global processed_image, original_image, filter_entry

    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before applying a filter.")
        return

    selected_filter = filter_entry.get().strip()
    if selected_filter:
        # Map filter names to PIL ImageFilter attributes
        filter_map = {
            "BLUR": ImageFilter.BLUR,
            "CONTOUR": ImageFilter.CONTOUR,
            "DETAIL": ImageFilter.DETAIL,
            "EDGE_ENHANCE": ImageFilter.EDGE_ENHANCE,
            "EMBOSS": ImageFilter.EMBOSS,
            "SHARPEN": ImageFilter.SHARPEN
        }

        try:
            if selected_filter in filter_map:
                processed_image = original_image.filter_entry(filter_map[selected_filter])
                update_processed_image(processed_image)  
            else:
                messagebox.showerror("Invalid Filter", "Please select a valid filter.")
        except Exception as e:
            messagebox.showerror("Error", f"An error occurred while applying the filter: {str(e)}")
    else:
        messagebox.showerror("Invalid Input", "Please select a filter to apply.")
        
# Function to apply filter to the image
def apply_filter():
    global processed_image, displayed_image #global variables
    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before applying a filter.")
        return
    selected_filter = filter_combobox.get()
    if selected_filter:
        try:
            if selected_filter == "BLUR":
                processed_image = original_image.filter(ImageFilter.BLUR)
            elif selected_filter == "CONTOUR":
                processed_image = original_image.filter(ImageFilter.CONTOUR)
            elif selected_filter == "DETAIL":
                processed_image = original_image.filter(ImageFilter.DETAIL)
            elif selected_filter == "EDGE_ENHANCE":
                processed_image = original_image.filter(ImageFilter.EDGE_ENHANCE)
            elif selected_filter == "EMBOSS":
                processed_image = original_image.filter(ImageFilter.EMBOSS)
            elif selected_filter == "SHARPEN":
                processed_image = original_image.filter(ImageFilter.SHARPEN)
            update_processed_image(processed_image)
        except Exception as e:
            messagebox.showerror("Error", f"An error occurred while applying the filter: {str(e)}")
    else:
        messagebox.showerror("Invalid Input", "Please select a filter to apply.")


from PIL import ImageEnhance


# Function for Intensity Manipulation using Color Transformation

intensity_entry = None  

def adjust_intensity():
    global processed_image, displayed_image, original_image, intensity_entry  

    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before adjusting intensity.")
        return

    intensity_value = intensity_entry.get().strip()  # Correct widget reference
    if intensity_value:
        try:
            # Convert the intensity value to a float
            intensity_value = float(intensity_value)

            # Adjust intensity using brightness manipulation
            enhancer = ImageEnhance.Brightness(original_image)
            processed_image = enhancer.enhance(intensity_value)

            update_processed_image(processed_image)

        except ValueError:
            messagebox.showerror("Invalid Input", "Please enter a valid numeric value for intensity.")
        except Exception as e:
            messagebox.showerror("Error", f"An error occurred: {str(e)}")
    else:
        messagebox.showerror("Invalid Input", "Please enter an intensity value.")


# Function for Tonal Transformations

tonal_entry = None
def adjust_tonal_range():
    global processed_image, displayed_image, original_image, tonal_entry 

    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before adjusting tonal range.")
        return

    tonal_value = tonal_entry.get().strip()  # Correct widget reference
    if tonal_value:
        try:
            # Convert the tonal value to a float
            tonal_value = float(tonal_value)

            # Adjust contrast as tonal transformation
            enhancer = ImageEnhance.Contrast(original_image)
            processed_image = enhancer.enhance(tonal_value)

            update_processed_image(processed_image)

        except ValueError:
            messagebox.showerror("Invalid Input", "Please enter a valid numeric value for tonal adjustment.")
        except Exception as e:
            messagebox.showerror("Error", f"An error occurred: {str(e)}")
    else:
        messagebox.showerror("Invalid Input", "Please enter a tonal value.")



# Function for Color Balancing
color_balance_entry = None  

def adjust_color_balance():
    global processed_image, displayed_image, original_image, color_balance_entry 

    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before adjusting color balance.")
        return

    color_value = color_balance_entry.get().strip()  # Correct widget reference
    if color_value:
        try:
            # Convert the color value to a float
            color_value = float(color_value)

            # Adjust color balance
            enhancer = ImageEnhance.Color(original_image)
            processed_image = enhancer.enhance(color_value)

            update_processed_image(processed_image)

        except ValueError:
            messagebox.showerror("Invalid Input", "Please enter a valid numeric value for color balance.")
        except Exception as e:
            messagebox.showerror("Error", f"An error occurred: {str(e)}")
    else:
        messagebox.showerror("Invalid Input", "Please enter a color value.")

# Function for Image Segmentation
# Funtion for segmentation
def apply_segmentation():
    global processed_image, displayed_image
    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before applying segmentation.")
        return

    # Convert PIL image to OpenCV format
    open_cv_image = np.array(original_image)

    # Convert image to grayscale
    if len(open_cv_image.shape) == 3:  # If the image has 3 channels (RGB)
        gray_image = cv2.cvtColor(open_cv_image, cv2.COLOR_RGB2GRAY)
    else:
        gray_image = open_cv_image

    # Get threshold value from slider
    threshold_value = int(threshold_slider.get())

    # Apply binary threshold
    _, segmented_image = cv2.threshold(gray_image, threshold_value, 255, cv2.THRESH_BINARY)

    # Convert back to PIL image and update
    processed_image = Image.fromarray(segmented_image)
    update_processed_image(processed_image)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

class ESRGANModel(nn.Module):
    def __init__(self):
        super(ESRGANModel, self).__init__()
        
        self.conv_first = nn.Conv2d(3, 64, kernel_size=3, padding=1)  # Input layer
        self.RRDB_trunk = nn.Sequential(*[ResidualBlock(64) for _ in range(16)])  # 16 Residual Blocks
        self.trunk_conv = nn.Conv2d(64, 64, kernel_size=3, padding=1)  # Optional: adjust if needed
        self.upconv1 = nn.Conv2d(64, 64, kernel_size=3, padding=1)  # Intermediate upsampling
        self.upconv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)  # Change this to match checkpoint
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)  # Final output layer

    def forward(self, x):
        x = self.conv_first(x)
        x = self.RRDB_trunk(x)
        x = self.trunk_conv(x)  # Optional trunk convolution
        x = self.upconv1(x)  # Upsampling
        x = self.upconv2(x)  # Intermediate upsampling
        x = self.final_conv(x)  # Final output
        return x

# Load ESRGAN model
def load_esrgan_model():
    global model
    model_path = filedialog.askopenfilename(title="Select ESRGAN model", filetypes=[("Pytorch Model", "*.pth")])
    if model_path:
        model = ESRGANModel()  # Instantiate the ESRGAN model
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
        model.eval()
        messagebox.showinfo("Model Loaded", "ESRGAN model successfully loaded!")

# Apply deep learning enhancement (ESRGAN)
def enhance_image_with_deep_learning():
    global processed_image, original_image

    if original_image is None:
        messagebox.showerror("Error", "Please upload an image before enhancing.")
        return

    if model is None:
        messagebox.showerror("Error", "Please load a deep learning model first.")
        return

    # Convert PIL Image to Torch Tensor
    preprocess = transforms.Compose([
        transforms.ToTensor(),
    ])
    input_image = preprocess(original_image).unsqueeze(0)  # Add batch dimension

    
    with torch.no_grad():
        enhanced_image_tensor = model(input_image).clamp(0, 1)

    # Convert the output tensor back to PIL image
    postprocess = transforms.Compose([
        transforms.ToPILImage(),
    ])
    enhanced_image = postprocess(enhanced_image_tensor.squeeze(0))

    processed_image = enhanced_image
    update_processed_image(processed_image)  # Update display with the enhanced image
    
#style transfermodel
class YourStyleTransferModel(nn.Module):
    def __init__(self):
        super(YourStyleTransferModel, self).__init__()
        # Example layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        # Add more layers as needed
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)  # Final layer to output 3 channels

    def forward(self, x):
        x = self.conv1(x)
        # Apply additional layers...
        x = self.final_conv(x)  # Ensure output has 3 channels
        return x

# Load the model
def load_style_transfer_model():
    global style_model
    model_path = filedialog.askopenfilename(title="Select Style Transfer Model", filetypes=[("Pytorch Model", "*.pth")])
    if model_path:
        style_model = YourStyleTransferModel()  # Initialize your model class
        style_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)  # Load weights
        style_model.eval()
        messagebox.showinfo("Model Loaded", "Style transfer model successfully loaded!")

# Function to convert the image to a tensor
def load_image_to_tensor(image):
    transform = transforms.Compose([
        transforms.Resize((256, 256)), 
        transforms.ToTensor()  # Convert to tensor
    ])
    return transform(image).unsqueeze(0)

# Function to apply style transfer
def apply_style_transfer(input_image_tensor):
    global style_model
    with torch.no_grad():
        output_image_tensor = style_model(input_image_tensor)  # Forward pass through the model
    return output_image_tensor

# Function to handle the apply button
def on_apply_style_transfer():
    global original_image
    if original_image is None:
        messagebox.showwarning("No Image", "Please upload an image first.")
        return

    # Convert the original image to a tensor
    input_image_tensor = load_image_to_tensor(original_image)

    # Apply style transfer
    processed_image_tensor = apply_style_transfer(input_image_tensor)

    # Post-process the output tensor to a PIL Image (implement this function)
    processed_image = postprocess_output(processed_image_tensor)

    # Display or save the processed image (implement this function)
    processed_image.show()

# Function to post-process the output image tensor
def postprocess_output(output_image_tensor):
    # Convert the tensor back to a PIL Image
    output_image = output_image_tensor.squeeze(0)  # Remove batch dimension
    output_image = transforms.ToPILImage()(output_image)  # Convert to PIL Image
    return output_image

# Initialize main window
root = tk.Tk()
root.title("Modern Image Processing App")
root.geometry("900x700")
root.configure(bg="#f0f0f0")

style = ttk.Style()
style.configure("TButton", font=("Helvetica", 10), padding=10)
style.configure("TLabel", font=("Helvetica", 10), background="#f0f0f0")
style.configure("TFrame", background="#f0f0f0")
style.configure("TCombobox", font=("Helvetica", 10))

# Create frames for original and processed images
original_frame = ttk.Frame(root, width=400, height=400, padding=10, relief=tk.GROOVE)
original_frame.grid(row=1, column=0, padx=10, pady=10, sticky="nsew")

processed_frame = ttk.Frame(root, width=400, height=400, padding=10, relief=tk.GROOVE)
processed_frame.grid(row=1, column=1, padx=10, pady=10, sticky="nsew")

# Original Image Canvas and Scrollbars
original_canvas = tk.Canvas(original_frame, width=380, height=380, bg="white")
original_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

original_v_scrollbar = ttk.Scrollbar(original_frame, orient="vertical", command=original_canvas.yview)
original_v_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)

original_h_scrollbar = ttk.Scrollbar(original_frame, orient="horizontal", command=original_canvas.xview)
original_h_scrollbar.pack(side=tk.BOTTOM, fill=tk.X)

original_canvas.config(yscrollcommand=original_v_scrollbar.set, xscrollcommand=original_h_scrollbar.set)

# Processed Image Canvas and Scrollbars
processed_canvas = tk.Canvas(processed_frame, width=380, height=380, bg="white")
processed_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

processed_v_scrollbar = ttk.Scrollbar(processed_frame, orient="vertical", command=processed_canvas.yview)
processed_v_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)

processed_h_scrollbar = ttk.Scrollbar(processed_frame, orient="horizontal", command=processed_canvas.xview)
processed_h_scrollbar.pack(side=tk.BOTTOM, fill=tk.X)

processed_canvas.config(yscrollcommand=processed_v_scrollbar.set, xscrollcommand=processed_h_scrollbar.set)

# Upload Image Button
upload_button = ttk.Button(root, text="Upload Image", command=upload_image)
upload_button.grid(row=0, column=0, padx=10, pady=10, sticky="ew")

# Save Image Button
save_button = ttk.Button(root, text="Save Image", command=save_image)
save_button.grid(row=0, column=1, padx=10, pady=10, sticky="ew")

# Image Properties Button
properties_button = ttk.Button(root, text="Image Properties", command=show_image_properties)
properties_button.grid(row=2, column=0, padx=10, pady=10, sticky="ew")

# Resize Image Combobox
resize_label = ttk.Label(root, text="Resize Image:")
resize_label.grid(row=3, column=0, padx=10, pady=5, sticky="w")
resize_combobox = ttk.Combobox(root, values=["100x100", "200x200", "300x300"])
resize_combobox.grid(row=3, column=1, padx=10, pady=5, sticky="ew")
resize_button = ttk.Button(root, text="Apply", command=resize_image)
resize_button.grid(row=3, column=2, padx=10, pady=5, sticky="ew")

# Rotate Image Entry and Button
angle_label = ttk.Label(root, text="Angle:")
angle_label.grid(row=4, column=0, padx=10, pady=5, sticky="w")
angle_entry = ttk.Entry(root)
angle_entry.grid(row=4, column=1, padx=10, pady=5, sticky="ew")
rotate_button = ttk.Button(root, text="Rotate Image", command=rotate_image)
rotate_button.grid(row=4, column=2, padx=10, pady=5, sticky="ew")

# Crop Image Entry and Button
crop_label = ttk.Label(root, text="Crop (left, top, right, bottom):")
crop_label.grid(row=5, column=0, padx=10, pady=5, sticky="w")
crop_entry = ttk.Entry(root)
crop_entry.grid(row=5, column=1, padx=10, pady=5, sticky="ew")
crop_button = ttk.Button(root, text="Crop Image", command=crop_image)
crop_button.grid(row=5, column=2, padx=10, pady=5, sticky="ew")

# Filter dropdown and button
filter_label = ttk.Label(root, text="Apply Filter:")
filter_label.grid(row=6, column=0, padx=10, pady=5, sticky="w")

filter_combobox = ttk.Combobox(root, values=["BLUR", "CONTOUR", "DETAIL", "EDGE_ENHANCE", "EMBOSS", "SHARPEN"])
filter_combobox.grid(row=6, column=1, padx=10, pady=5, sticky="ew")

filter_button = ttk.Button(root, text="Apply Filter", command=apply_filter)
filter_button.grid(row=6, column=2, padx=10, pady=5, sticky="ew")

# Intensity Manipulation Entry and Button
intensity_label = ttk.Label(root, text="Intensity (0.0-6.0):")
intensity_label.grid(row=7, column=0, padx=10, pady=5, sticky="w")
intensity_entry = ttk.Entry(root)
intensity_entry.grid(row=7, column=1, padx=10, pady=5, sticky="ew")
intensity_button = ttk.Button(root, text="Adjust Intensity", command=adjust_intensity)
intensity_button.grid(row=7, column=2, padx=10, pady=5, sticky="ew")

# Tonal Transformation Entry and Button
tonal_label = ttk.Label(root, text="Tonal Range (0.0-2.0):")
tonal_label.grid(row=8, column=0, padx=10, pady=5, sticky="w")
tonal_entry = ttk.Entry(root)
tonal_entry.grid(row=8, column=1, padx=10, pady=5, sticky="ew")
tonal_button = ttk.Button(root, text="Adjust Tonal Range", command=adjust_tonal_range)
tonal_button.grid(row=8, column=2, padx=10, pady=5, sticky="ew")

# Color Balancing Entry and Button
color_label = ttk.Label(root, text="Color Balance (0.0-2.0):")
color_label.grid(row=9, column=0, padx=10, pady=5, sticky="w")
color_balance_entry = ttk.Entry(root)
color_balance_entry.grid(row=9, column=1, padx=10, pady=5, sticky="ew")
color_button = ttk.Button(root, text="Adjust Color Balance", command=adjust_color_balance)
color_button.grid(row=9, column=2, padx=10, pady=5, sticky="ew")

# Segmentation label and threshold slider
segmentation_label = ttk.Label(root, text="Image Segmentation:")
segmentation_label.grid(row=10, column=0, padx=10, pady=5, sticky="w")

threshold_slider = ttk.Scale(root, from_=0, to_=255, orient="horizontal")
threshold_slider.grid(row=10, column=1, padx=10, pady=5, sticky="ew")

segmentation_button = ttk.Button(root, text="Apply", command=apply_segmentation)
segmentation_button.grid(row=10, column=2, padx=10, pady=5, sticky="ew")

# Add button for loading the deep learning model
load_model_button = ttk.Button(root, text="Load Deep Learning Model", command=load_esrgan_model)
load_model_button.grid(row=11, column=0, padx=10, pady=10, sticky="ew")

# Add button to enhance the image using the deep learning model
enhance_button = ttk.Button(root, text="Enhance Image (Deep Learning)", command=enhance_image_with_deep_learning)
enhance_button.grid(row=11, column=1, padx=10, pady=10, sticky="ew")


# Load Style Transfer Model Button
load_style_button = ttk.Button(root, text="Load Style Transfer Model", command=load_style_transfer_model)
load_style_button.grid(row=12, column=0, padx=10, pady=10, sticky="ew")

# Apply Style Transfer Button
apply_style_button = ttk.Button(root, text="Apply Style Transfer", command=on_apply_style_transfer)
apply_style_button.grid(row=12, column=1, padx=10, pady=10, sticky="ew")


# Configure column and row resizing
root.grid_columnconfigure(0, weight=1)
root.grid_columnconfigure(1, weight=1)
root.grid_columnconfigure(2, weight=1)
root.grid_rowconfigure(1, weight=1)

# Global Variables
original_image = None
processed_image = None
displayed_image = None
img_tensor = None
style_model = None

# Run the application
root.mainloop()

  style_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)  # Load weights
