In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
import os
import random

In [None]:
input_img = cv2.imread('C:\\Users\\Asus\\Downloads\\frame_711.jpg')
plt.imshow(input_img)

In [None]:
def calculate_features(image):
    # Create ORB detector
    orb = cv2.ORB_create()

    # Detect keypoints using ORB
    keypoints = orb.detect(image, None)

    # Compute descriptors
    keypoints, descriptors = orb.compute(image, keypoints)

    # Draw the detected keypoints on the image
    output_image = cv2.drawKeypoints(image, keypoints, None, color=(0, 255, 0), flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)

    # Count the number of detected features
    num_features = len(keypoints)
    # print(f"Number of ORB features detected: {num_features}")

    # Display the image with keypoints
    # plt.imshow(output_image)
    # plt.axis('off')
    # plt.show()
    
    return num_features

calculate_features(input_img)

In [None]:
# Define the list of image enhancement algorithms
image_enhancement_algorithms = ['WB','C_Up','C_Down','Bs_Up','B_Down','CLAHE']

# Define state space (number of features in intervals of hundreds)
state_space = ['F0','F1','F2','F3','F4','F5']

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, num_actions):
        super(PolicyNetwork, self).__init__()
        self.dense1 = nn.Linear(1, 32)
        self.dense2 = nn.Linear(32, num_actions)

    def forward(self, state):
        x = F.relu(self.dense1(state))
        return F.softmax(self.dense2(x), dim=-1)

# Instantiate the policy network and move it to CUDA if available
policy_net = PolicyNetwork(len(image_enhancement_algorithms))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_net.to(device)

# Define the optimizer
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.01)

In [None]:
def CLAHE(image):

    if len(image.shape) == 3:
        gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray_image = image

    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))

    clahe_image = clahe.apply(gray_image)

    if len(image.shape) == 3:
        clahe_image = cv2.cvtColor(clahe_image, cv2.COLOR_GRAY2BGR)

    return clahe_image

In [None]:
def white_balance(image):
   
    lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)

    l, a, b = cv2.split(lab_image)

    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    cl = clahe.apply(l)

    balanced_lab_image = cv2.merge((cl, a, b))

    balanced_image = cv2.cvtColor(balanced_lab_image, cv2.COLOR_LAB2BGR)

    return balanced_image

In [None]:
def Contrast_Up(image):
    
    contrasted_image = cv2.convertScaleAbs(image, alpha=4.0, beta=0)
    return contrasted_image

def Contrast_Down(image):
    
    contrasted_image = cv2.convertScaleAbs(image, alpha=0.2, beta=0)
    return contrasted_image

def Brightness_Up(image):
    
    brightened_image = cv2.convertScaleAbs(image, alpha=1.0, beta=150)
    return brightened_image

def Brightness_Down(image):
    
    darkened_image = cv2.convertScaleAbs(image, alpha=1.0, beta=10)
    return darkened_image

In [None]:
def perform_action(ind,img_inp):
        if ind == 0:
            denoised = white_balance(img_inp)
            return denoised
        elif ind == 1:
            denoised = Contrast_Up(img_inp)
            return denoised
        elif ind == 2:
            denoised = Contrast_Down(img_inp)
            return denoised
        elif ind == 3:
            denoised = Brightness_Up(img_inp)
            return denoised
        elif ind == 4:
            denoised = Brightness_Down(img_inp)
            return denoised
        elif ind == 5:
            denoised = CLAHE(img_inp)
            return denoised

In [None]:
def check_state(image):
        num_of_features = calculate_features(image)
        if num_of_features < 0:
            return 'F0'
        elif num_of_features >= 0 and num_of_features < 100 :
            return 'F1'
        elif num_of_features <=200 and num_of_features > 100:
            return 'F2'
        elif num_of_features <=300 and num_of_features > 200:
            return 'F3'
        elif num_of_features <=400 and num_of_features > 300:
            return 'F4'
        elif  num_of_features > 400:
            return 'F5'

In [None]:
def next_state(image ,action):
        next_img = perform_action(action,image)
        return [next_img,check_state(next_img)]

In [None]:
def get_feature_value(ft1):
        i1 = state_space.index(ft1)
        return i1*100

In [None]:
def update_reward(img1,img2):
        feature_difference = calculate_features(img1) - calculate_features(img2)
        
        if feature_difference <0:
            return -5
        elif feature_difference == 0:
            return -1
        elif feature_difference <= 100 and feature_difference > 0:
            return 1
        elif feature_difference <= 200 and feature_difference > 100:
            return 2
        elif feature_difference <= 300 and feature_difference > 200:
            return 3
        elif feature_difference <= 400 and feature_difference > 300:
            return 4
        elif feature_difference > 400 :
            return 5

