In [2]:
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import cv2

In [3]:
def load_img (filename, norm=True,):
    img = np.array(Image.open(filename).convert("RGB"))
    if norm:   
        img = img / 255.
        img = img.astype(np.float32)
    return img

In [4]:
# -------------------------
# Candidate Color Grading Options
# -------------------------

color_grades = {
    "cinematic_warm": {
        "wb": np.array([1.10, 1.00, 0.90]),
        "ccm": np.array([
            [ 1.05, -0.02, -0.03],
            [-0.01,  1.08, -0.07],
            [ 0.00, -0.05,  1.05],
        ])
    },
    "cool_high_contrast": {
        "wb": np.array([0.95, 1.00, 1.10]),
        "ccm": np.array([
            [ 1.10, -0.05, -0.05],
            [-0.05,  1.20, -0.05],
            [-0.05, -0.05,  1.10],
        ])
    },
    "vibrant_boost": {
        "wb": np.array([1.00, 1.05, 1.00]),
        "ccm": np.array([
            [ 1.20, -0.10, -0.10],
            [-0.10,  1.20, -0.10],
            [-0.10, -0.10,  1.20],
        ])
    },
    "neutral_realistic": {
        "wb": np.array([1.00, 1.00, 1.00]),
        "ccm": np.eye(3)
    },
    "vintage_film": {
        "wb": np.array([1.05, 1.00, 0.95]),
        "ccm": np.array([
            [ 0.95,  0.05,  0.00],
            [ 0.00,  1.00,  0.00],
            [ 0.00, -0.05,  1.05],
        ])
    },
     "kodak_vision3_arri_alexa": {
        "wb": np.array([1.05, 1.00, 0.95]),
        "ccm": np.array([
            [1.10, -0.02, -0.08],
            [-0.01, 1.05, -0.04],
            [0.00, -0.05, 1.08]
        ])
    },
    "filmic_natural_sony_venice_red": {
        "wb": np.array([1.00, 1.00, 1.00]),
        "ccm": np.array([
            [1.08, -0.05, -0.03],
            [-0.03, 1.05, -0.02],
            [-0.02, -0.05, 1.10]
        ])
    },
    "cinematic_teal_orange": {
        "wb": np.array([1.10, 0.95, 0.90]),
        "ccm": np.array([
            [1.15, -0.10, -0.05],
            [-0.10, 1.05, -0.10],
            [-0.05, -0.05, 1.20]
        ])
    },
    "fujifilm_classic_chrome": {
        "wb": np.array([1.00, 1.02, 0.98]),
        "ccm": np.array([
            [1.05, -0.02, -0.03],
            [-0.02, 1.03, -0.01],
            [0.00, -0.01, 1.02]
        ])
    },
    "cool_balanced_modern": {
        "wb": np.array([0.99, 1.00, 1.05]),
        "ccm": np.array([
            [1.10, -0.04, -0.04],  
            [-0.03, 1.08, -0.02],  
            [-0.02, -0.04, 1.12]  
        ])
    },
    "photorealistic_balanced": {
    "wb": np.array([1.03, 1.00, 0.97]),
    "ccm": np.array([
        [1.08, -0.03, -0.05],
        [-0.02, 1.05, -0.03],
        [-0.01, -0.04, 1.10]
    ])
},
"realistic_nature_boost": {
    "wb": np.array([1.02, 1.00, 0.98]),
    "ccm": np.array([
        [1.06, -0.02, -0.04],
        [-0.02, 1.05, -0.03],
        [-0.01, -0.02, 1.08]
    ])
},
"urban_neutral_film": {
    "wb": np.array([1.00, 1.00, 1.00]),
    "ccm": np.array([
        [1.05, -0.01, -0.04],
        [-0.01, 1.06, -0.03],
        [-0.02, -0.02, 1.07]
    ])
},
"subtle_filmic_cool": {
    "wb": np.array([1.00, 0.98, 1.02]),
    "ccm": np.array([
        [1.04, -0.02, -0.02],
        [-0.01, 1.03, -0.01],
        [-0.01, -0.02, 1.05]
    ])
}
    
}

