In [1]:
import cv2
import numpy as np
import torch
import torchvision.transforms as T
from torchvision import models
from sklearn.neighbors import KDTree

# ---------------------- 1. Load Mask R-CNN Model ----------------------
def load_mask_rcnn():
    """Loads pre-trained Mask R-CNN model."""
    model = models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()
    return model

load_mask_rcnn()



MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(in

In [2]:

# ---------------------- 2. Foreground-Background Segmentation ----------------------
def segment_foreground_background(image, model, threshold=0.5):
    """Returns a binary mask for foreground segmentation using Mask R-CNN."""
    transform = T.Compose([T.ToPILImage(), T.ToTensor()])
    input_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        predictions = model(input_tensor)[0]

    # Get the mask with highest score (most confident detection)
    if len(predictions['masks']) > 0:
        best_mask_idx = torch.argmax(predictions['scores']).item()
        mask = predictions['masks'][best_mask_idx, 0].byte().cpu().numpy()
    else:
        # Return a zero mask if no object is detected
        mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

    # Binarize mask: 1 for foreground, 0 for background
    return (mask > threshold).astype(np.uint8)

# ---------------------- 3. Palette-based Clustering ----------------------
def generate_palette(image_lab, bins=100, max_colors=32, threshold=30, radius=3):
    """Generates a color palette using histogram analysis in Lab space."""
    l_channel, a_channel, b_channel = cv2.split(image_lab)
    
    # Construct histograms
    hist_l = np.histogram(l_channel, bins=bins, range=[0, 255])[0]
    hist_a = np.histogram(a_channel, bins=bins, range=[0, 255])[0]
    hist_b = np.histogram(b_channel, bins=bins, range=[0, 255])[0]


    # Detect peaks in histograms
    peaks_l = peak_search(hist_l, threshold, radius)
    peaks_a = peak_search(hist_a, threshold, radius)
    peaks_b = peak_search(hist_b, threshold, radius)

    # Combine peaks to form initial palette
    raw_palette = np.array([[l, a, b] for l in peaks_l for a in peaks_a for b in peaks_b], dtype=np.float32)

    # Merge close peaks and reduce size
    final_palette = merge_peaks(raw_palette, max_colors)
    return final_palette

def peak_search(hist, threshold, radius):
    """Finds peaks in a histogram."""
    peaks = []
    for i in range(radius, len(hist) - radius):
        if hist[i] > threshold and hist[i] == max(hist[i - radius:i + radius + 1]):
            peaks.append(i)
    return peaks if peaks else [0, 255]  # Include boundary values if no peaks

def merge_peaks(raw_palette, max_colors):
    """Merges close peaks to reduce palette size."""
    if len(raw_palette) <= max_colors:
        return raw_palette

    tree = KDTree(raw_palette)
    merged_palette = []
    used = set()
    
    for i, color in enumerate(raw_palette):
        if i not in used:
            neighbors = tree.query_radius([color], r=10)[0]
            mean_color = np.mean([raw_palette[j] for j in neighbors], axis=0)
            merged_palette.append(mean_color)
            used.update(neighbors)
        
        # Stop when max_colors is reached
        if len(merged_palette) >= max_colors:
            break

    return np.array(merged_palette[:max_colors], dtype=np.float32)

# ---------------------- 4. Color Mapping Strategy ----------------------
def transfer_colors(source_lab, reference_lab, source_palette, reference_palette, mask):
    """Performs color transfer with split correspondence, chromatic aberration, and consistency keeping."""
    tree = KDTree(reference_palette)

    # Separate foreground and background
    fg_indices = np.where(mask == 1)
    bg_indices = np.where(mask == 0)

    # Foreground mapping
    if len(fg_indices[0]) > 0:
        transfer_foreground(source_lab, reference_lab, source_palette, reference_palette, fg_indices, tree)

    # Background mapping
    if len(bg_indices[0]) > 0:
        transfer_foreground(source_lab, reference_lab, source_palette, reference_palette, bg_indices, tree)

    return source_lab

def transfer_foreground(source_lab, reference_lab, source_palette, reference_palette, indices, tree):
    """Transfers colors for foreground or background pixels."""
    h, w, _ = source_lab.shape
    
    for i, j in zip(indices[0], indices[1]):
        # Boundary check to prevent out-of-bounds error
        if i >= h or j >= w:
            continue
        
        pixel = source_lab[i, j]
        _, idx = tree.query([pixel], k=1)
        
        # Ensure the index is valid
        if 0 <= idx[0][0] < len(reference_palette):
            source_lab[i, j] = reference_palette[idx[0][0]]

# ---------------------- 5. Internal and External Consistency ----------------------
def maintain_consistency(source_lab, reference_lab, source_palette, reference_palette):
    """Maintains color consistency and chromatic aberration control."""
    mapping = {}
    for src, ref in zip(source_palette, reference_palette):
        if tuple(src) not in mapping:
            mapping[tuple(src)] = ref
        else:
            mapping[tuple(src)] = 0.5 * (mapping[tuple(src)] + ref)

    for i in range(source_lab.shape[0]):
        for j in range(source_lab.shape[1]):
            pixel = tuple(source_lab[i, j])
            if pixel in mapping:
                source_lab[i, j] = mapping[pixel]
    return source_lab

# ---------------------- 6. Lighting Optimization ----------------------
def optimize_lighting(transferred_lab, reference_lab, alpha=0.3):
    """Applies lighting optimization to prevent abnormal exposure."""
    h, w, _ = transferred_lab.shape

    # Resize reference_lab to match transferred_lab dimensions
    reference_lab_resized = cv2.resize(reference_lab, (w, h), interpolation=cv2.INTER_LINEAR)

    l_channel_src, a_channel_src, b_channel_src = cv2.split(transferred_lab)
    l_channel_ref = reference_lab_resized[:, :, 0]

    # Weighted update of L-channel
    l_channel_optimized = (1 - alpha) * l_channel_src + alpha * l_channel_ref

    # Convert to uint8 before merging
    l_channel_optimized = np.clip(l_channel_optimized, 0, 255).astype(np.uint8)

    # Merge channels to form the optimized Lab image
    optimized_lab = cv2.merge([l_channel_optimized, a_channel_src, b_channel_src])

    # Optional global lighting enhancement
    optimized_lab = enhance_global_lighting(optimized_lab)
    return optimized_lab

def enhance_global_lighting(image_lab):
    """Enhances global lighting using histogram stretching."""
    l_channel, a_channel, b_channel = cv2.split(image_lab)
    
    # Equalize histogram for L-channel to enhance brightness
    l_channel = cv2.equalizeHist(l_channel.astype(np.uint8))
    
    # Merge the enhanced L-channel with original A and B channels
    return cv2.merge([l_channel, a_channel, b_channel])

# ---------------------- 7. Main Execution ----------------------
def main(source_path, reference_path):
    # Load Mask R-CNN model
    model = load_mask_rcnn()

    # Load source and reference images
    source_image = cv2.imread(source_path)
    reference_image = cv2.imread(reference_path)

    if source_image is None or reference_image is None:
        raise FileNotFoundError("Source or Reference image not found!")

    # Convert images to Lab color space
    source_lab = cv2.cvtColor(source_image, cv2.COLOR_BGR2LAB)
    reference_lab = cv2.cvtColor(reference_image, cv2.COLOR_BGR2LAB)

    # Generate palettes
    source_palette = generate_palette(source_lab)
    reference_palette = generate_palette(reference_lab)

    # Foreground-background segmentation using Mask R-CNN
    source_mask = segment_foreground_background(source_image, model)

    # Color transfer with consistency and chromatic aberration control
    transferred_lab = transfer_colors(source_lab, reference_lab, source_palette, reference_palette, source_mask)
    transferred_lab = maintain_consistency(transferred_lab, reference_lab, source_palette, reference_palette)

    # Lighting optimization
    final_result = optimize_lighting(transferred_lab, reference_lab)

    # Convert back to BGR and save output
    final_bgr = cv2.cvtColor(final_result.astype(np.uint8), cv2.COLOR_LAB2BGR)
    cv2.imwrite("output.jpg", final_bgr)

    print("Color transfer completed successfully. Check 'output.jpg'.")


In [3]:

# Run the pipeline
if __name__ == "__main__":
    source_path = "/home/neelraj-reddy/college/6th_sem/computer vision/project/A little survey on previous works/images/input.jpg"  # Input source image
    reference_path = "/home/neelraj-reddy/college/6th_sem/computer vision/project/A little survey on previous works/images/reference.jpeg"  # Input reference image
    main(source_path, reference_path)


Color transfer completed successfully. Check 'output.jpg'.
