In [1]:
import cv2
import numpy as np
import math
from skimage.segmentation import slic
from skimage.measure import label, regionprops
from skimage import color as skcolor
from sklearn.cluster import KMeans

##############################################################################
# Utility Functions
##############################################################################

def rgb_to_hsv_np(img):
    """Convert an RGB image [0..1] to HSV [0..1], using skimage for convenience."""
    # skimage.color.rgb2hsv expects float64 in [0..1].
    return skcolor.rgb2hsv(img.astype(np.float64))

def hsv_to_rgb_np(img):
    """Convert an HSV image [0..1] to RGB [0..1], using skimage for convenience."""
    return skcolor.hsv2rgb(img)

def bgr_to_lab(img_bgr):
    """Convert a BGR [0..255] image to Lab [0..255] using OpenCV."""
    return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)

def lab_to_bgr(img_lab):
    """Convert a Lab [0..255] image to BGR [0..255] using OpenCV."""
    return cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR)

def delta_e94(lab1, lab2):
    """
    Approximate ΔE94 between two Lab colors.
    Each is (L, a, b). This is a simplified version from the paper.
    """
    # Paper uses: ΔE94 = sqrt((ΔL/KL)^2 + (ΔAB/(1 + K1*AB1))^2 + (ΔH/(1 + K2*AB1))^2)
    # We'll do a simpler approach or partial approach for demonstration.

    L1, a1, b1 = lab1
    L2, a2, b2 = lab2
    dL = L1 - L2
    C1 = math.sqrt(a1*a1 + b1*b1)
    C2 = math.sqrt(a2*a2 + b2*b2)
    dC = C1 - C2
    dA = a1 - a2
    dB = b1 - b2
    dH_sq = (dA*dA + dB*dB - dC*dC) if (dC*dC <= (dA*dA + dB*dB)) else 0.0

    # Constants from paper:
    KL = 2.0
    K1 = 0.048
    K2 = 0.014

    # We compute approximate AB1 as the average of C1, from the paper's eqn:
    AB1 = C1
    # For simplicity, we do a direct approach:
    termL = (dL / KL)**2
    termC = (dC / (1 + K1 * AB1))**2
    termH = (math.sqrt(abs(dH_sq)) / (1 + K2 * AB1))**2

    return math.sqrt(termL + termC + termH)

def compute_saliency_opencv(img_bgr):
    """Compute saliency using OpenCV's spectral residual method."""
    saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
    success, saliency_map = saliency.computeSaliency(img_bgr)
    if not success:
        # fallback: just return grayscale
        gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
        return gray.astype(np.float32) / 255.0
    return saliency_map.astype(np.float32)

def guided_filter_color(I, p, radius=5, eps=1e-3):
    """
    Simple approximate guided filter using OpenCV's ximgproc if available.
    If not, we do a basic bilateral filter fallback.
    """
    # Try to import guided filter from opencv-contrib
    try:
        import cv2.ximgproc as xip
        # p and I must be 8-bit or 32F
        if p.dtype != np.float32:
            p = p.astype(np.float32)
        if I.dtype != np.float32:
            I = I.astype(np.float32)
        guided = xip.guidedFilter(I, p, radius, eps)
        return guided
    except ImportError:
        # fallback: use a bilateral filter as a rough approximation
        return cv2.bilateralFilter(p, radius*2+1, 75, 75)

##############################################################################
# 1. Dominant Color Estimation (Grid-based in HSV)
##############################################################################