In [None]:
def train_REINFORCE(num_episodes, discount_factor, input_img):
    cumulative_reward = 0

    for episode in range(num_episodes):
        state = state_space[np.random.choice(len(state_space))]  # Random initial state
        state_tensor = torch.tensor([[get_feature_value(state)]], dtype=torch.float32).to(device)
        curr_image = input_img

        action_probs = policy_net(state_tensor)
        action_distribution = torch.distributions.Categorical(probs=action_probs)
        action = action_distribution.sample().item()
        #action = int(action.numpy())
        
        if action == 0:
            nxt_state =  next_state(curr_image,0)
            next_state_tensor = torch.tensor([[get_feature_value(nxt_state[1])]], dtype=torch.float32)
            den_img = nxt_state[0]
            reward = update_reward(den_img,curr_image)
            cumulative_reward += reward
        
        elif action == 1:
            nxt_state =  next_state(curr_image,1)
            next_state_tensor = torch.tensor([[get_feature_value(nxt_state[1])]], dtype=torch.float32)
            den_img = nxt_state[0]
            reward = update_reward(den_img,curr_image)
            cumulative_reward += reward
            
        elif action == 2:
            nxt_state =  next_state(curr_image,2)
            next_state_tensor = torch.tensor([[get_feature_value(nxt_state[1])]], dtype=torch.float32)
            den_img = nxt_state[0]
            reward = update_reward(den_img,curr_image)
            cumulative_reward += reward
        
        elif action == 3:
            nxt_state =  next_state(curr_image,3)
            next_state_tensor = torch.tensor([[get_feature_value(nxt_state[1])]], dtype=torch.float32)
            den_img = nxt_state[0]
            reward = update_reward(den_img,curr_image)
            cumulative_reward += reward
            
        elif action == 4:
            nxt_state =  next_state(curr_image,4)
            next_state_tensor = torch.tensor([[get_feature_value(nxt_state[1])]], dtype=torch.float32)
            den_img = nxt_state[0]
            reward = update_reward(den_img,curr_image)
            cumulative_reward += reward
        elif action == 5:
            nxt_state =  next_state(curr_image,5)
            next_state_tensor = torch.tensor([[get_feature_value(nxt_state[1])]], dtype=torch.float32)
            den_img = nxt_state[0]
            reward = update_reward(den_img,curr_image)
            cumulative_reward += reward

        
        loss = -torch.log(action_probs[0][action]) * reward

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        curr_image = den_img

        if get_feature_value(check_state(den_img)) - get_feature_value(state) > 400:
            break
        
        print(f"Episode {episode + 1}: State={state}, Action={image_enhancement_algorithms[action]}, Reward={reward}, Cumulative Reward={cumulative_reward}")
        plt.imshow(curr_image)

# Training parameters
num_episodes = 1000
discount_factor = 0.99

# Train the agent
train_REINFORCE(num_episodes, discount_factor, input_img)

In [None]:
def train_REINFORCE(num_episodes, discount_factor, input_img):
    cumulative_reward = 0

    for episode in range(num_episodes):
        state = state_space[np.random.choice(len(state_space))]  # Random initial state
        state_tensor = torch.tensor([[get_feature_value(state)]], dtype=torch.float32).to(device)
        curr_image = input_img
        num_steps = 0

        while True:
            action_probs = policy_net(state_tensor)
            action_distribution = torch.distributions.Categorical(probs=action_probs)
            action = action_distribution.sample().item()
            
            nxt_state, nxt_state_label = next_state(curr_image, action)
            next_state_tensor = torch.tensor([[get_feature_value(nxt_state_label)]], dtype=torch.float32).to(device)
            
            reward = update_reward(nxt_state, curr_image)
            cumulative_reward += reward
            
            loss = -torch.log(action_probs[0][action]) * reward

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            curr_image = nxt_state
            num_steps += 1
            print(num_steps)
            if get_feature_value(nxt_state_label) - get_feature_value(state) > 400:
                break

            if num_steps > 99:
                break

        print(f"Episode {episode + 1}: State={state}, Action={image_enhancement_algorithms[action]}, Reward={reward}, Cumulative Reward={cumulative_reward}")
        plt.imshow(curr_image)
        plt.show()

# Training parameters
num_episodes = 1000
discount_factor = 0.99

# Train the agent
train_REINFORCE(num_episodes, discount_factor, input_img)


In [None]:
plt.imshow(curr_image)