In [5]:
# -------------------------
# Helper Functions
# -------------------------

def white_balance(image, wb):
    # Per-channel multiplication for white balance
    wb_img = image * wb.reshape((1, 1, 3))
    return np.clip(wb_img, 0.0, 1.0)

def color_correction(image, ccm):
    # Reshape image to (H*W, 3), apply CCM, then reshape back
    h, w, _ = image.shape
    corrected = np.dot(image.reshape(-1, 3), ccm.T)
    return np.clip(corrected.reshape(h, w, 3), 0.0, 1.0)

def adjust_contrast_saturation(image, contrast=1.2, saturation=1.2):
    # Convert to HSV, adjust saturation and value (contrast), then convert back
    image_8bit = np.uint8(np.clip(image * 255, 0, 255))
    hsv = cv2.cvtColor(image_8bit, cv2.COLOR_RGB2HSV).astype(np.float32)
    hsv[..., 1] *= saturation  # Saturation channel
    hsv[..., 2] = np.clip(hsv[..., 2] * contrast, 0, 255)  # Value channel
    hsv = np.clip(hsv, 0, 255).astype(np.uint8)
    result = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    return result.astype(np.float32) / 255.0

def shadow_correction(image):
    # Convert to LAB and apply CLAHE on the L channel for shadow recovery
    image_8bit = np.uint8(np.clip(image * 255, 0, 255))
    lab = cv2.cvtColor(image_8bit, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    l_eq = clahe.apply(l)
    lab_eq = cv2.merge((l_eq, a, b))
    corrected = cv2.cvtColor(lab_eq, cv2.COLOR_LAB2RGB)
    return corrected.astype(np.float32) / 255.0

def lens_blur(image, kernel_size=3):
    # Apply a slight Gaussian blur to mimic lens imperfections.
    # kernel_size should be an odd number.
    blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigmaX=0.5)
    return blurred

In [6]:
import torch
from torch import nn
import torch.nn.functional as F

class CNILUT(nn.Module):
    """
    Simple residual coordinate-based neural network for fitting 3D LUTs
    Official code: https://github.com/mv-lab/nilut
    """
    def __init__(self, in_features=3, hidden_features=256, hidden_layers=3, out_features=3, styles=3, res=True):
        super().__init__()
        
        self.styles = styles
        self.res = res
        self.net = []
        self.net.append(nn.Linear(in_features+styles, hidden_features))
        self.net.append(nn.ReLU())
        
        for _ in range(hidden_layers):
            self.net.append(nn.Linear(hidden_features, hidden_features))
            self.net.append(nn.Tanh())
        
        self.net.append(nn.Linear(hidden_features, out_features))
        if not self.res:
            self.net.append(torch.nn.Sigmoid())
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, intensity):
        output = self.net(intensity)
        if self.res:
            output = output + intensity[:,:self.styles]
            output = torch.clamp(output, 0.,1.)
        
        return output


In [7]:
PATH = "/kaggle/input/nilut-data/nilutx3style.pt"
lut_model  = CNILUT(in_features=3, out_features=3, hidden_features=256, hidden_layers=2, styles=3, res=True)
lut_model.load_state_dict(torch.load(PATH, weights_only=True)['model'])
lut_model.cuda()
lut_model.eval()

CNILUT(
  (net): Sequential(
    (0): Linear(in_features=6, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): Tanh()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): Tanh()
    (6): Linear(in_features=256, out_features=3, bias=True)
  )
)

