In [1]:
"""
nn_style_transfer_with_mask.py

A Gatys-style neural style transfer script extended to mimic the functionality described in:
"Style-Transfer via Texture-Synthesis" (Elad & Milanfar, 2016) — multi-scale, palette matching,
and segmentation mask W to preserve content regions.

Usage:
    python nn_style_transfer_with_mask.py 
        --content content.jpg 
        --style style.jpg 
        --mask face_mask.png   # optional (grayscale, same size as content). White=preserve
        --out result.jpg

Dependencies:
    pip install torch torchvision pillow numpy scikit-image
"""
#Colab for GPU: https://colab.research.google.com/drive/1bHCWFXnwWDOGI3knMbplxyGr9Dxocax3?usp=sharing

import argparse
import copy
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import torchvision.models as models
from skimage.exposure import match_histograms
import os
 
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
#utility functions
def load_image(path, target_size=None):
    img = Image.open(path).convert('RGB')
    if target_size is not None:
        img = img.resize(target_size, Image.LANCZOS) # خوارزمية إعادة التحجيم (resampling) تعطي نتائج ناعمة ودقيقة
    return img

#تحويل من صيغة بايثون للثور لتنسور
def pil_to_tensor(img):
    transform = T.Compose([
        T.ToTensor(),
        T.Lambda(lambda x: x.mul(255))
    ])
    return transform(img).unsqueeze(0).to(device)

def tensor_to_pil(tensor):
    t = tensor.detach().cpu().squeeze(0).clamp(0,255).div(255)
    transform = T.ToPILImage()
    return transform(t)


# نُسطّح (flatten) الأبعاد المكانية  
# [C,H×W]
# نحسب مصفوفة غرام بضرب المصفوفة في محورها المنقول:
# G=F⋅FT
# النتيجة: مصفوفة بُعدها 
# [C,C]، كل عنصر فيها يُظهر مدى ارتباط قناة معينة بأخرى عبر كل موضع في الصورة.
def gram_matrix(features):
    (b, ch, h, w) = features.size()
    features = features.view(b, ch, h*w)
    G = torch.bmm(features, features.transpose(1,2))
    return G / (ch * h * w)

def total_variation_loss(img):
    # كلما كانت الصورة أكثر سلاسة (مثل لوحة زيتية)، قلّت هذه القيمة.
    # Horizontal differences (along width)
    x_diff = img[..., :, 1:] - img[..., :, :-1]
    # Vertical differences (along height)
    y_diff = img[..., 1:, :] - img[..., :-1, :]  
    return (x_diff.abs().mean() + y_diff.abs().mean())

# Helper: resize tensor image (1,3,H,W) to PIL-size (H,W) keeping [0..255]
def tensor_resize(tensor, size):
    # size is (H, W) (note PIL uses width,height but we pass height,width)
    t = tensor.detach().cpu().clamp(0,255).squeeze(0) / 255.0
    transform = T.Compose([
        T.ToPILImage(),
        T.Resize((size[0], size[1]), Image.LANCZOS),
        T.ToTensor(),
        T.Lambda(lambda x: x.mul(255))
    ])
    return transform(t).unsqueeze(0).to(device)



In [3]:
class VGGFeatures(nn.Module):
    def __init__(self, layers):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features.to(device).eval()
        self.selected_layers = layers
        self.layers = vgg
        # freeze params
        for p in self.layers.parameters():
            p.requires_grad = False     #لا تحديث أوزان VGG

    def forward(self, x):
        results = {}
        for i, l in enumerate(self.layers):
            x = l(x)
            name = f"l{i}"
            if name in self.selected_layers:
                results[name] = x
        return results


