In [None]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from deap import base, creator, tools
import random

# ------------------------------
# 1. Setup Pretrained Feature Extractor with PyTorch (GPU)
# ------------------------------
class VGG16FeatureExtractor(nn.Module):
    def __init__(self):
        super(VGG16FeatureExtractor, self).__init__()
        vgg16 = models.vgg16(pretrained=True)
        self.features = vgg16.features
        for param in self.features.parameters():
            param.requires_grad = False
        self.pool = nn.AdaptiveAvgPool2d((7, 7))

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        return x.view(x.size(0), -1)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = VGG16FeatureExtractor().to(device)
feature_extractor.eval()

# Transformation for feature extractor
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def extract_features_torch(img: np.ndarray) -> np.ndarray:
    """Extract deep features from an RGB image (numpy array [0,1])"""
    input_tensor = transform((img * 255).astype(np.uint8))
    input_tensor = input_tensor.unsqueeze(0).to(device)
    with torch.no_grad():
        feats = feature_extractor(input_tensor)
    return feats.cpu().numpy().flatten()

# ------------------------------
# 2. GPU-based Image Enhancement Operators
# ------------------------------
def enhance_image_gpu(image: np.ndarray, individual: list) -> np.ndarray:
    brightness, contrast, gamma = individual
    img_tensor = torch.tensor(image, dtype=torch.float32, device=device)
    img_bright = torch.clamp(img_tensor + (brightness / 255.0), 0.0, 1.0)
    img_contrast = torch.clamp(contrast * img_bright, 0.0, 1.0)
    img_gamma = torch.clamp(torch.pow(img_contrast, 1.0 / gamma), 0.0, 1.0)
    return img_gamma.cpu().numpy()