In [8]:
def run_nilut (image, style=1):
    image_tensor = torch.from_numpy(image)
    image_tensor = image_tensor.reshape((image_tensor.shape[0]*image_tensor.shape[1],3)) # [hw, 3]

    style_vector = np.zeros(3).astype(np.float32)
    style_vector[style] = 1.
    style_vector    = torch.from_numpy(style_vector)
    style_vector_re = style_vector.repeat(image_tensor.shape[0]).view(image_tensor.shape[0],3)

    img = torch.cat([image_tensor,style_vector_re], dim=-1)

    with torch.no_grad():
        out = lut_model(img.cuda())

    np_out  = out.cpu().view(image.shape[0],image.shape[1],3).detach().numpy().astype(np.float32)
    return np_out

In [9]:
# -------------------------
# Enhanced Realism + Plan B Pipeline
# -------------------------

def add_vignette(image, strength=0.1):
    h, w = image.shape[:2]
    Y, X = np.ogrid[:h, :w]
    center = (h/2, w/2)
    distance = np.sqrt((X - center[1])**2 + (Y - center[0])**2)
    max_dist = np.sqrt(center[0]**2 + center[1]**2)
    mask = 1 - strength * (distance / max_dist)
    mask = np.clip(mask, 0.8, 1)
    return (image * mask[..., np.newaxis])

def sharpen(image, amount=0.3):
    blurred = cv2.GaussianBlur(image, (0, 0), 3)
    return np.clip(image + amount * (image - blurred), 0, 1)

def slight_desaturation(image, factor=0.95):
    image_8bit = np.uint8(np.clip(image * 255, 0, 255))
    hsv = cv2.cvtColor(image_8bit, cv2.COLOR_RGB2HSV).astype(np.float32)
    hsv[...,1] *= factor
    hsv = np.clip(hsv, 0, 255).astype(np.uint8)
    image = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    return image.astype(np.float32) / 255.0

def enhanced_pipeline(image, wb, ccm, 
                      gamma_val=2.2, 
                      contrast=1.2, 
                      saturation=1.2,
                      blur_kernel=3, 
                      lut=None,
                      apply_vignette=True,
                      apply_sharpen=True,
                      apply_desaturation=True):
    """
    Full enhanced ISP pipeline for photorealism (merged with Plan B).
    """

    # 0. Optional: Add tiny random WB jitter (for diversity)
    wb = wb + np.random.normal(0, 0.005, wb.shape)
    wb = np.clip(wb, 0.95, 1.05)

    # 1. White Balance
    image = white_balance(image, wb)
    
    # 2. Color Correction with CCM
    image = color_correction(image, ccm)
    
    # 3. Adaptive Gamma Correction
    avg_brightness = image.mean()
    gamma_val = 1.8 if avg_brightness > 0.5 else 2.2
    image = np.clip(np.power(image, 1.0 / gamma_val), 0, 1)
    
    # 4. Contrast and Saturation Adjustment
    image = adjust_contrast_saturation(image, contrast=contrast, saturation=saturation)
    
    # 5. Shadow Correction (via LAB CLAHE)
    image = shadow_correction(image)
    
    # 6. Optional Lens Blur (simulate optical imperfections)
    image = lens_blur(image, kernel_size=blur_kernel)
    
    # 7. Neural LUT-based Color Mapping
    image = run_nilut(image, style=1)

    # 8. Vignette (optional)
    if apply_vignette:
        image = add_vignette(image, strength=0.1)

    # 9. Sharpening (optional)
    if apply_sharpen:
        image = sharpen(image, amount=0.1)   # lighter sharpen for final polish

    # 10. Slight Desaturation (optional for realism)
    if apply_desaturation:
        image = slight_desaturation(image, factor=0.95)

    return image


In [10]:
import os
import numpy as np
import cv2
from tqdm import tqdm
from glob import glob