In [4]:
# Gatys Neural Style Transfer (coarse-to-fine)
def run_style_transfer(content_img_pil, style_img_pil,
                       out_size=(400,400), num_steps=300, style_weight=1e6,
                       content_weight=1e0, tv_weight=1e-6, lr=0.02,
                       pyramid_scales=[(100,100),(200,200),(400,400)]):
    """
    Performs coarse-to-fine style transfer.
    - mask_pil: grayscale PIL image same size as content. White=preserve content (W high), black=free to stylize.
    """
    # Prepare VGG extractor layers (common choices)
    # We'll use conv layers indices as names: l0..lN (matching VGG19 features)
    #طبقاته المبكرة تلتقط تفاصيل بسيطة (حواف، ألوان، نقوش)
    #طبقاته العميقة تلتقط مفاهيم معقدة (كائنات، هياكل عالية المستوى).
    style_layers = ['l1','l6','l11','l20']   # shallow->deep conv layers for style Gram
    content_layer = 'l21'                   # deeper layer for content
    all_layers = style_layers + [content_layer]

    extractor = VGGFeatures(all_layers).to(device)

    # convert style image to tensor and precompute style grams per scale
    style_pil_original = style_img_pil.copy()
    content_pil_original = content_img_pil.copy()
   
    # initial image X: we will start with content + noise as in the paper
    X = pil_to_tensor(content_pil_original).clone()
    noise = torch.randn_like(X) * 50.0  # sigma=50 as in paper initialization
    #requires_grad_(True) يجعل هذا الـ tensor قابلًا للتفاضل — أي أن PyTorch سيحسب التدرجات بالنسبة له أثناء الـ backward()
    X = (X + noise).clamp(0,255).detach().requires_grad_(True).to(device)

    # iterate scales from coarse->fine
    scales = pyramid_scales
    for (w,h) in scales:
        print(f"\n--- Scale {(w,h)} ---")
        # resize images
        Cp = content_pil_original.resize((w,h), Image.LANCZOS)
        Sp = style_pil_original.resize((w,h), Image.LANCZOS)
        Sp_tensor = pil_to_tensor(Sp)

        # palette matching: match content to style palette before optimization at each scale
        # we perform histogram matching using skimage (works on numpy)
        content_np = np.array(Cp).astype(np.uint8)
        style_np = np.array(Sp).astype(np.uint8)
        # تُعدّل توزيع ألوان صورة المحتوى لتشبه توزيع ألوان صورة النمط
        matched = match_histograms(content_np, style_np, channel_axis=-1).astype(np.uint8)
        Cp_matched_pil = Image.fromarray(matched)
        Cp_tensor_matched = pil_to_tensor(Cp_matched_pil)

        # compute style grams
        with torch.no_grad():
            sp_feats = extractor(Sp_tensor/255.0)
            style_grams = {k: gram_matrix(v) for k,v in sp_feats.items() if k in style_layers}

        # resize current X to scale
        X = tensor_resize(X, (h,w))  # helper below
        X = X.detach().requires_grad_(True)

        # optimizer
        optimizer = optim.LBFGS([X], lr=lr)  # LBFGS is classic for style transfer
        run = [0]
        # Optimize the generated image X for the current scale using LBFGS.
        # We use a closure because LBFGS requires a re-evaluatable function for line search.
        while run[0] < num_steps:
            def closure():
                # Zero gradients from the previous step
                optimizer.zero_grad()
                
                # Normalize input to [0, 1] as expected by the pretrained VGG model
                X_norm = X / 255.0
                feats = extractor(X_norm)  # Extract feature maps from selected VGG layers
                
                # --- Content Loss ---
                # Get features of the current generated image at the content layer
                content_feat = feats[content_layer]
                # Get target content features from the (color-matched) content image (detached to avoid gradients)
                target_content = extractor(Cp_tensor_matched / 255.0)[content_layer].detach()
                # Compute Mean Squared Error between generated and target content features
                c_loss = nn.MSELoss()(content_feat, target_content)

                # --- Style Loss ---
                # Accumulate style loss across all selected style layers using Gram matrices
                s_loss = 0.0
                for l in style_layers:
                    G = gram_matrix(feats[l])        # Gram matrix of current generated image
                    A = style_grams[l]               # Precomputed Gram matrix of the style image
                    s_loss = s_loss + nn.MSELoss()(G, A.expand_as(G))  # Match style representations

                # --- Total Variation (TV) Loss ---
                # Encourage spatial smoothness in the output image (reduce high-frequency noise)
                tv = total_variation_loss(X)

                # --- Total Loss ---
                # Weighted combination of content, style, and TV losses
                loss = content_weight * c_loss + style_weight * s_loss + tv_weight * tv
                
                # Backpropagate gradients through the generated image X
                loss.backward()
                
                # Log losses every 50 iterations for monitoring
                if run[0] % 50 == 0:
                    print(f"iter {run[0]} loss: content {c_loss.item():.4e} style {s_loss.item():.4e} tv {tv.item():.4e}")
                
                # Increment the step counter (using list to allow mutation inside closure)
                run[0] += 1
                return loss  # LBFGS needs the loss value returned by closure
            
            # Perform one optimization step (LBFGS may call closure multiple times internally)
            optimizer.step(closure)
            # نمرّر X عبر VGG → نأخذ ميزاته في طبقة عميقة (مثل l21).
            # نمرّر صورة المحتوى عبر VGG → نأخذ ميزاته في نفس الطبقة.
            # نحسب MSE بين الميزتين:
            # هذا يقيس: "هل الهيكل العام لـ X مشابه للمحتوى؟"
            # ب. خسارة النمط:
            # نحسب مصفوفة غرام لـ X من طبقات متعددة.
            # نحسب مصفوفة غرام لصورة النمط من نفس الطبقات.
            # نحسب MSE بين كل زوج

        # --- Post-optimization color refinement ---
        # After optimizing at this scale, re-match the color palette of the output to the style image
        # to maintain consistent colors before proceeding to the next (finer) scale.
        X_img = tensor_to_pil(X)  # Convert tensor back to PIL image
        X_np = np.array(X_img).astype(np.uint8)
        Sp_np = np.array(Sp)      # Style image at current scale as NumPy array
        # Apply histogram matching to align color distributions
        X_matched = match_histograms(X_np, Sp_np, channel_axis=-1).astype(np.uint8)
        # Convert back to tensor and re-enable gradients for the next scale optimization
        X = pil_to_tensor(Image.fromarray(X_matched)).detach().requires_grad_(True).to(device)
        
    # final result (resize to requested out_size)
    final = tensor_resize(X, (out_size[1], out_size[0]))
    return tensor_to_pil(final)