# ------------------------------
# 3. Corrected Fitness Evaluation
# ------------------------------
def evaluate_individual(individual, original_img, orig_features):
    enhanced = enhance_image_gpu(original_img, individual)
    img_tensor = torch.tensor(enhanced, dtype=torch.float32, device=device)
    gray = 0.114 * img_tensor[:,:,0] + 0.587 * img_tensor[:,:,1] + 0.299 * img_tensor[:,:,2]
    hist = torch.histc(gray, bins=256, min=0.0, max=1.0)
    prob = hist / torch.sum(hist)
    entropy = -torch.sum(prob * torch.log2(prob + 1e-6)).item()

    # Deep features
    enhanced_rgb = cv2.cvtColor((enhanced * 255).astype(np.uint8), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    feat_enh = extract_features_torch(enhanced_rgb)
    feat_dist = np.linalg.norm(orig_features - feat_enh)

    # Brightness penalty
    mean_b = torch.mean(img_tensor).item()
    penalty = 0.0
    low, up, factor = 0.35, 0.7, 30.0
    if mean_b < low:
        penalty += (low - mean_b) * factor
    elif mean_b > up:
        penalty += (mean_b - up) * factor

    return entropy, feat_dist + penalty

# ------------------------------
# 4. GA Configuration
# ------------------------------
PARAM_BOUNDS = [(-10, 60), (1.0, 2.0), (1.0, 2.0)]
creator.create("FitnessMulti", base.Fitness, weights=(1.0, -1.0))
creator.create("Individual", list, fitness=creator.FitnessMulti)

toolbox = base.Toolbox()
for i, (low, up) in enumerate(PARAM_BOUNDS):
    toolbox.register(f"attr_{i}", random.uniform, low, up)
toolbox.register("individual", tools.initCycle, creator.Individual,
                 (toolbox.attr_0, toolbox.attr_1, toolbox.attr_2), n=1)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("evaluate", evaluate_individual)

def bounded_cx(ind1, ind2):
    tools.cxBlend(ind1, ind2, alpha=0.5)
    for i, (low, up) in enumerate(PARAM_BOUNDS):
        ind1[i] = np.clip(ind1[i], low, up)
        ind2[i] = np.clip(ind2[i], low, up)
    return ind1, ind2

def bounded_mut(ind, indpb):
    for i, (low, up) in enumerate(PARAM_BOUNDS):
        if random.random() < indpb:
            ind[i] += random.gauss(0, (up - low) * 0.1)
            ind[i] = np.clip(ind[i], low, up)
    return ind,

toolbox.register("mate", bounded_cx)
toolbox.register("mutate", bounded_mut, indpb=0.4)
toolbox.register("select", tools.selNSGA2)

# ------------------------------
# 5. Local Search
# ------------------------------
def local_search(ind, original_img, orig_features, max_iter=8):
    best = ind[:]
    best_fit = evaluate_individual(best, original_img, orig_features)
    for _ in range(max_iter):
        neigh = [np.clip(p + random.uniform(-0.5, 0.5), lo, hi)
                 for p, (lo, hi) in zip(best, PARAM_BOUNDS)]
        fit_n = evaluate_individual(neigh, original_img, orig_features)
        if (fit_n[0] > best_fit[0]) or (fit_n[0] == best_fit[0] and fit_n[1] < best_fit[1]):
            best, best_fit = neigh, fit_n
    return best

# ------------------------------
# 6. GA Routine Without Display
# ------------------------------
def run_ga(original_img, ngen=35, pop_size=50):
    orig_f = original_img.astype(np.float32) / 255.0
    orig_rgb = cv2.cvtColor((orig_f * 255).astype(np.uint8), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    feats = extract_features_torch(orig_rgb)

    pop = toolbox.population(n=pop_size)
    for ind in pop:
        ind.fitness.values = toolbox.evaluate(ind, orig_f, feats)
    pop = toolbox.select(pop, len(pop))

    for gen in range(ngen):
        offspring = tools.selTournamentDCD(pop, len(pop))
        offspring = [toolbox.clone(i) for i in offspring]
        # Crossover & Mutation
        for i1, i2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < 0.85:
                toolbox.mate(i1, i2)
                del i1.fitness.values, i2.fitness.values
        mut_rate = 0.3 if gen < ngen//2 else 0.2
        for m in offspring:
            if random.random() < mut_rate:
                toolbox.mutate(m)
                del m.fitness.values

        invalid = [i for i in offspring if not i.fitness.valid]
        for i in invalid:
            i.fitness.values = toolbox.evaluate(i, orig_f, feats)
        pop = toolbox.select(pop + offspring, pop_size)

        # Local search on top 20%
        top = tools.selBest(pop, max(1, int(0.2 * pop_size)))
        for ind in top:
            new_p = local_search(ind, orig_f, feats)
            new_fit = evaluate_individual(new_p, orig_f, feats)
            if (new_fit[0] > ind.fitness.values[0]) or \
               (new_fit[0] == ind.fitness.values[0] and new_fit[1] < ind.fitness.values[1]):
                ind[:] = new_p
                ind.fitness.values = new_fit

    best = tools.selBest(pop, 1)[0]
    enhanced = enhance_image_gpu(orig_f, best)
    return enhanced

# ------------------------------
# 7. Batch Processing
# ------------------------------
if __name__ == "__main__":
    input_root = r"D:\mit-5k-subset\d"
    output_root = r"D:\mit-5k-subset\d_enhanced"
    os.makedirs(output_root, exist_ok=True)

    for subdir, _, files in os.walk(input_root):
        rel_path = os.path.relpath(subdir, input_root)
        output_subdir = os.path.join(output_root, rel_path)
        os.makedirs(output_subdir, exist_ok=True)

        for fname in files:
            if not fname.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff")):
                continue
            in_path = os.path.join(subdir, fname)
            img = cv2.imread(in_path)
            if img is None:
                print(f"Warning: unable to read {in_path}")
                continue
            print(f"Enhancing: {in_path}")
            enhanced = run_ga(img, ngen=5, pop_size=60)
            out_img = (enhanced * 255).astype(np.uint8)
            out_path = os.path.join(output_subdir, fname)
            cv2.imwrite(out_path, out_img)
            print(f"Saved enhanced image: {out_path}")





Enhancing: D:\mit-5k-subset\d\a0034-LSYD4O2202.jpg


In [1]:
#run this cell only if you stop the first cell without completition
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from deap import base, creator, tools
import random

# ------------------------------
# 1. Setup Pretrained Feature Extractor with PyTorch (GPU)
# ------------------------------
class VGG16FeatureExtractor(nn.Module):
    def __init__(self):
        super(VGG16FeatureExtractor, self).__init__()
        vgg16 = models.vgg16(pretrained=True)
        self.features = vgg16.features
        for param in self.features.parameters():
            param.requires_grad = False
        self.pool = nn.AdaptiveAvgPool2d((7, 7))

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        return x.view(x.size(0), -1)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = VGG16FeatureExtractor().to(device)
feature_extractor.eval()

# Transformation for feature extractor
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def extract_features_torch(img: np.ndarray) -> np.ndarray:
    """Extract deep features from an RGB image (numpy array [0,1])"""
    input_tensor = transform((img * 255).astype(np.uint8))
    input_tensor = input_tensor.unsqueeze(0).to(device)
    with torch.no_grad():
        feats = feature_extractor(input_tensor)
    return feats.cpu().numpy().flatten()

# ------------------------------
# 2. GPU-based Image Enhancement Operators
# ------------------------------
def enhance_image_gpu(image: np.ndarray, individual: list) -> np.ndarray:
    brightness, contrast, gamma = individual
    img_tensor = torch.tensor(image, dtype=torch.float32, device=device)
    img_bright = torch.clamp(img_tensor + (brightness / 255.0), 0.0, 1.0)
    img_contrast = torch.clamp(contrast * img_bright, 0.0, 1.0)
    img_gamma = torch.clamp(torch.pow(img_contrast, 1.0 / gamma), 0.0, 1.0)
    return img_gamma.cpu().numpy()

# ------------------------------
# 3. Corrected Fitness Evaluation
# ------------------------------
def evaluate_individual(individual, original_img, orig_features):
    enhanced = enhance_image_gpu(original_img, individual)
    img_tensor = torch.tensor(enhanced, dtype=torch.float32, device=device)
    gray = 0.114 * img_tensor[:,:,0] + 0.587 * img_tensor[:,:,1] + 0.299 * img_tensor[:,:,2]
    hist = torch.histc(gray, bins=256, min=0.0, max=1.0)
    prob = hist / torch.sum(hist)
    entropy = -torch.sum(prob * torch.log2(prob + 1e-6)).item()

    # Deep features
    enhanced_rgb = cv2.cvtColor((enhanced * 255).astype(np.uint8), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    feat_enh = extract_features_torch(enhanced_rgb)
    feat_dist = np.linalg.norm(orig_features - feat_enh)

    # Brightness penalty
    mean_b = torch.mean(img_tensor).item()
    penalty = 0.0
    low, up, factor = 0.35, 0.7, 30.0
    if mean_b < low:
        penalty += (low - mean_b) * factor
    elif mean_b > up:
        penalty += (mean_b - up) * factor

    return entropy, feat_dist + penalty

# ------------------------------
# 4. GA Configuration
# ------------------------------
PARAM_BOUNDS = [(-10, 60), (1.0, 2.0), (1.0, 2.0)]
creator.create("FitnessMulti", base.Fitness, weights=(1.0, -1.0))
creator.create("Individual", list, fitness=creator.FitnessMulti)

toolbox = base.Toolbox()
for i, (low, up) in enumerate(PARAM_BOUNDS):
    toolbox.register(f"attr_{i}", random.uniform, low, up)
toolbox.register("individual", tools.initCycle, creator.Individual,
                 (toolbox.attr_0, toolbox.attr_1, toolbox.attr_2), n=1)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("evaluate", evaluate_individual)

def bounded_cx(ind1, ind2):
    tools.cxBlend(ind1, ind2, alpha=0.5)
    for i, (low, up) in enumerate(PARAM_BOUNDS):
        ind1[i] = np.clip(ind1[i], low, up)
        ind2[i] = np.clip(ind2[i], low, up)
    return ind1, ind2

def bounded_mut(ind, indpb):
    for i, (low, up) in enumerate(PARAM_BOUNDS):
        if random.random() < indpb:
            ind[i] += random.gauss(0, (up - low) * 0.1)
            ind[i] = np.clip(ind[i], low, up)
    return ind,

toolbox.register("mate", bounded_cx)
toolbox.register("mutate", bounded_mut, indpb=0.4)
toolbox.register("select", tools.selNSGA2)

# ------------------------------
# 5. Local Search
# ------------------------------
def local_search(ind, original_img, orig_features, max_iter=8):
    best = ind[:]
    best_fit = evaluate_individual(best, original_img, orig_features)
    for _ in range(max_iter):
        neigh = [np.clip(p + random.uniform(-0.5, 0.5), lo, hi)
                 for p, (lo, hi) in zip(best, PARAM_BOUNDS)]
        fit_n = evaluate_individual(neigh, original_img, orig_features)
        if (fit_n[0] > best_fit[0]) or (fit_n[0] == best_fit[0] and fit_n[1] < best_fit[1]):
            best, best_fit = neigh, fit_n
    return best

# ------------------------------
# 6. GA Routine Without Display
# ------------------------------
def run_ga(original_img, ngen=35, pop_size=50):
    orig_f = original_img.astype(np.float32) / 255.0
    orig_rgb = cv2.cvtColor((orig_f * 255).astype(np.uint8), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    feats = extract_features_torch(orig_rgb)

    pop = toolbox.population(n=pop_size)
    for ind in pop:
        ind.fitness.values = toolbox.evaluate(ind, orig_f, feats)
    pop = toolbox.select(pop, len(pop))

    for gen in range(ngen):
        offspring = tools.selTournamentDCD(pop, len(pop))
        offspring = [toolbox.clone(i) for i in offspring]
        # Crossover & Mutation
        for i1, i2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < 0.85:
                toolbox.mate(i1, i2)
                del i1.fitness.values, i2.fitness.values
        mut_rate = 0.3 if gen < ngen//2 else 0.2
        for m in offspring:
            if random.random() < mut_rate:
                toolbox.mutate(m)
                del m.fitness.values

        invalid = [i for i in offspring if not i.fitness.valid]
        for i in invalid:
            i.fitness.values = toolbox.evaluate(i, orig_f, feats)
        pop = toolbox.select(pop + offspring, pop_size)

        # Local search on top 20%
        top = tools.selBest(pop, max(1, int(0.2 * pop_size)))
        for ind in top:
            new_p = local_search(ind, orig_f, feats)
            new_fit = evaluate_individual(new_p, orig_f, feats)
            if (new_fit[0] > ind.fitness.values[0]) or \
               (new_fit[0] == ind.fitness.values[0] and new_fit[1] < ind.fitness.values[1]):
                ind[:] = new_p
                ind.fitness.values = new_fit

    best = tools.selBest(pop, 1)[0]
    enhanced = enhance_image_gpu(orig_f, best)
    return enhanced

# ------------------------------
# 7. Batch Processing
# ------------------------------
if __name__ == "__main__":
    input_root = r"D:\mit-5k-subset\d"
    output_root = r"D:\mit-5k-subset\d_enhanced"
    os.makedirs(output_root, exist_ok=True)

    for subdir, _, files in os.walk(input_root):
        rel_path = os.path.relpath(subdir, input_root)
        output_subdir = os.path.join(output_root, rel_path)
        os.makedirs(output_subdir, exist_ok=True)
        for fname in files:
            if not fname.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff")):
                continue
            in_path = os.path.join(subdir, fname)
            out_path = os.path.join(output_subdir, fname)

            if os.path.exists(out_path):
                print(f"Skipping (already processed): {out_path}")
                continue

            img = cv2.imread(in_path)
            if img is None:
                 print(f"Warning: unable to read {in_path}")
                 continue

            print(f"Enhancing: {in_path}")
            enhanced = run_ga(img, ngen=5, pop_size=60)
            out_img = (enhanced * 255).astype(np.uint8)
            cv2.imwrite(out_path, out_img)
            print(f"Saved enhanced image: {out_path}")






Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0034-LSYD4O2202.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0035-dgw_048.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0291-IMG_0115.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0436-IMG_2583.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0452-IMG_1646.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0463-jmac_DSC2316.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0631-NKIM_MG_6442.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0648-IMG_5085.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0712-_DSC8911.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0752-20061213_134314__MG_3708.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0767-jn_20070824_0165.jpg
Skipping (already processed): D:\mit-5k-subset\d_enhanced\.\a0770-050703_161554__I2E9266.jpg
Skipp