def estimate_dominant_colors(img_bgr, grid_size=8, min_pixels=50):
    """
    Implements the grid-based dominant color estimation in HSV from the paper (Sec. 3.1.1),
    but in a simplified manner:
      1. Convert image to HSV.
      2. Create a 3D histogram with shape [grid_size, grid_size, grid_size].
      3. Remove outlier bins using a threshold (3-sigma or min_pixels).
      4. Merge connected bins to find final 'dominant color' centroids.
    Returns a list of HSV colors in [0..255].
    """
    h, w, _ = img_bgr.shape
    img_hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)  # each channel [0..180] for H, [0..255] for S,V

    # 3D histogram: We'll use integer binning. H range: 0..180, S range: 0..256, V range: 0..256
    # We want grid_size bins for each channel => bin width:
    h_bin = 180 / grid_size
    s_bin = 256 / grid_size
    v_bin = 256 / grid_size

    # Build the histogram
    hist_3d = np.zeros((grid_size, grid_size, grid_size), dtype=np.int32)
    for y in range(h):
        for x in range(w):
            H, S, V = img_hsv[y, x]
            hi = int(H // h_bin)
            si = int(S // s_bin)
            vi = int(V // v_bin)
            if hi >= grid_size: hi = grid_size - 1
            if si >= grid_size: si = grid_size - 1
            if vi >= grid_size: vi = grid_size - 1
            hist_3d[hi, si, vi] += 1

    # Remove outliers (bins with fewer than min_pixels)
    mask = (hist_3d >= min_pixels).astype(np.uint8)

    # We now find connected components in 3D (26-neighborhood).
    # We'll label them and compute centroids in HSV bin space.
    labeled = label(mask, connectivity=3)  # scikit-image 3D labeling
    regions = regionprops(labeled)
    dominant_hsv = []
    for r in regions:
        # regionprops centroid is in (z,y,x) order if 3D, but our labeling might be (hi,si,vi).
        coords = r.coords  # list of (hi, si, vi)
        total_count = 0
        sumH = 0
        sumS = 0
        sumV = 0
        for (hi, si, vi) in coords:
            count = hist_3d[hi, si, vi]
            # The actual H, S, V center of that bin
            # center in bin space => hi+0.5, etc.
            # Convert bin index -> actual H,S,V
            # H in [0..180], bin width = h_bin => actual H ~ (hi+0.5)*h_bin
            # We'll weight by the # of pixels in that bin
            sumH += ( (hi + 0.5)*h_bin ) * count
            sumS += ( (si + 0.5)*s_bin ) * count
            sumV += ( (vi + 0.5)*v_bin ) * count
            total_count += count

        if total_count > 0:
            cH = sumH / total_count
            cS = sumS / total_count
            cV = sumV / total_count
            # store as integer
            cH = np.clip(cH, 0, 179)
            cS = np.clip(cS, 0, 255)
            cV = np.clip(cV, 0, 255)
            dominant_hsv.append([cH, cS, cV])

    # Return as a list of HSV in [0..255], H in [0..179].
    # The paper merges them further if needed. We'll just return them.
    return dominant_hsv

##############################################################################
# 2. Soft Segmentation via Cost Volume + Guided Filtering
##############################################################################

def soft_segmentation(img_bgr, dom_hsv_list):
    """
    Soft-segment the image based on the dominant colors (Sec. 3.1.2).
    Steps:
      1) Convert each dominant color to Lab for cost computation.
      2) For each pixel, compute cost = ΔE94(Lab_pixel, Lab_dominant).
      3) Filter cost volumes with a guided filter (paper used cost-volume filtering).
      4) Convert to soft segmentation by normalized exp(-cost).
    Returns: seg_map with shape [H,W,len(dom_hsv_list)] => sum over last dim = 1.
    """
    H, W, _ = img_bgr.shape
    img_lab = bgr_to_lab(img_bgr).astype(np.float32)
    # Precompute Lab for each pixel
    lab_pixels = img_lab.reshape(-1, 3)

    # Convert each dominant HSV to BGR->Lab for cost
    dom_lab_list = []
    for hsv in dom_hsv_list:
        color_bgr = np.uint8([[hsv]])  # shape(1,1,3) in HSV
        color_bgr = cv2.cvtColor(color_bgr, cv2.COLOR_HSV2BGR)
        color_lab = cv2.cvtColor(color_bgr, cv2.COLOR_BGR2LAB)[0,0,:].astype(np.float32)
        dom_lab_list.append(color_lab)

    # Build cost volumes
    cost_volumes = []
    for cidx, dom_lab in enumerate(dom_lab_list):
        cost_map = np.zeros((H, W), dtype=np.float32)
        for i in range(H):
            for j in range(W):
                Lp, Ap, Bp = img_lab[i,j]
                cost_map[i,j] = delta_e94((Lp,Ap,Bp), dom_lab)
        # Filter cost_map with guided filter using the original image as guidance
        # (paper uses cost-volume filtering, we approximate with guidedFilter on each slice).
        # Convert BGR to float32 for guided filter
        guided = img_bgr.astype(np.float32) / 255.0
        cost_map_filtered = guided_filter_color(guided, cost_map, radius=5, eps=1e-3)
        cost_volumes.append(cost_map_filtered)

    # Convert cost to soft membership
    # We'll do: membership = exp(-cost / alpha), then normalize
    alpha = 10.0  # scale factor to control softness
    seg_stack = []
    for cvm in cost_volumes:
        seg_stack.append(np.exp(-cvm / alpha))

    seg_stack = np.stack(seg_stack, axis=-1)  # shape(H,W,#dom)
    sum_ = np.sum(seg_stack, axis=-1, keepdims=True)
    sum_ = np.clip(sum_, 1e-8, None)
    seg_stack /= sum_
    return seg_stack

##############################################################################
# 3. Region Matching
##############################################################################

def compute_region_features(img_bgr, seg_stack):
    """
    Compute region features: saliency, luminance, pixel ratio for each
    dominant-color region. We'll pick the argmax per pixel as 'hard region label'
    to measure stats. (Paper does 'winner-take-all' in cost volume).
    """
    H, W, C = seg_stack.shape
    # Hard assignment
    labels = np.argmax(seg_stack, axis=-1).astype(np.int32)  # shape(H,W)
    # Saliency map
    sal_map = compute_saliency_opencv(img_bgr)

    # Luminance from Lab
    img_lab = bgr_to_lab(img_bgr)
    L_chan = img_lab[:,:,0].astype(np.float32)

    # For each region, compute average saliency, average L, and pixel ratio
    region_feats = []
    total_pixels = H*W
    for ridx in range(C):
        mask = (labels == ridx)
        pix_count = np.sum(mask)
        if pix_count < 1:
            region_feats.append((0,0,0))  # fallback
            continue
        mean_sal = np.mean(sal_map[mask])
        mean_lum = np.mean(L_chan[mask])
        ratio = float(pix_count) / total_pixels
        region_feats.append((mean_sal, mean_lum, ratio))
    return region_feats

def match_regions(src_feats, tgt_feats, wS=1.0, wL=1.0, wR=1.0):
    """
    From paper eq. (6):
    We find for each src region i => best match j in target that minimizes:
       sqrt( wS*(S_i - S_j)^2 + wL*(L_i - L_j)^2 + wR*(R_i - R_j)^2 )
    Returns a dict: match[i] = j
    """
    match_dict = {}
    for i, (sS, sL, sR) in enumerate(src_feats):
        best_j = None
        best_dist = 1e10
        for j, (tS, tL, tR) in enumerate(tgt_feats):
            dS = (sS - tS)
            dL = (sL - tL)
            dR = (sR - tR)
            dist = math.sqrt(wS*(dS**2) + wL*(dL**2) + wR*(dR**2))
            if dist < best_dist:
                best_dist = dist
                best_j = j
        match_dict[i] = best_j
    return match_dict

##############################################################################
# 4. Local Color Transfer with Modified Reinhard + Local Gamma
##############################################################################

def local_color_transfer(src_bgr, tgt_bgr, src_seg, tgt_seg, match_dict):
    """
    Implements eq. (10)-(12) from paper in a simplified manner:
      - We separate each region in src & tgt
      - For a,b channels => standard Reinhard per region
      - For L => local gamma correction
    We then blend the results using the soft segmentation from src_seg.
    """
    H, W, C = src_seg.shape
    src_lab = bgr_to_lab(src_bgr).astype(np.float32)
    tgt_lab = bgr_to_lab(tgt_bgr).astype(np.float32)

    # Hard labels
    src_labels = np.argmax(src_seg, axis=-1)  # shape(H,W)
    tgt_labels = np.argmax(tgt_seg, axis=-1)

    # We'll build an output Lab
    out_lab = np.zeros_like(src_lab)

    # Precompute region stats (mean, std) for L, a, b in both images
    def compute_region_stats(img_lab, labels, region_idx):
        """Compute (meanL, stdL, meanA, stdA, meanB, stdB) for the given region."""
        mask = (labels == region_idx)
        if not np.any(mask):
            return (0,1, 0,1, 0,1)  # fallback
        region_vals = img_lab[mask]
        L_ = region_vals[:,0]
        A_ = region_vals[:,1]
        B_ = region_vals[:,2]
        return (np.mean(L_), np.std(L_)+1e-6,
                np.mean(A_), np.std(A_)+1e-6,
                np.mean(B_), np.std(B_)+1e-6)

    # We'll store region stats in dict for source & target
    src_stats = {}
    tgt_stats = {}
    src_num_regions = C
    tgt_num_regions = C

    for i in range(src_num_regions):
        src_stats[i] = compute_region_stats(src_lab, src_labels, i)
    for j in range(tgt_num_regions):
        tgt_stats[j] = compute_region_stats(tgt_lab, tgt_labels, j)

    # Local gamma parameters from eq. (12)
    # We'll compute global L means:
    src_global_L = np.mean(src_lab[:,:,0])
    tgt_global_L = np.mean(tgt_lab[:,:,0])

    # For each pixel in src, do:
    #   region i => matched region j
    #   a*, b* => eq. (10),(11)
    #   L => eq. (12) local gamma
    out_L = np.zeros((H,W), dtype=np.float32)
    out_A = np.zeros((H,W), dtype=np.float32)
    out_B = np.zeros((H,W), dtype=np.float32)

    for i in range(H):
        for j in range(W):
            ridx = src_labels[i,j]
            matched_r = match_dict[ridx]
            # get stats
            sMeanL, sStdL, sMeanA, sStdA, sMeanB, sStdB = src_stats[ridx]
            tMeanL, tStdL, tMeanA, tStdA, tMeanB, tStdB = tgt_stats[matched_r]

            Lp, Ap, Bp = src_lab[i,j]

            # eq. (10),(11) for a,b
            a_new = ((Ap - sMeanA) * (tStdA / sStdA)) + tMeanA
            b_new = ((Bp - sMeanB) * (tStdB / sStdB)) + tMeanB

            # eq. (12) for L => local gamma
            # gamma_i = |beta_i + alpha*(μLs - μLt)|
            #   beta_i = [1 + 2(μli - μlj)]
            #   alpha = exp( |Ls - Lt| * (μli/μLs - μlj/μLt) )
            # We'll interpret "μli" as sMeanL, "μlj" as tMeanL
            # and Ls=src_global_L, Lt=tgt_global_L
            mu_li = sMeanL
            mu_lj = tMeanL
            Ls = src_global_L
            Lt = tgt_global_L

            beta_i = (1.0 + 2.0*(mu_li - mu_lj)/255.0)  # scaled by 255
            # Protect from negative or zero
            # alpha_i = exp(|Ls - Lt| * ( (mu_li/μLs) - (mu_lj/μLt) ))
            # The paper's exact formula is a bit ambiguous; we do a best guess:
            #   alpha = exp( |Ls-Lt| * ( (mu_li/Ls) - (mu_lj/Lt) ) )
            # We'll clamp to avoid negative or zero denominators
            Ls = max(Ls, 1e-6)
            Lt = max(Lt, 1e-6)
            alpha_i = math.exp(abs(Ls - Lt) * ((mu_li/Ls) - (mu_lj/Lt)))

            gamma_i = abs(beta_i + alpha_i*(Ls - Lt)/255.0)  # extra scaling
            # now L' = L^gamma_i, but we want to keep it in [0..255].
            # The paper's eq. (12) says: L0 = sum( (Li)^gamma_i ), but we have per-pixel approach
            # We'll do: L_new = (Lp/255)^gamma_i * 255
            # This is a simplified approach.
            Lp_norm = Lp / 255.0
            L_new = (Lp_norm ** gamma_i) * 255.0

            out_L[i,j] = np.clip(L_new, 0, 255)
            out_A[i,j] = np.clip(a_new, 0, 255)
            out_B[i,j] = np.clip(b_new, 0, 255)

    # Combine into Lab
    combined_lab = np.stack([out_L, out_A, out_B], axis=-1).astype(np.uint8)

    # Because we did a "hard" approach per pixel, we can blend with the soft weights from src_seg
    # But the paper merges them region by region. We'll do a simpler approach:
    #   out_lab is final.
    return combined_lab

##############################################################################
# Main Pipeline
##############################################################################

def local_color_transfer_pipeline(source_path, target_path):
    """
    Full pipeline from the Yoo et al. paper, approximate version.
    1) Estimate dominant colors in both images.
    2) Soft-segment both images.
    3) Region matching by saliency, luminance, pixel ratio.
    4) Local color transfer with local gamma correction.
    """
    # Load images
    src_bgr = cv2.imread(source_path)
    tgt_bgr = cv2.imread(target_path)
    if src_bgr is None or tgt_bgr is None:
        raise IOError("Could not load source or target image.")

    # 1) Dominant colors
    src_dom = estimate_dominant_colors(src_bgr, grid_size=8, min_pixels=50)
    tgt_dom = estimate_dominant_colors(tgt_bgr, grid_size=8, min_pixels=50)

    # 2) Soft segmentation
    src_seg = soft_segmentation(src_bgr, src_dom)  # shape(H,W,Cs)
    tgt_seg = soft_segmentation(tgt_bgr, tgt_dom)  # shape(H,W,Ct)

    # For simplicity, we assume the same number of dominant colors in both images.
    # If they differ, you can keep only the min(C_s, C_t) largest segments or similar.
    # We'll just truncate to min length:
    Cs = src_seg.shape[-1]
    Ct = tgt_seg.shape[-1]
    C_min = min(Cs, Ct)
    src_seg = src_seg[:,:,:C_min]
    tgt_seg = tgt_seg[:,:,:C_min]

    # 3) Region matching
    src_feats = compute_region_features(src_bgr, src_seg)
    tgt_feats = compute_region_features(tgt_bgr, tgt_seg)
    match_dict = match_regions(src_feats, tgt_feats, wS=1.0, wL=1.0, wR=1.0)

    # 4) Local color transfer
    out_lab = local_color_transfer(src_bgr, tgt_bgr, src_seg, tgt_seg, match_dict)
    out_bgr = lab_to_bgr(out_lab)
    return out_bgr




In [None]:
##############################################################################
# Demo Main
##############################################################################

if __name__ == "__main__":
    # Example usage with the two beach images:
    source_path = "reference.jpeg"  # The image to be recolored
    target_path = "input.jpg"     # The reference style image

    result_bgr = local_color_transfer_pipeline(source_path, target_path)
    out_name = "local_color_result.jpg"


In [8]:
import matplotlib.pyplot as plt

# Load original and target images
source_img = cv2.imread(source_path)
target_img = cv2.imread(target_path)

# Convert BGR to RGB for correct color display in matplotlib
source_rgb = cv2.cvtColor(source_img, cv2.COLOR_BGR2RGB)
target_rgb = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)
result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)

# Plot all three images for comparison
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.imshow(source_rgb)
plt.axis('off')
plt.title("Source Image")

plt.subplot(1, 3, 2)
plt.imshow(target_rgb)
plt.axis('off')
plt.title("Target Image")

plt.subplot(1, 3, 3)
plt.imshow(result_rgb)
plt.axis('off')
plt.title("Color Transferred Image")

plt.show()




In [7]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from skimage.segmentation import slic
from skimage.measure import regionprops
from skimage import color as skcolor
from scipy.spatial.distance import cdist

# Load pre-trained ResNet for feature extraction
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.feature_layer = nn.Sequential(*list(resnet.children())[:-1])  # Remove FC layer

    def forward(self, x):
        x = self.feature_layer(x)
        return x.view(x.shape[0], -1)  # Flatten feature maps

# Preprocess images for ResNet
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

# Compute deep features for regions
def compute_deep_features(img_bgr, seg_map, model, device):
    H, W, num_regions = seg_map.shape
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    region_features = []

    for ridx in range(num_regions):
        mask = seg_map[:, :, ridx] > 0.5  # Threshold for hard segmentation
        if np.sum(mask) == 0:
            region_features.append(np.zeros(2048))  # ResNet50 feature size
            continue

        # Extract region and resize for ResNet
        region = img_rgb * np.expand_dims(mask, axis=-1)
        region = np.uint8(region)
        region_resized = cv2.resize(region, (224, 224))

        # Compute deep features
        region_tensor = preprocess_image(region_resized).to(device)
        with torch.no_grad():
            features = model(region_tensor)
        region_features.append(features.cpu().numpy().flatten())

    return np.array(region_features)

# Match regions using deep features and cosine similarity
def match_regions_deep(src_feats, tgt_feats):
    distances = cdist(src_feats, tgt_feats, metric='cosine')  # Cosine similarity
    matches = np.argmin(distances, axis=1)  # Find best match for each source region
    return {i: matches[i] for i in range(len(src_feats))}

# Local Color Transfer Pipeline
def local_color_transfer_pipeline(source_path, target_path):
    # Load images
    src_bgr = cv2.imread(source_path)
    tgt_bgr = cv2.imread(target_path)
    if src_bgr is None or tgt_bgr is None:
        raise IOError("Could not load source or target image.")

    # Step 1: Estimate Dominant Colors (Same as Before)
    src_dominant_hsv = estimate_dominant_colors(src_bgr)
    tgt_dominant_hsv = estimate_dominant_colors(tgt_bgr)

    # Step 2: Soft Segmentation
    src_seg = soft_segmentation(src_bgr, src_dominant_hsv)
    tgt_seg = soft_segmentation(tgt_bgr, tgt_dominant_hsv)

    # Step 3: Deep Learning-Based Region Matching
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = FeatureExtractor().to(device).eval()

    src_feats = compute_deep_features(src_bgr, src_seg, model, device)
    tgt_feats = compute_deep_features(tgt_bgr, tgt_seg, model, device)
    match_dict = match_regions_deep(src_feats, tgt_feats)

    # Step 4: Local Color Transfer
    output_lab = local_color_transfer(src_bgr, tgt_bgr, src_seg, tgt_seg, match_dict)

    # Convert to BGR and save
    output_bgr = lab_to_bgr(output_lab)
    cv2.imwrite("output.png", output_bgr)
    print("Color transfer completed. Output saved as output.png")

# Run the pipeline
source_img = "input.jpg"
target_img = "reference.jpeg"
local_color_transfer_pipeline(source_img, target_img)






In [None]:
import cv2
import numpy as np
from sklearn.cluster import KMeans

def kmeans_segmentation(image, k=4):
    # Convert image to HSV
    hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    
    # Reshape the image to a 2D array of pixels (each row represents a pixel)
    pixels = hsv_image.reshape((-1, 3))
    
    # Apply K-means clustering
    kmeans = KMeans(n_clusters=k, random_state=0)
    kmeans.fit(pixels)
    
    # Reshape the labels back to the original image shape
    segmented_image = kmeans.labels_.reshape(image.shape[0], image.shape[1])
    
    return segmented_image, kmeans.cluster_centers_

def apply_color_transfer(source_image, target_image, segmented_image, cluster_centers):
    # Convert both images to HSV color space
    source_hsv = cv2.cvtColor(source_image, cv2.COLOR_BGR2HSV)
    target_hsv = cv2.cvtColor(target_image, cv2.COLOR_BGR2HSV)
    
    # Apply color transfer based on cluster centers
    for i in range(len(cluster_centers)):
        # Find the pixels in the target image that belong to this cluster
        target_cluster_pixels = (segmented_image == i)
        
        # For each pixel, adjust the color to match the source's corresponding cluster
        source_cluster_center = cluster_centers[i]
        target_hsv[target_cluster_pixels] = source_cluster_center
        
    # Convert the result back to BGR color space
    result_image = cv2.cvtColor(target_hsv, cv2.COLOR_HSV2BGR)
    
    return result_image

# Load source and target images
source_image = cv2.imread('input.jpg')
target_image = cv2.imread('reference.jpeg')

# Apply K-means segmentation to the target image
segmented_image, cluster_centers = kmeans_segmentation(target_image, k=2)

# Apply color transfer using K-means segmented regions
result_image = apply_color_transfer(source_image, target_image, segmented_image, cluster_centers)

# Save and show the result
cv2.imwrite('result_image.jpg', result_image)
cv2.imshow('Result Image', result_image)
cv2.waitKey(0)
cv2.destroyAllWindows()


In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN

def dbscan_segmentation(image, eps=20, min_samples=100):
    # Convert the image to HSV color space
    hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    
    # Reshape the image into a 2D array of pixels
    pixels = hsv_image.reshape((-1, 3))
    
    # Apply DBSCAN clustering
    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    dbscan_labels = dbscan.fit_predict(pixels)
    
    # Reshape the labels back to the original image dimensions
    segmented_image = dbscan_labels.reshape(image.shape[0], image.shape[1])
    
    return segmented_image, dbscan_labels

def apply_color_transfer(source_image, target_image, segmented_image, cluster_centers):
    # Convert both images to HSV color space
    source_hsv = cv2.cvtColor(source_image, cv2.COLOR_BGR2HSV)
    target_hsv = cv2.cvtColor(target_image, cv2.COLOR_BGR2HSV)
    
    # Assign each cluster in the target image the corresponding cluster color from the source
    for i in range(len(cluster_centers)):
        target_cluster_pixels = (segmented_image == i)
        source_cluster_center = cluster_centers[i]
        target_hsv[target_cluster_pixels] = source_cluster_center
    
    # Convert the result back to BGR color space
    result_image = cv2.cvtColor(target_hsv, cv2.COLOR_HSV2BGR)
    
    return result_image

# Load source and target images
source_image = cv2.imread('input.jpg')
target_image = cv2.imread('reference.jpeg')

# Apply DBSCAN segmentation to the target image
segmented_image, dbscan_labels = dbscan_segmentation(target_image, eps=20, min_samples=100)

# Find cluster centers (mean color of each cluster)
unique_labels = np.unique(dbscan_labels)
cluster_centers = []

for label in unique_labels:
    if label != -1:  # Ignore noise (-1)
        cluster_pixels = target_image[dbscan_labels == label]
        cluster_center = np.mean(cluster_pixels, axis=0)
        cluster_centers.append(cluster_center)

# Apply color transfer using DBSCAN segmented regions
result_image = apply_color_transfer(source_image, target_image, segmented_image, cluster_centers)

# Show the result using Matplotlib (no OpenCV windows)
plt.figure(figsize=(10, 5))

# Plot the original target image and the result
plt.subplot(1, 2, 1)
plt.title('Original Target Image')
plt.imshow(cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB))
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title('Result Image with Color Transfer')
plt.imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
plt.axis('off')