def process_selected_folders(root_folder, selected_folders, output_folder,
                             wb, ccm, gamma_val=2.2, contrast=1.2, saturation=1.2,
                             blur_kernel=3, lut=None):

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    image_files = []

    # Go through only selected folders and find images recursively
    for folder in selected_folders:
        folder_path = os.path.join(root_folder, folder)
        found = glob(os.path.join(folder_path, "**", "*.*"), recursive=True)
        image_files.extend(found)

    # Filter for image extensions
    image_files = [f for f in image_files if f.lower().endswith(('jpg', 'png', 'jpeg'))]

    print("📁 Selected folders:", selected_folders)
    print("🔍 Total image files found:", len(image_files))

    for image_path in tqdm(image_files):
        try:
            image = load_img(image_path)

            output_img = enhanced_pipeline(image, wb, ccm,
                                           gamma_val=gamma_val,
                                           contrast=contrast,
                                           saturation=saturation,
                                           blur_kernel=blur_kernel,
                                           lut=lut)

            output_display = np.uint8(np.clip(output_img * 255, 0, 255))
            filename = os.path.basename(image_path)
            output_path = os.path.join(output_folder, filename)
            cv2.imwrite(output_path, cv2.cvtColor(output_display, cv2.COLOR_RGB2BGR))

        except Exception as e:
            print(f"❌ Failed to process {image_path}: {e}")

    print(f"\n✅ Done! Processed {len(image_files)} images and saved to: {output_folder}")


Nilut

In [11]:
grade = color_grades["cool_balanced_modern"]
wb = grade["wb"]
ccm = grade["ccm"]
selected_folders = [f"{str(i).zfill(2)}_images" for i in range(10, 11)]  # 01_images to 05_images


root_folder = "/kaggle/input/playing-for-data"
output_folder = "/kaggle/working/isp_nilut_output_10_images"

process_selected_folders(root_folder, selected_folders, output_folder,
                         wb, ccm, gamma_val=1.8, contrast=1.1, saturation=1.15,
                         blur_kernel=1, lut=None)



📁 Selected folders: ['10_images']
🔍 Total image files found: 2466


100%|██████████| 2466/2466 [38:34<00:00,  1.07it/s]


✅ Done! Processed 2466 images and saved to: /kaggle/working/isp_nilut_output_10_images





In [12]:
!pip install clean-fid


Collecting clean-fid
  Downloading clean_fid-0.1.35-py3-none-any.whl.metadata (36 kB)
Downloading clean_fid-0.1.35-py3-none-any.whl (26 kB)
Installing collected packages: clean-fid
Successfully installed clean-fid-0.1.35


In [13]:
from cleanfid import fid


fdir1 = "/kaggle/input/photorealism-data/ours_Cityscape"   # Cityscapes (Ground Truth)
fdir2 = "/kaggle/input/playing-for-data/01_images/images"  # change this

# FID
score_fid = fid.compute_fid(fdir1, fdir2)
print(f"FID Score: {score_fid}")

print("cool_balanced_modern contrast= 1.1, saturation=1.15 and added vignette sharpen and slight jitter and final boost for desaturation")

compute FID between two folders




Found 19252 images in the folder /kaggle/input/photorealism-data/ours_Cityscape


FID ours_Cityscape : 100%|██████████| 602/602 [03:04<00:00,  3.27it/s]


Found 2500 images in the folder /kaggle/input/playing-for-data/01_images/images


FID images : 100%|██████████| 79/79 [01:48<00:00,  1.37s/it]


FID Score: 54.14864577416438
cool_balanced_modern contrast= 1.1, saturation=1.15 and added vignette sharpen and slight jitter and final boost for desaturation


In [14]:
from cleanfid import fid


score_kid = fid.compute_kid(fdir1, fdir2)
print(f"KID Score: {score_kid * 1000:.3f}")  #Multiply by 100 for proper reporting

compute KID between two folders
Found 19252 images in the folder /kaggle/input/photorealism-data/ours_Cityscape


KID ours_Cityscape : 100%|██████████| 602/602 [03:02<00:00,  3.30it/s]


Found 2500 images in the folder /kaggle/input/playing-for-data/01_images/images


KID images : 100%|██████████| 79/79 [01:45<00:00,  1.34s/it]


KID Score: 58.996