In [7]:
def main():
    
    C = load_image("../Data/content/wave_home.jpg")
    S = load_image("../Data/style/waves.jpg")
    output='model_result.jpg'

    res = run_style_transfer(C, S, num_steps=150 )

    res.save(output)
    print("Saved:", output)

if __name__ == '__main__':
    main()





--- Scale (100, 100) ---
iter 0 loss: content 2.5637e+00 style 3.4549e-05 tv 3.3182e+01
iter 50 loss: content 2.5553e+00 style 1.8787e-05 tv 3.3614e+01
iter 100 loss: content 3.1608e+00 style 1.1014e-05 tv 3.5820e+01
iter 150 loss: content 3.4340e+00 style 8.1848e-06 tv 3.8047e+01

--- Scale (200, 200) ---
iter 0 loss: content 3.1148e+00 style 2.3348e-05 tv 2.1710e+01
iter 50 loss: content 3.1148e+00 style 2.3348e-05 tv 2.1710e+01
iter 100 loss: content 3.1148e+00 style 2.3348e-05 tv 2.1710e+01

--- Scale (400, 400) ---
iter 0 loss: content 2.1720e+00 style 1.5054e-05 tv 1.2334e+01
iter 50 loss: content 2.1720e+00 style 1.5054e-05 tv 1.2334e+01
iter 100 loss: content 2.1720e+00 style 1.5054e-05 tv 1.2334e+01
Saved: model_result.jpg