plt.show()


In [None]:
import cv2
import numpy as np
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt

#############################################
# Mean-Shift Segmentation
#############################################

def mean_shift_segmentation(img_bgr, sp=21, sr=51):
    filtered = cv2.pyrMeanShiftFiltering(img_bgr, sp, sr)
    return filtered

#############################################
# Edge Detection and Processing
#############################################

def compute_edge_mask(img_bgr, low_threshold=50, high_threshold=150, dilation_iters=1):
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, low_threshold, high_threshold)
    kernel = np.ones((3, 3), np.uint8)
    edges_dilated = cv2.dilate(edges, kernel, iterations=dilation_iters)
    return edges_dilated > 0

#############################################
# Region Labeling from Mean-Shift Output
#############################################

def label_regions(filtered_img):
    gray = cv2.cvtColor(filtered_img, cv2.COLOR_BGR2GRAY)
    ret, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    labeled = label(thresh)
    return labeled

#############################################
# Apply Tint to a Region (Basic Blend)
#############################################

def apply_tint_to_region(img_bgr, region_mask, tint_color, blend=0.5):
    tinted = img_bgr.copy()
    tint_img = np.full_like(img_bgr, tint_color)
    tinted[region_mask] = cv2.addWeighted(img_bgr[region_mask], 1 - blend,
                                          tint_img[region_mask], blend, 0)
    return tinted

#############################################
# Color Transfer Methods
#############################################

def compute_mean_and_cov(image):
    reshaped = image.reshape(-1, 3).astype(np.float32)
    mean = np.mean(reshaped, axis=0)
    cov = np.cov(reshaped, rowvar=False)
    return mean, cov

def sqrtm(matrix, method="svd"):
    if method == "svd":
        U, S, Vt = np.linalg.svd(matrix)
        return np.dot(U, np.dot(np.diag(np.sqrt(S)), Vt))
    elif method == "eigen":
        eigvals, eigvecs = np.linalg.eigh(matrix)
        sqrt_diag = np.diag(np.sqrt(eigvals))
        return eigvecs @ sqrt_diag @ eigvecs.T
    elif method == "cholesky":
        return np.linalg.cholesky(matrix)
    else:
        raise ValueError("Invalid method for matrix square root")

def separable_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    scale = np.sqrt(np.diag(cov_r)) / np.sqrt(np.diag(cov_t))
    transform = np.diag(scale)
    transformed = np.dot((target.reshape(-1, 3) - mu_t), transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

def cholesky_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    L_t = np.linalg.cholesky(cov_t)
    L_r = np.linalg.cholesky(cov_r)
    transform = L_r @ np.linalg.inv(L_t)
    transformed = np.dot((target.reshape(-1, 3) - mu_t), transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

def pca_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    sqrt_cov_t = sqrtm(cov_t, method="eigen")
    sqrt_cov_r = sqrtm(cov_r, method="eigen")
    transform = sqrt_cov_r @ np.linalg.inv(sqrt_cov_t)
    transformed = np.dot((target.reshape(-1, 3) - mu_t), transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

def monge_kantorovitch_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    sqrt_cov_t = sqrtm(cov_t, method="svd")
    inv_sqrt_cov_t = np.linalg.inv(sqrt_cov_t)
    mk_transform = inv_sqrt_cov_t @ sqrtm(sqrt_cov_t @ cov_r @ sqrt_cov_t, method="svd") @ inv_sqrt_cov_t
    transformed = np.dot((target.reshape(-1, 3) - mu_t), mk_transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

# A helper to choose the transfer method
def color_transfer(target, reference, method="pca"):
    if method == "separable":
        return separable_transfer(target, reference)
    elif method == "cholesky":
        return cholesky_transfer(target, reference)
    elif method == "pca":
        return pca_transfer(target, reference)
    elif method == "monge_kantorovitch":
        return monge_kantorovitch_transfer(target, reference)
    else:
        raise ValueError("Unknown color transfer method.")

#############################################
# Integrated Processing Pipeline
#############################################

def process_image(src_path, tgt_path, transfer_method=None, region_based_transfer=False):
    """
    Process the target image by:
    1. Applying mean-shift segmentation.
    2. Detecting edges.
    3. Labeling regions.
    4. For each region, either apply:
         a. A basic tint (using the mean color from the source) with edge-aware blending, or
         b. A color transfer (global or per-region) using the selected transfer method.
    
    Parameters:
      transfer_method: One of "separable", "cholesky", "pca", "monge_kantorovitch".
                       If None, the basic tint (mean blending) is used.
      region_based_transfer: If True, apply the transfer method on each segmented region.
                             If False, apply a global transfer to the whole image.
    """
    src = cv2.imread(src_path)
    tgt = cv2.imread(tgt_path)
    if src is None or tgt is None:
        raise IOError("Could not load source or target image.")

    # Resize source to target size if necessary
    if src.shape != tgt.shape:
        src = cv2.resize(src, (tgt.shape[1], tgt.shape[0]))

    # Apply mean-shift segmentation and edge detection
    ms_tgt = mean_shift_segmentation(tgt, sp=21, sr=51)
    edge_mask = compute_edge_mask(tgt, low_threshold=50, high_threshold=150, dilation_iters=1)
    labels = label_regions(ms_tgt)
    props = regionprops(labels)
    
    out_img = tgt.copy()

    # Option 1: Global color transfer
    if transfer_method is not None and not region_based_transfer:
        # Apply the selected color transfer to the full images.
        out_img = color_transfer(tgt, src, method=transfer_method)
        return out_img

    # Option 2: Per-region processing (basic tint or color transfer)
    for prop in props:
        region_mask = (labels == prop.label)
        if np.sum(region_mask) < 50:
            continue  # Skip tiny regions
        
        # Compute edge ratio in the region for adaptive blending.
        edge_ratio = np.sum(edge_mask & region_mask) / np.sum(region_mask)
        # Adjust blend factor based on edge density (lower blend when more edges are present)
        blend_factor = max(0.3, 1 - edge_ratio)
        
        if transfer_method is None:
            # Basic tint: use the region's mean color from the source.
            mean_color = cv2.mean(src, mask=(region_mask.astype(np.uint8) * 255))[:3]
            mean_color = np.array(mean_color, dtype=np.uint8)
            region_result = apply_tint_to_region(out_img, region_mask, mean_color, blend=blend_factor)
        else:
            # Region-based color transfer:
            # Extract the region from both source and target images.
            src_region = src.copy()
            tgt_region = out_img.copy()
            # Zero out pixels outside the region (we assume background remains unchanged).
            src_region[~region_mask] = 0
            tgt_region[~region_mask] = 0
            
            # It may help to process only the nonzero region.
            # Here, we compute the transfer on the whole region; if necessary, you could
            # crop to a bounding box for efficiency.
            transferred_region = color_transfer(tgt_region, src_region, method=transfer_method)
            # Blend the transferred region with the original target to preserve edge details.
            region_result = out_img.copy()
            region_result[region_mask] = cv2.addWeighted(out_img[region_mask], 1 - blend_factor,
                                                         transferred_region[region_mask], blend_factor, 0)
        
        # Update the output image with the processed region.
        out_img[region_mask] = region_result[region_mask]

    return out_img

#############################################
# Example Usage
#############################################

if __name__ == '__main__':

    # Example 2: Global PCA color transfer applied on the whole image.
    output_global = process_image( 'input2.jpg', 'reference2.jpeg', transfer_method="monge_kantorovitch", region_based_transfer=True)
    cv2.imwrite('output_global_2.jpg', output_global)







In [2]:
import cv2
import numpy as np
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt

#############################################
# Mean-Shift Segmentation
#############################################

def mean_shift_segmentation(img_bgr, sp=21, sr=51):
    filtered = cv2.pyrMeanShiftFiltering(img_bgr, sp, sr)
    return filtered

#############################################
# Edge Detection and Processing
#############################################

def compute_edge_mask(img_bgr, low_threshold=50, high_threshold=150, dilation_iters=1):
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, low_threshold, high_threshold)
    kernel = np.ones((3, 3), np.uint8)
    edges_dilated = cv2.dilate(edges, kernel, iterations=dilation_iters)
    return edges_dilated > 0

#############################################
# Region Labeling from Mean-Shift Output
#############################################

def label_regions(filtered_img):
    gray = cv2.cvtColor(filtered_img, cv2.COLOR_BGR2GRAY)
    ret, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    labeled = label(thresh)
    return labeled

#############################################
# Apply Tint to a Region (Basic Blend)
#############################################

def apply_tint_to_region(img_bgr, region_mask, tint_color, blend=0.5):
    tinted = img_bgr.copy()
    tint_img = np.full_like(img_bgr, tint_color)
    tinted[region_mask] = cv2.addWeighted(img_bgr[region_mask], 1 - blend,
                                          tint_img[region_mask], blend, 0)
    return tinted

#############################################
# Color Transfer Methods
#############################################

def compute_mean_and_cov(image):
    reshaped = image.reshape(-1, 3).astype(np.float32)
    mean = np.mean(reshaped, axis=0)
    cov = np.cov(reshaped, rowvar=False)
    return mean, cov

def sqrtm(matrix, method="svd"):
    if method == "svd":
        U, S, Vt = np.linalg.svd(matrix)
        return np.dot(U, np.dot(np.diag(np.sqrt(S)), Vt))
    elif method == "eigen":
        eigvals, eigvecs = np.linalg.eigh(matrix)
        sqrt_diag = np.diag(np.sqrt(eigvals))
        return eigvecs @ sqrt_diag @ eigvecs.T
    elif method == "cholesky":
        return np.linalg.cholesky(matrix)
    else:
        raise ValueError("Invalid method for matrix square root")

def separable_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    scale = np.sqrt(np.diag(cov_r)) / np.sqrt(np.diag(cov_t))
    transform = np.diag(scale)
    transformed = np.dot((target.reshape(-1, 3) - mu_t), transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

def cholesky_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    L_t = np.linalg.cholesky(cov_t)
    L_r = np.linalg.cholesky(cov_r)
    transform = L_r @ np.linalg.inv(L_t)
    transformed = np.dot((target.reshape(-1, 3) - mu_t), transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

def pca_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    sqrt_cov_t = sqrtm(cov_t, method="eigen")
    sqrt_cov_r = sqrtm(cov_r, method="eigen")
    transform = sqrt_cov_r @ np.linalg.inv(sqrt_cov_t)
    transformed = np.dot((target.reshape(-1, 3) - mu_t), transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

def monge_kantorovitch_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    sqrt_cov_t = sqrtm(cov_t, method="svd")
    inv_sqrt_cov_t = np.linalg.inv(sqrt_cov_t)
    mk_transform = inv_sqrt_cov_t @ sqrtm(sqrt_cov_t @ cov_r @ sqrt_cov_t, method="svd") @ inv_sqrt_cov_t
    transformed = np.dot((target.reshape(-1, 3) - mu_t), mk_transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

# A helper to choose the transfer method
def color_transfer(target, reference, method="pca"):
    if method == "separable":
        return separable_transfer(target, reference)
    elif method == "cholesky":
        return cholesky_transfer(target, reference)
    elif method == "pca":
        return pca_transfer(target, reference)
    elif method == "monge_kantorovitch":
        return monge_kantorovitch_transfer(target, reference)
    else:
        raise ValueError("Unknown color transfer method.")

#############################################
# Integrated Processing Pipeline
#############################################

def process_image(src_path, tgt_path, transfer_method=None, region_based_transfer=False):
    """
    Process the target image by:
    1. Applying mean-shift segmentation.
    2. Detecting edges.
    3. Labeling regions.
    4. For each region, either apply:
         a. A basic tint (using the mean color from the source) with edge-aware blending, or
         b. A color transfer (global or per-region) using the selected transfer method.
    
    Parameters:
      transfer_method: One of "separable", "cholesky", "pca", "monge_kantorovitch".
                       If None, the basic tint (mean blending) is used.
      region_based_transfer: If True, apply the transfer method on each segmented region.
                             If False, apply a global transfer to the whole image.
    """
    src = cv2.imread(src_path)
    tgt = cv2.imread(tgt_path)
    if src is None or tgt is None:
        raise IOError("Could not load source or target image.")

    # Resize source to target size if necessary
    if src.shape != tgt.shape:
        src = cv2.resize(src, (tgt.shape[1], tgt.shape[0]))

    # Apply mean-shift segmentation and edge detection
    ms_tgt = mean_shift_segmentation(tgt, sp=21, sr=51)
    edge_mask = compute_edge_mask(tgt, low_threshold=50, high_threshold=150, dilation_iters=1)
    labels = label_regions(ms_tgt)
    props = regionprops(labels)
    
    out_img = tgt.copy()

    # Option 1: Global color transfer
    if transfer_method is not None and not region_based_transfer:
        # Apply the selected color transfer to the full images.
        out_img = color_transfer(tgt, src, method=transfer_method)
        return out_img

    # Option 2: Per-region processing (basic tint or color transfer)
    for prop in props:
        region_mask = (labels == prop.label)
        if np.sum(region_mask) < 50:
            continue  # Skip tiny regions
        
        # Compute edge ratio in the region for adaptive blending.
        edge_ratio = np.sum(edge_mask & region_mask) / np.sum(region_mask)
        # Adjust blend factor based on edge density (lower blend when more edges are present)
        blend_factor = max(0.3, 1 - edge_ratio)
        
        if transfer_method is None:
            # Basic tint: use the region's mean color from the source.
            mean_color = cv2.mean(src, mask=(region_mask.astype(np.uint8) * 255))[:3]
            mean_color = np.array(mean_color, dtype=np.uint8)
            region_result = apply_tint_to_region(out_img, region_mask, mean_color, blend=blend_factor)
        else:
            # Region-based color transfer:
            # Extract the region from both source and target images.
            src_region = src.copy()
            tgt_region = out_img.copy()
            # Zero out pixels outside the region (we assume background remains unchanged).
            src_region[~region_mask] = 0
            tgt_region[~region_mask] = 0
            
            # It may help to process only the nonzero region.
            # Here, we compute the transfer on the whole region; if necessary, you could
            # crop to a bounding box for efficiency.
            transferred_region = color_transfer(tgt_region, src_region, method=transfer_method)
            # Blend the transferred region with the original target to preserve edge details.
            region_result = out_img.copy()
            region_result[region_mask] = cv2.addWeighted(out_img[region_mask], 1 - blend_factor,
                                                         transferred_region[region_mask], blend_factor, 0)
        
        # Update the output image with the processed region.
        out_img[region_mask] = region_result[region_mask]

    return out_img

#############################################
# Example Usage
#############################################

if __name__ == '__main__':

    # Example 2: Global PCA color transfer applied on the whole image.
    output_global = process_image('reference4.jpeg', 'input4.jpg', transfer_method="pca", region_based_transfer=False)
    cv2.imwrite('output_global_4.jpg', output_global)

In [3]:
import cv2
import numpy as np
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt

#############################################
# Mean-Shift Segmentation
#############################################

def mean_shift_segmentation(img_bgr, sp=21, sr=51):
    filtered = cv2.pyrMeanShiftFiltering(img_bgr, sp, sr)
    return filtered

#############################################
# Edge Detection and Processing
#############################################

def compute_edge_mask(img_bgr, low_threshold=50, high_threshold=150, dilation_iters=1):
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, low_threshold, high_threshold)
    kernel = np.ones((3, 3), np.uint8)
    edges_dilated = cv2.dilate(edges, kernel, iterations=dilation_iters)
    return edges_dilated > 0

#############################################
# Region Labeling from Mean-Shift Output
#############################################

def label_regions(filtered_img):
    gray = cv2.cvtColor(filtered_img, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    labeled = label(thresh)
    return labeled

#############################################
# Apply Tint to a Region (Basic Blend)
#############################################

def apply_tint_to_region(img_bgr, region_mask, tint_color, blend=0.5):
    tinted = img_bgr.copy()
    tint_img = np.full_like(img_bgr, tint_color)
    tinted[region_mask] = cv2.addWeighted(img_bgr[region_mask], 1 - blend,
                                          tint_img[region_mask], blend, 0)
    return tinted

#############################################
# Color Transfer Methods (Global)
#############################################

def compute_mean_and_cov(image):
    reshaped = image.reshape(-1, 3).astype(np.float32)
    mean = np.mean(reshaped, axis=0)
    cov = np.cov(reshaped, rowvar=False)
    return mean, cov

def sqrtm(matrix, method="svd"):
    if method == "svd":
        U, S, Vt = np.linalg.svd(matrix)
        return np.dot(U, np.dot(np.diag(np.sqrt(S)), Vt))
    elif method == "eigen":
        eigvals, eigvecs = np.linalg.eigh(matrix)
        sqrt_diag = np.diag(np.sqrt(eigvals))
        return eigvecs @ sqrt_diag @ eigvecs.T
    elif method == "cholesky":
        return np.linalg.cholesky(matrix)
    else:
        raise ValueError("Invalid method for matrix square root")

def separable_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    scale = np.sqrt(np.diag(cov_r)) / np.sqrt(np.diag(cov_t))
    transform = np.diag(scale)
    transformed = np.dot((target.reshape(-1, 3) - mu_t), transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

def cholesky_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    L_t = np.linalg.cholesky(cov_t)
    L_r = np.linalg.cholesky(cov_r)
    transform = L_r @ np.linalg.inv(L_t)
    transformed = np.dot((target.reshape(-1, 3) - mu_t), transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

def pca_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    sqrt_cov_t = sqrtm(cov_t, method="eigen")
    sqrt_cov_r = sqrtm(cov_r, method="eigen")
    transform = sqrt_cov_r @ np.linalg.inv(sqrt_cov_t)
    transformed = np.dot((target.reshape(-1, 3) - mu_t), transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

def monge_kantorovitch_transfer(target, reference):
    target = target.astype(np.float32) / 255.0
    reference = reference.astype(np.float32) / 255.0
    mu_t, cov_t = compute_mean_and_cov(target)
    mu_r, cov_r = compute_mean_and_cov(reference)
    sqrt_cov_t = sqrtm(cov_t, method="svd")
    inv_sqrt_cov_t = np.linalg.inv(sqrt_cov_t)
    mk_transform = inv_sqrt_cov_t @ sqrtm(sqrt_cov_t @ cov_r @ sqrt_cov_t, method="svd") @ inv_sqrt_cov_t
    transformed = np.dot((target.reshape(-1, 3) - mu_t), mk_transform.T) + mu_r
    transformed = np.clip(transformed, 0, 1)
    return (transformed.reshape(target.shape) * 255).astype(np.uint8)

#############################################
# LAB-Based PCA Transfer (Preserve Luminance)
#############################################

def pca_transfer_lab(target, reference):
    """
    Applies PCA-based transfer on the chrominance (A and B) channels in LAB.
    Luminance (L) is preserved from the target.
    """
    # Convert images to LAB color space
    target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
    ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32)
    
    # Separate L (luminance) and AB (chrominance)
    L_t = target_lab[:, :, 0]
    AB_t = target_lab[:, :, 1:3]
    AB_r = ref_lab[:, :, 1:3]
    
    # Normalize AB channels to [0, 1]
    AB_t /= 255.0
    AB_r /= 255.0
    
    # Reshape AB channels to (num_pixels x 2)
    AB_t_flat = AB_t.reshape(-1, 2)
    AB_r_flat = AB_r.reshape(-1, 2)
    
    # Compute means and covariances for chrominance channels
    mu_t = np.mean(AB_t_flat, axis=0)
    mu_r = np.mean(AB_r_flat, axis=0)
    cov_t = np.cov(AB_t_flat, rowvar=False)
    cov_r = np.cov(AB_r_flat, rowvar=False)
    
    # Compute square roots of covariance matrices using eigen decomposition
    sqrt_cov_t = sqrtm(cov_t, method="eigen")
    sqrt_cov_r = sqrtm(cov_r, method="eigen")
    
    # Compute transformation matrix
    transform = sqrt_cov_r @ np.linalg.inv(sqrt_cov_t)
    
    # Apply transformation to the AB channels
    AB_transformed = ((AB_t_flat - mu_t) @ transform.T) + mu_r
    AB_transformed = np.clip(AB_transformed, 0, 1).reshape(AB_t.shape)
    
    # Scale AB back to [0, 255]
    AB_transformed = (AB_transformed * 255).astype(np.uint8)
    
    # Combine original L with transformed AB channels
    L_t = np.clip(L_t, 0, 255).astype(np.uint8)
    lab_transferred = cv2.merge((L_t, AB_transformed[:, :, 0], AB_transformed[:, :, 1]))
    
    # Convert back to BGR color space
    return cv2.cvtColor(lab_transferred, cv2.COLOR_LAB2BGR)

#############################################
# Helper for Color Transfer Choice
#############################################

def color_transfer(target, reference, method="pca"):
    if method == "separable":
        return separable_transfer(target, reference)
    elif method == "cholesky":
        return cholesky_transfer(target, reference)
    elif method == "pca":
        return pca_transfer(target, reference)
    elif method == "monge_kantorovitch":
        return monge_kantorovitch_transfer(target, reference)
    elif method == "pca_lab":
        return pca_transfer_lab(target, reference)
    else:
        raise ValueError("Unknown color transfer method.")

#############################################
# Integrated Processing Pipeline
#############################################

def process_image(src_path, tgt_path, transfer_method=None, region_based_transfer=False):
    """
    Process the target image by:
      1. Applying mean-shift segmentation.
      2. Detecting edges.
      3. Labeling regions.
      4. For each region, either apply:
             a. A basic tint (using the mean color from the source) with edge-aware blending, or
             b. A color transfer (global or per-region) using the selected transfer method.
    
    Parameters:
      transfer_method: One of "separable", "cholesky", "pca", "monge_kantorovitch", "pca_lab".
                       If None, the basic tint (mean blending) is used.
      region_based_transfer: If True, apply the transfer method on each segmented region.
                             If False, apply a global transfer to the whole image.
    """
    src = cv2.imread(src_path)
    tgt = cv2.imread(tgt_path)
    if src is None or tgt is None:
        raise IOError("Could not load source or target image.")

    # Resize source to target size if necessary
    if src.shape != tgt.shape:
        src = cv2.resize(src, (tgt.shape[1], tgt.shape[0]))

    # Apply mean-shift segmentation and edge detection
    ms_tgt = mean_shift_segmentation(tgt, sp=21, sr=51)
    edge_mask = compute_edge_mask(tgt, low_threshold=50, high_threshold=150, dilation_iters=1)
    labels = label_regions(ms_tgt)
    props = regionprops(labels)
    
    out_img = tgt.copy()

    # Global color transfer
    if transfer_method is not None and not region_based_transfer:
        out_img = color_transfer(tgt, src, method=transfer_method)
        return out_img

    # Per-region processing (basic tint or region-based color transfer)
    for prop in props:
        region_mask = (labels == prop.label)
        if np.sum(region_mask) < 50:
            continue  # Skip tiny regions
        
        # Compute edge ratio in the region for adaptive blending.
        edge_ratio = np.sum(edge_mask & region_mask) / np.sum(region_mask)
        blend_factor = max(0.3, 1 - edge_ratio)
        
        if transfer_method is None:
            # Basic tint: use the region's mean color from the source.
            mean_color = cv2.mean(src, mask=(region_mask.astype(np.uint8) * 255))[:3]
            mean_color = np.array(mean_color, dtype=np.uint8)
            region_result = apply_tint_to_region(out_img, region_mask, mean_color, blend=blend_factor)
        else:
            # Region-based color transfer:
            src_region = src.copy()
            tgt_region = out_img.copy()
            src_region[~region_mask] = 0
            tgt_region[~region_mask] = 0
            
            transferred_region = color_transfer(tgt_region, src_region, method=transfer_method)
            region_result = out_img.copy()
            region_result[region_mask] = cv2.addWeighted(out_img[region_mask], 1 - blend_factor,
                                                         transferred_region[region_mask], blend_factor, 0)
        
        # Update output image for this region
        out_img[region_mask] = region_result[region_mask]

    return out_img

#############################################
# Example Usage
#############################################

if __name__ == '__main__':
    # Example: Global LAB-based PCA color transfer applied on the whole image.
    output_global = process_image('input4.jpg', 'reference4.jpeg', transfer_method="pca_lab", region_based_transfer=False)
    cv2.imwrite('output_global_2.jpg', output_global)


In [None]:
import cv2
import numpy as np
from skimage.measure import label, regionprops
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
from PIL import Image, ImageTk

#############################################
# Processing Functions (Example: LAB-based PCA Transfer)
#############################################

def sqrtm(matrix, method="svd"):
    # Simple square-root using eigen decomposition.
    eigvals, eigvecs = np.linalg.eigh(matrix)
    return eigvecs @ np.diag(np.sqrt(eigvals)) @ eigvecs.T

def compute_mean_and_cov(image):
    reshaped = image.reshape(-1, 3).astype(np.float32)
    mean = np.mean(reshaped, axis=0)
    cov = np.cov(reshaped, rowvar=False)
    return mean, cov

def pca_transfer_lab(target, reference):
    """
    Applies PCA-based transfer on the chrominance (A and B) channels in LAB.
    Luminance (L) is preserved from the target.
    """
    target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
    ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32)
    L_t = target_lab[:, :, 0]
    AB_t = target_lab[:, :, 1:3]
    AB_r = ref_lab[:, :, 1:3]
    AB_t /= 255.0
    AB_r /= 255.0
    AB_t_flat = AB_t.reshape(-1, 2)
    AB_r_flat = AB_r.reshape(-1, 2)
    mu_t = np.mean(AB_t_flat, axis=0)
    mu_r = np.mean(AB_r_flat, axis=0)
    cov_t = np.cov(AB_t_flat, rowvar=False)
    cov_r = np.cov(AB_r_flat, rowvar=False)
    sqrt_cov_t = sqrtm(cov_t, method="eigen")
    sqrt_cov_r = sqrtm(cov_r, method="eigen")
    transform = sqrt_cov_r @ np.linalg.inv(sqrt_cov_t)
    AB_transformed = ((AB_t_flat - mu_t) @ transform.T) + mu_r
    AB_transformed = np.clip(AB_transformed, 0, 1).reshape(AB_t.shape)
    AB_transformed = (AB_transformed * 255).astype(np.uint8)
    L_t = np.clip(L_t, 0, 255).astype(np.uint8)
    lab_transferred = cv2.merge((L_t, AB_transformed[:, :, 0], AB_transformed[:, :, 1]))
    return cv2.cvtColor(lab_transferred, cv2.COLOR_LAB2BGR)

def process_sequential(target, ref_list):
    """
    Sequentially applies the color transfer.
    target: initial target image (BGR).
    ref_list: list of reference images (BGR).
    Returns a list of tuples: (label, image) with intermediate outputs.
    """
    results = [("Target", target.copy())]
    current = target.copy()
    for i, ref in enumerate(ref_list, start=1):
        # Resize reference if necessary.
        if ref.shape != current.shape:
            ref = cv2.resize(ref, (current.shape[1], current.shape[0]))
        current = pca_transfer_lab(current, ref)
        results.append((f"Output {i}", current.copy()))
    return results

#############################################
# Advanced Tkinter UI
#############################################

class ColorTransferApp(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Sequential Color Transfer")
        self.geometry("1000x700")
        
        self.target_img = None       # BGR numpy array
        self.ref_imgs = []           # List of BGR numpy arrays
        self.intermediate_results = []  # List of (label, BGR image)
        
        self.create_widgets()
    
    def create_widgets(self):
        # Top frame for image selection
        top_frame = tk.Frame(self)
        top_frame.pack(padx=10, pady=10, fill=tk.X)
        
        self.btn_load_target = tk.Button(top_frame, text="Load Target Image", command=self.load_target)
        self.btn_load_target.pack(side=tk.LEFT, padx=5)
        
        self.btn_load_refs = tk.Button(top_frame, text="Load Reference Images", command=self.load_references)
        self.btn_load_refs.pack(side=tk.LEFT, padx=5)
        
        self.btn_run = tk.Button(top_frame, text="Run Sequential Transfer", command=self.run_transfer)
        self.btn_run.pack(side=tk.LEFT, padx=5)
        
        self.btn_save = tk.Button(top_frame, text="Save Final Output", command=self.save_final)
        self.btn_save.pack(side=tk.LEFT, padx=5)
        
        # Middle frame for displaying target and reference thumbnails
        mid_frame = tk.Frame(self)
        mid_frame.pack(padx=10, pady=5, fill=tk.X)
        
        self.lbl_target = tk.Label(mid_frame, text="Target Image Not Loaded", bd=2, relief=tk.SOLID, width=200, height=200)
        self.lbl_target.pack(side=tk.LEFT, padx=10)
        
        ref_frame = tk.Frame(mid_frame)
        ref_frame.pack(side=tk.LEFT, padx=10, fill=tk.Y)
        
        tk.Label(ref_frame, text="Reference Images").pack()
        self.lst_refs = tk.Listbox(ref_frame, width=30, height=10)
        self.lst_refs.pack(side=tk.LEFT, padx=5)
        
        self.ref_thumbs = []  # List to store thumbnail PhotoImage objects
        
        # Scrollable frame for intermediate outputs
        output_frame = tk.Frame(self)
        output_frame.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)
        
        canvas = tk.Canvas(output_frame)
        canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
        
        scrollbar = tk.Scrollbar(output_frame, orient="horizontal", command=canvas.xview)
        scrollbar.pack(side=tk.BOTTOM, fill=tk.X)
        
        canvas.configure(xscrollcommand=scrollbar.set)
        canvas.bind("<Configure>", lambda e: canvas.configure(scrollregion=canvas.bbox("all")))
        
        self.results_frame = tk.Frame(canvas)
        canvas.create_window((0, 0), window=self.results_frame, anchor="nw")
    
    def load_target(self):
        path = filedialog.askopenfilename(title="Select Target Image", filetypes=[("Image Files", "*.png *.jpg *.jpeg")])
        if path:
            img = cv2.imread(path)
            if img is None:
                messagebox.showerror("Error", "Failed to load target image.")
                return
            self.target_img = img
            self.display_image(self.lbl_target, img)
    
    def load_references(self):
        paths = filedialog.askopenfilenames(title="Select Reference Images", filetypes=[("Image Files", "*.png *.jpg *.jpeg")])
        if paths:
            self.ref_imgs = []
            self.lst_refs.delete(0, tk.END)
            self.ref_thumbs = []
            for path in paths:
                img = cv2.imread(path)
                if img is not None:
                    self.ref_imgs.append(img)
                    # Create thumbnail for display
                    thumb = self.create_thumbnail(img, 100, 100)
                    self.ref_thumbs.append(thumb)
                    self.lst_refs.insert(tk.END, path.split("/")[-1])
    
    def run_transfer(self):
        if self.target_img is None:
            messagebox.showwarning("No Target", "Please load a target image.")
            return
        if not self.ref_imgs:
            messagebox.showwarning("No References", "Please load at least one reference image.")
            return
        
        self.intermediate_results = process_sequential(self.target_img, self.ref_imgs)
        self.display_intermediate_results()
    
    def display_intermediate_results(self):
        # Clear previous results in the results_frame.
        for widget in self.results_frame.winfo_children():
            widget.destroy()
        
        for label_text, img in self.intermediate_results:
            frame = tk.Frame(self.results_frame, bd=2, relief=tk.RIDGE)
            frame.pack(side=tk.LEFT, padx=5, pady=5)
            tk.Label(frame, text=label_text).pack()
            img_label = tk.Label(frame)
            img_label.pack()
            self.display_image(img_label, img, width=200, height=200)
    
    def save_final(self):
        if not self.intermediate_results:
            messagebox.showinfo("No Output", "No output available to save.")
            return
        final_label, final_img = self.intermediate_results[-1]
        path = filedialog.asksaveasfilename(title="Save Final Output", defaultextension=".jpg",
                                              filetypes=[("JPEG Files", "*.jpg"), ("PNG Files", "*.png")])
        if path:
            cv2.imwrite(path, final_img)
            messagebox.showinfo("Saved", f"Final output saved to {path}")
    
    def display_image(self, label, cv_img, width=None, height=None):
        # Convert BGR to RGB and then to PIL image.
        rgb = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(rgb)
        if width and height:
            pil_img.thumbnail((width, height))
        else:
            pil_img.thumbnail((200, 200))
        imgtk = ImageTk.PhotoImage(pil_img)
        label.configure(image=imgtk)
        label.image = imgtk  # keep a reference
        
    def create_thumbnail(self, cv_img, width, height):
        rgb = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(rgb)
        pil_img.thumbnail((width, height))
        return ImageTk.PhotoImage(pil_img)

if __name__ == '__main__':
    app = ColorTransferApp()
    app.mainloop()


In [4]:
import cv2
import numpy as np
from skimage.measure import label, regionprops
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk

#############################################
# Processing Functions (Example: LAB-based PCA Transfer)
#############################################

def sqrtm(matrix, method="svd"):
    # Simple square-root using eigen decomposition.
    eigvals, eigvecs = np.linalg.eigh(matrix)
    return eigvecs @ np.diag(np.sqrt(eigvals)) @ eigvecs.T

def compute_mean_and_cov(image):
    reshaped = image.reshape(-1, 3).astype(np.float32)
    mean = np.mean(reshaped, axis=0)
    cov = np.cov(reshaped, rowvar=False)
    return mean, cov

def pca_transfer_lab(target, reference):
    """
    Applies PCA-based transfer on the chrominance (A and B) channels in LAB.
    Luminance (L) is preserved from the target.
    """
    target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
    ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32)
    L_t = target_lab[:, :, 0]
    AB_t = target_lab[:, :, 1:3]
    AB_r = ref_lab[:, :, 1:3]
    AB_t /= 255.0
    AB_r /= 255.0
    AB_t_flat = AB_t.reshape(-1, 2)
    AB_r_flat = AB_r.reshape(-1, 2)
    mu_t = np.mean(AB_t_flat, axis=0)
    mu_r = np.mean(AB_r_flat, axis=0)
    cov_t = np.cov(AB_t_flat, rowvar=False)
    cov_r = np.cov(AB_r_flat, rowvar=False)
    sqrt_cov_t = sqrtm(cov_t, method="eigen")
    sqrt_cov_r = sqrtm(cov_r, method="eigen")
    transform = sqrt_cov_r @ np.linalg.inv(sqrt_cov_t)
    AB_transformed = ((AB_t_flat - mu_t) @ transform.T) + mu_r
    AB_transformed = np.clip(AB_transformed, 0, 1).reshape(AB_t.shape)
    AB_transformed = (AB_transformed * 255).astype(np.uint8)
    L_t = np.clip(L_t, 0, 255).astype(np.uint8)
    lab_transferred = cv2.merge((L_t, AB_transformed[:, :, 0], AB_transformed[:, :, 1]))
    return cv2.cvtColor(lab_transferred, cv2.COLOR_LAB2BGR)

def process_sequential(target, ref_list):
    """
    Sequentially applies the color transfer.
    target: initial target image (BGR).
    ref_list: list of reference images (BGR) in the desired order.
    Returns a list of tuples: (label, image) with intermediate outputs.
    """
    results = [("Target", target.copy())]
    current = target.copy()
    for i, ref in enumerate(ref_list, start=1):
        if ref.shape != current.shape:
            ref = cv2.resize(ref, (current.shape[1], current.shape[0]))
        current = pca_transfer_lab(current, ref)
        results.append((f"Output {i}", current.copy()))
    return results

#############################################
# Advanced Tkinter UI with Reference Ordering
#############################################

class ColorTransferApp(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Sequential Color Transfer")
        self.geometry("1100x750")
        
        self.target_img = None        # BGR numpy array for target image
        self.ref_imgs = []            # List of BGR numpy arrays for reference images
        self.ref_files = []           # Corresponding file names for references
        self.intermediate_results = []  # List of (label, BGR image)
        
        self.create_widgets()
    
    def create_widgets(self):
        # Top frame for buttons
        top_frame = tk.Frame(self)
        top_frame.pack(padx=10, pady=10, fill=tk.X)
        
        tk.Button(top_frame, text="Load Target Image", command=self.load_target).pack(side=tk.LEFT, padx=5)
        tk.Button(top_frame, text="Load Reference Images", command=self.load_references).pack(side=tk.LEFT, padx=5)
        tk.Button(top_frame, text="Run Sequential Transfer", command=self.run_transfer).pack(side=tk.LEFT, padx=5)
        tk.Button(top_frame, text="Save Final Output", command=self.save_final).pack(side=tk.LEFT, padx=5)
        
        # Middle frame for target and reference order
        mid_frame = tk.Frame(self)
        mid_frame.pack(padx=10, pady=5, fill=tk.X)
        
        # Target image display
        self.lbl_target = tk.Label(mid_frame, text="Target Image Not Loaded", bd=2, relief=tk.SOLID, width=200, height=200)
        self.lbl_target.pack(side=tk.LEFT, padx=10)
        
        # Frame for reference list with reordering buttons
        ref_list_frame = tk.Frame(mid_frame)
        ref_list_frame.pack(side=tk.LEFT, padx=10, fill=tk.Y)
        
        tk.Label(ref_list_frame, text="Reference Images (Order)").pack()
        self.lst_refs = tk.Listbox(ref_list_frame, width=40, height=10)
        self.lst_refs.pack(side=tk.LEFT, padx=5, pady=5)
        
        btns_frame = tk.Frame(ref_list_frame)
        btns_frame.pack(side=tk.LEFT, padx=5)
        tk.Button(btns_frame, text="Move Up", command=self.move_up).pack(pady=2)
        tk.Button(btns_frame, text="Move Down", command=self.move_down).pack(pady=2)
        
        # Scrollable frame for intermediate outputs (horizontal)
        output_frame = tk.Frame(self)
        output_frame.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)
        
        self.out_canvas = tk.Canvas(output_frame)
        self.out_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
        out_scrollbar = tk.Scrollbar(output_frame, orient="horizontal", command=self.out_canvas.xview)
        out_scrollbar.pack(side=tk.BOTTOM, fill=tk.X)
        self.out_canvas.configure(xscrollcommand=out_scrollbar.set)
        self.results_frame = tk.Frame(self.out_canvas)
        self.out_canvas.create_window((0, 0), window=self.results_frame, anchor="nw")
        self.results_frame.bind("<Configure>", lambda e: self.out_canvas.configure(scrollregion=self.out_canvas.bbox("all")))
    
    def load_target(self):
        path = filedialog.askopenfilename(title="Select Target Image", filetypes=[("Image Files", "*.png *.jpg *.jpeg")])
        if path:
            img = cv2.imread(path)
            if img is None:
                messagebox.showerror("Error", "Failed to load target image.")
                return
            self.target_img = img
            self.display_image(self.lbl_target, img, 200, 200)
    
    def load_references(self):
        paths = filedialog.askopenfilenames(title="Select Reference Images", filetypes=[("Image Files", "*.png *.jpg *.jpeg")])
        if paths:
            self.ref_imgs = []
            self.ref_files = []
            self.lst_refs.delete(0, tk.END)
            for path in paths:
                img = cv2.imread(path)
                if img is not None:
                    self.ref_imgs.append(img)
                    self.ref_files.append(path.split("/")[-1])
                    self.lst_refs.insert(tk.END, path.split("/")[-1])
    
    def move_up(self):
        # Moves the selected item up in the list
        selected = self.lst_refs.curselection()
        if not selected:
            return
        index = selected[0]
        if index == 0:
            return
        # Swap in listbox
        item_text = self.lst_refs.get(index)
        self.lst_refs.delete(index)
        self.lst_refs.insert(index - 1, item_text)
        self.lst_refs.selection_set(index - 1)
        # Swap in our reference lists
        self.ref_imgs[index], self.ref_imgs[index - 1] = self.ref_imgs[index - 1], self.ref_imgs[index]
        self.ref_files[index], self.ref_files[index - 1] = self.ref_files[index - 1], self.ref_files[index]
    
    def move_down(self):
        # Moves the selected item down in the list
        selected = self.lst_refs.curselection()
        if not selected:
            return
        index = selected[0]
        if index == self.lst_refs.size() - 1:
            return
        item_text = self.lst_refs.get(index)
        self.lst_refs.delete(index)
        self.lst_refs.insert(index + 1, item_text)
        self.lst_refs.selection_set(index + 1)
        # Swap in our reference lists
        self.ref_imgs[index], self.ref_imgs[index + 1] = self.ref_imgs[index + 1], self.ref_imgs[index]
        self.ref_files[index], self.ref_files[index + 1] = self.ref_files[index + 1], self.ref_files[index]
    
    def run_transfer(self):
        if self.target_img is None:
            messagebox.showwarning("No Target", "Please load a target image.")
            return
        if not self.ref_imgs:
            messagebox.showwarning("No References", "Please load at least one reference image.")
            return
        
        # Process sequentially in the order from the listbox.
        self.intermediate_results = process_sequential(self.target_img, self.ref_imgs)
        self.display_intermediate_results()
    
    def display_intermediate_results(self):
        # Clear previous results
        for widget in self.results_frame.winfo_children():
            widget.destroy()
        
        for label_text, img in self.intermediate_results:
            frame = tk.Frame(self.results_frame, bd=2, relief=tk.RIDGE)
            frame.pack(side=tk.LEFT, padx=5, pady=5)
            tk.Label(frame, text=label_text).pack()
            img_label = tk.Label(frame)
            img_label.pack()
            self.display_image(img_label, img, 200, 200)
    
    def save_final(self):
        if not self.intermediate_results:
            messagebox.showinfo("No Output", "No output available to save.")
            return
        final_label, final_img = self.intermediate_results[-1]
        path = filedialog.asksaveasfilename(title="Save Final Output", defaultextension=".jpg",
                                              filetypes=[("JPEG Files", "*.jpg"), ("PNG Files", "*.png")])
        if path:
            cv2.imwrite(path, final_img)
            messagebox.showinfo("Saved", f"Final output saved to {path}")
    
    def display_image(self, label, cv_img, width, height):
        # Convert BGR to RGB, then to PIL image and display.
        rgb = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(rgb)
        pil_img.thumbnail((width, height))
        imgtk = ImageTk.PhotoImage(pil_img)
        label.configure(image=imgtk)
        label.image = imgtk  # keep a reference

if __name__ == '__main__':
    app = ColorTransferApp()
    app.mainloop()


In [4]:
import cv2
import numpy as np
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
from sklearn.cluster import KMeans
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
#############################################
# Processing Functions (Example: LAB-based PCA Transfer)
#############################################

def sqrtm(matrix, method="svd"):
    # Simple square-root using eigen decomposition.
    eigvals, eigvecs = np.linalg.eigh(matrix)
    return eigvecs @ np.diag(np.sqrt(eigvals)) @ eigvecs.T

def compute_mean_and_cov(image):
    reshaped = image.reshape(-1, 3).astype(np.float32)
    mean = np.mean(reshaped, axis=0)
    cov = np.cov(reshaped, rowvar=False)
    return mean, cov

def pca_transfer_lab(target, reference):
    """
    Applies PCA-based transfer on the chrominance (A and B) channels in LAB.
    Luminance (L) is preserved from the target.
    """
    target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
    ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32)
    L_t = target_lab[:, :, 0]
    AB_t = target_lab[:, :, 1:3]
    AB_r = ref_lab[:, :, 1:3]
    AB_t /= 255.0
    AB_r /= 255.0
    AB_t_flat = AB_t.reshape(-1, 2)
    AB_r_flat = AB_r.reshape(-1, 2)
    mu_t = np.mean(AB_t_flat, axis=0)
    mu_r = np.mean(AB_r_flat, axis=0)
    cov_t = np.cov(AB_t_flat, rowvar=False)
    cov_r = np.cov(AB_r_flat, rowvar=False)
    sqrt_cov_t = sqrtm(cov_t, method="eigen")
    sqrt_cov_r = sqrtm(cov_r, method="eigen")
    transform = sqrt_cov_r @ np.linalg.inv(sqrt_cov_t)
    AB_transformed = ((AB_t_flat - mu_t) @ transform.T) + mu_r
    AB_transformed = np.clip(AB_transformed, 0, 1).reshape(AB_t.shape)
    AB_transformed = (AB_transformed * 255).astype(np.uint8)
    L_t = np.clip(L_t, 0, 255).astype(np.uint8)
    lab_transferred = cv2.merge((L_t, AB_transformed[:, :, 0], AB_transformed[:, :, 1]))
    return cv2.cvtColor(lab_transferred, cv2.COLOR_LAB2BGR)

def process_sequential(target, ref_list):
    """
    Sequentially applies the color transfer.
    target: initial target image (BGR).
    ref_list: list of reference images (BGR) in the desired order.
    Returns a list of tuples: (label, image) with intermediate outputs.
    """
    results = [("Target", target.copy())]
    current = target.copy()
    for i, ref in enumerate(ref_list, start=1):
        if ref.shape != current.shape:
            ref = cv2.resize(ref, (current.shape[1], current.shape[0]))
        current = pca_transfer_lab(current, ref)
        results.append((f"Output {i}", current.copy()))
    return results


#############################################
# Utility / Processing Functions
#############################################

def cv2_to_tk(cv_img, maxsize=(500,500)):
    rgb = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
    pil_img = Image.fromarray(rgb)
    pil_img.thumbnail(maxsize)
    return ImageTk.PhotoImage(pil_img)

def create_polygon_mask(points, img_shape, disp_dims):
    """
    Converts polygon points from canvas coordinates to image coordinates.
    disp_dims: (display_width, display_height)
    img_shape: (img_height, img_width)
    """
    disp_w, disp_h = disp_dims
    img_h, img_w = img_shape
    scale_x = img_w / disp_w
    scale_y = img_h / disp_h
    scaled_points = [(int(x * scale_x), int(y * scale_y)) for (x, y) in points]
    mask = np.zeros((img_h, img_w), dtype=np.uint8)
    pts = np.array(scaled_points, np.int32).reshape((-1, 1, 2))
    cv2.fillPoly(mask, [pts], 255)
    return mask.astype(bool)

def get_top_colors(image, k=10):
    """
    Uses KMeans to extract top k colors from the image.
    Returns a list of colors as lists [B, G, R]. If k-means fails to find enough clusters,
    returns the unique colors found.
    """
    Z = image.reshape((-1, 3)).astype(np.float32)
    try:
        kmeans = KMeans(n_clusters=k, random_state=0)
        kmeans.fit(Z)
        centers = kmeans.cluster_centers_.astype(np.uint8).tolist()
        return centers
    except Exception as e:
        unique_colors = np.unique(Z, axis=0)
        return unique_colors.tolist()

def create_feathered_alpha(mask, feather=15):
    """
    Computes a smooth alpha mask using a distance transform.
    """
    mask_uint = mask.astype(np.uint8)
    inv_mask = cv2.bitwise_not(mask_uint)
    dist = cv2.distanceTransform(inv_mask, cv2.DIST_L2, 5)
    if np.any(mask_uint == 255):
        max_val = np.max(dist[mask_uint == 255])
    else:
        max_val = 1
    alpha = 1 - (dist / (max_val + 1e-5))
    alpha = cv2.GaussianBlur(alpha, (feather, feather), 0)
    return np.clip(alpha, 0, 1)

def apply_tint_feathering(img_bgr, mask, tint_color, blend=0.6, feather=15):
    """
    Blends tint_color into img_bgr over the region defined by mask,
    using a feathered alpha for smooth transitions.
    """
    alpha = create_feathered_alpha(mask, feather=feather)
    tint_img = np.full_like(img_bgr, tint_color, dtype=np.float32)
    img_float = img_bgr.astype(np.float32)
    blended = (1 - blend * alpha[..., None]) * img_float + (blend * alpha[..., None]) * tint_img
    return np.clip(blended, 0, 255).astype(np.uint8)

def compute_mean_color(img_bgr):
    """
    Computes the mean [B, G, R] of the image.
    """
    return np.mean(img_bgr.reshape(-1,3), axis=0).tolist()

def apply_local_mapping(target_img, target_mask, mapping, threshold=40):
    """
    For each pixel in target_img within target_mask, if its color is close to one of the keys in mapping,
    then blend that pixel toward the mapped color.
    'mapping' is a dict with keys and values as tuples: {target_color: ref_color}.
    """
    out = target_img.copy().astype(np.float32)
    indices = np.where(target_mask)
    for y, x in zip(*indices):
        pixel = target_img[y, x].tolist()
        for t_color, r_color in mapping.items():
            d = np.linalg.norm(np.array(pixel, dtype=np.float32) - np.array(t_color, dtype=np.float32))
            if d < threshold:
                w = 1 - (d / threshold)
                new_val = (1 - w) * np.array(pixel) + w * np.array(r_color)
                out[y, x] = new_val
                break
    return np.clip(out, 0, 255).astype(np.uint8)

#############################################
# Interactive UI Application
#############################################

class InteractiveLocalMappingApp(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Interactive Local Color Mapping")
        self.geometry("1400x900")
        
        # Images (BGR)
        self.target_img = None
        self.ref_img = None
        self.processed_img = None  # after global transfer
        
        # Polygons (canvas coordinates)
        self.target_poly_points = []
        self.ref_poly_points = []
        self.target_poly_id = None
        self.ref_poly_id = None
        
        # Displayed dimensions (for scaling)
        self.target_disp_dims = None
        self.ref_disp_dims = None
        
        # Local masks
        self.target_mask = None
        self.ref_mask = None
        
        # Dominant colors extracted from local regions.
        self.target_local_colors = []
        self.ref_local_colors = []
        
        # Mapping dictionary: keys and values as tuples: (target_color) -> (ref_color)
        self.mapping = {}
        
        # Temporary storage for mapping: when selecting, store target color first.
        self.temp_target_color = None
        
        self.create_widgets()
    
    def create_widgets(self):
        top_frame = tk.Frame(self)
        top_frame.pack(padx=10, pady=5, fill=tk.X)
        
        tk.Button(top_frame, text="Load Target Image", command=self.load_target).pack(side=tk.LEFT, padx=5)
        tk.Button(top_frame, text="Load Reference Image", command=self.load_reference).pack(side=tk.LEFT, padx=5)
        tk.Button(top_frame, text="Apply Global Transfer", command=self.apply_global_transfer).pack(side=tk.LEFT, padx=5)
        tk.Button(top_frame, text="Finish Target Polygon", command=self.finish_target_polygon).pack(side=tk.LEFT, padx=5)
        tk.Button(top_frame, text="Finish Reference Polygon", command=self.finish_ref_polygon).pack(side=tk.LEFT, padx=5)
        tk.Button(top_frame, text="Extract Local Colors", command=self.extract_local_colors).pack(side=tk.LEFT, padx=5)
        tk.Button(top_frame, text="Apply Local Mapping", command=self.apply_local_mapping).pack(side=tk.LEFT, padx=5)
        
        main_frame = tk.Frame(self)
        main_frame.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)
        
        # Left: Target canvas for polygon selection.
        target_frame = tk.Frame(main_frame)
        target_frame.pack(side=tk.LEFT, padx=10, pady=10)
        tk.Label(target_frame, text="Target Image").pack()
        self.target_canvas = tk.Canvas(target_frame, bg="gray", width=600, height=600)
        self.target_canvas.pack()
        self.target_canvas.bind("<Button-1>", self.on_target_click)
        
        # Right: Reference canvas for polygon selection.
        ref_frame = tk.Frame(main_frame)
        ref_frame.pack(side=tk.LEFT, padx=10, pady=10)
        tk.Label(ref_frame, text="Reference Image").pack()
        self.ref_canvas = tk.Canvas(ref_frame, bg="gray", width=600, height=600)
        self.ref_canvas.pack()
        self.ref_canvas.bind("<Button-1>", self.on_ref_click)
        
        # Bottom: Local palettes and mapping display.
        bottom_frame = tk.Frame(self)
        bottom_frame.pack(padx=10, pady=5, fill=tk.X)
        
        tk.Label(bottom_frame, text="Target Local Colors:").grid(row=0, column=0, sticky="w")
        self.target_palette_frame = tk.Frame(bottom_frame)
        self.target_palette_frame.grid(row=1, column=0, padx=5)
        
        tk.Label(bottom_frame, text="Reference Local Colors:").grid(row=0, column=1, sticky="w")
        self.ref_palette_frame = tk.Frame(bottom_frame)
        self.ref_palette_frame.grid(row=1, column=1, padx=5)
        
        tk.Label(bottom_frame, text="Mappings (click target then reference):").grid(row=0, column=2, sticky="w")
        self.mapping_frame = tk.Frame(bottom_frame)
        self.mapping_frame.grid(row=1, column=2, padx=5)
        
        self.status_label = tk.Label(self, text="Load images to begin...", bd=1, relief=tk.SUNKEN, anchor=tk.W)
        self.status_label.pack(fill=tk.X)
    
    def load_target(self):
        path = filedialog.askopenfilename(title="Select Target Image", filetypes=[("Image Files", "*.png *.jpg *.jpeg")])
        if path:
            img = cv2.imread(path)
            if img is None:
                messagebox.showerror("Error", "Failed to load target image.")
                return
            self.target_img = img
            self.processed_img = img.copy()
            self.display_on_canvas(self.target_canvas, self.processed_img)
            self.target_disp_dims = (self.target_canvas.winfo_width(), self.target_canvas.winfo_height())
            self.target_poly_points = []
            if self.target_poly_id:
                self.target_canvas.delete(self.target_poly_id)
                self.target_poly_id = None
            self.status_label.config(text="Target image loaded. Now load reference image.")
    
    def load_reference(self):
        path = filedialog.askopenfilename(title="Select Reference Image", filetypes=[("Image Files", "*.png *.jpg *.jpeg")])
        if path:
            img = cv2.imread(path)
            if img is None:
                messagebox.showerror("Error", "Failed to load reference image.")
                return
            self.ref_img = img
            self.display_on_canvas(self.ref_canvas, self.ref_img)
            self.ref_disp_dims = (self.ref_canvas.winfo_width(), self.ref_canvas.winfo_height())
            self.status_label.config(text="Reference image loaded. Click 'Apply Global Transfer' when ready.")
    
    def apply_global_transfer(self):
        if self.target_img is None or self.ref_img is None:
            messagebox.showwarning("Missing Images", "Please load both target and reference images.")
            return
        self.processed_img = pca_transfer_lab(self.target_img, self.ref_img)
        self.display_on_canvas(self.target_canvas, self.processed_img)
        self.status_label.config(text="Global transfer applied. Now draw polygons on both images.")
    
    def on_target_click(self, event):
        self.target_poly_points.append((event.x, event.y))
        r = 3
        self.target_canvas.create_oval(event.x - r, event.y - r, event.x + r, event.y + r, fill="red")
        if self.target_poly_points:
            if self.target_poly_id:
                self.target_canvas.delete(self.target_poly_id)
            self.target_poly_id = self.target_canvas.create_polygon(self.target_poly_points, outline="red", fill="", width=2)
            self.status_label.config(text=f"Target polygon: {self.target_poly_points}")
    
    def on_ref_click(self, event):
        self.ref_poly_points.append((event.x, event.y))
        r = 3
        self.ref_canvas.create_oval(event.x - r, event.y - r, event.x + r, event.y + r, fill="blue")
        if self.ref_poly_points:
            if self.ref_poly_id:
                self.ref_canvas.delete(self.ref_poly_id)
            self.ref_poly_id = self.ref_canvas.create_polygon(self.ref_poly_points, outline="blue", fill="", width=2)
            self.status_label.config(text=f"Reference polygon: {self.ref_poly_points}")
    
    def finish_target_polygon(self):
        if len(self.target_poly_points) < 3:
            messagebox.showwarning("Insufficient Points", "Please select at least 3 points on the target image.")
            return
        self.status_label.config(text="Target polygon finished.")
    
    def finish_ref_polygon(self):
        if len(self.ref_poly_points) < 3:
            messagebox.showwarning("Insufficient Points", "Please select at least 3 points on the reference image.")
            return
        self.status_label.config(text="Reference polygon finished.")
    
    def extract_local_colors(self):
        if len(self.target_poly_points) < 3 or len(self.ref_poly_points) < 3:
            messagebox.showwarning("Missing Polygon", "Please finish drawing polygons on both images.")
            return
        self.target_mask = create_polygon_mask(self.target_poly_points, self.processed_img.shape[:2], self.target_disp_dims)
        self.ref_mask = create_polygon_mask(self.ref_poly_points, self.ref_img.shape[:2], self.ref_disp_dims)
        self.target_local_colors = extract_dominant_colors(self.processed_img, self.target_mask, n_colors=4)
        self.ref_local_colors = extract_dominant_colors(self.ref_img, self.ref_mask, n_colors=4)
        self.display_local_palettes()
        self.mapping = {}
        self.status_label.config(text="Local colors extracted. Click a target swatch then a reference swatch to map them.")
    
    def display_local_palettes(self):
        for widget in self.target_palette_frame.winfo_children():
            widget.destroy()
        for color in self.target_local_colors:
            hex_color = f"#{color[2]:02x}{color[1]:02x}{color[0]:02x}"
            lbl = tk.Label(self.target_palette_frame, bg=hex_color, width=4, height=2, relief=tk.RAISED, bd=2)
            lbl.pack(side=tk.LEFT, padx=5)
            lbl.bind("<Button-1>", lambda e, col=color: self.on_target_color_select(col))
        
        for widget in self.ref_palette_frame.winfo_children():
            widget.destroy()
        for color in self.ref_local_colors:
            hex_color = f"#{color[2]:02x}{color[1]:02x}{color[0]:02x}"
            lbl = tk.Label(self.ref_palette_frame, bg=hex_color, width=4, height=2, relief=tk.RAISED, bd=2)
            lbl.pack(side=tk.LEFT, padx=5)
            lbl.bind("<Button-1>", lambda e, col=color: self.on_ref_color_select(col))
        
        for widget in self.mapping_frame.winfo_children():
            widget.destroy()
    
    def on_target_color_select(self, color):
        # Convert list to tuple for hashing.
        self.temp_target_color = tuple(color)
        self.status_label.config(text=f"Selected target color: {self.temp_target_color}. Now select corresponding reference color.")
    
    def on_ref_color_select(self, color):
        if not hasattr(self, 'temp_target_color') or self.temp_target_color is None:
            self.status_label.config(text="Select a target color first.")
            return
        self.mapping[self.temp_target_color] = tuple(color)
        # Display mapping.
        for widget in self.mapping_frame.winfo_children():
            widget.destroy()
        mapping_text = f"Mapping: {self.temp_target_color} -> {tuple(color)}"
        tk.Label(self.mapping_frame, text=mapping_text).pack()
        self.status_label.config(text=f"Mapping added: {self.temp_target_color} -> {tuple(color)}.")
        self.temp_target_color = None
    
    def apply_local_mapping(self):
        if not self.mapping:
            messagebox.showwarning("No Mapping", "Please map at least one color pair.")
            return
        if self.target_mask is None:
            messagebox.showerror("Error", "Local target mask not available.")
            return
        new_img = apply_local_mapping(self.processed_img, self.target_mask, self.mapping, threshold=40)
        self.processed_img = new_img
        self.display_on_canvas(self.target_canvas, self.processed_img)
        self.status_label.config(text="Local mapping applied.")
    
    def display_on_canvas(self, canvas, cv_img):
        photo = cv2_to_tk(cv_img, maxsize=(canvas.winfo_width(), canvas.winfo_height()))
        canvas.photo = photo
        canvas.delete("all")
        canvas.create_image(canvas.winfo_width()//2, canvas.winfo_height()//2, image=photo)
    
#############################################
# Additional Functions
#############################################

def extract_dominant_colors(img, mask, n_colors=8):
    pixels = img[mask].reshape(-1, 3)
    if pixels.shape[0] < n_colors:
        return []
    kmeans = KMeans(n_clusters=n_colors, random_state=0)
    kmeans.fit(pixels)
    colors = kmeans.cluster_centers_.astype(np.uint8).tolist()
    return colors

#############################################
# Main
#############################################

if __name__ == '__main__':
    def get_top_colors(image, k=10):
        Z = image.reshape((-1, 3)).astype(np.float32)
        criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
        kmeans = KMeans(n_clusters=k, random_state=0)
        kmeans.fit(Z)
        centers = kmeans.cluster_centers_.astype(np.uint8).tolist()
        return centers
    
    app = InteractiveLocalMappingApp()
    app.mainloop()


