In [1]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import io
from PIL import Image
from skimage.metrics import structural_similarity as ssim
import os
import logging
import cv2

MAX_INT = 2**31 - 1
logger = logging.getLogger("ENOFT_LOGER")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("model.json")
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)

class ImageCompressionEnv(gym.Env):
    def __init__(self, images, org_images):
        super(ImageCompressionEnv, self).__init__()
        self.images = images
        self.current_index = 0
        self.current_image = self.images[self.current_index]
        self.org_images = org_images
        self.original_size = 0
        self.compression_ratio = 1.0  # Initialize with a default value

        # Define action and observation space
        self.action_space = spaces.MultiDiscrete([101, 65536, 101, 101, 5])
        feature_dim = len(self.current_image) + 1  # Include compression ratio in the state
        self.observation_space = spaces.Box(
            low=0, high=np.inf, shape=(feature_dim,), dtype=np.float32
        )

    def step(self, action):
        compressed_image, compressed_size = self.compress_image(self.org_images[self.current_index], action)
        max_size = self.original_size * self.compression_ratio / 100

        grey_original = cv2.cvtColor(self.org_images[self.current_index], cv2.COLOR_BGR2GRAY)
        grey_compressed = cv2.cvtColor(compressed_image, cv2.COLOR_BGR2GRAY)
        got_ssim = ssim(grey_original, grey_compressed, multichannel=True)

        # Weighted reward calculation
        ssim_weight = 0.7
        size_weight = 30
        size_penalty = (compressed_size - max_size) / max_size if compressed_size > max_size else 0

        reward = ssim_weight * got_ssim - size_weight * size_penalty

        self.current_index = (self.current_index + 1) % len(self.images)
        done = self.current_index == 0

        state = self.update_environment_state()
        log = {
            "original_size": self.original_size,
            "compressed_size": compressed_size,
            "max_size": max_size,
            "got_ssim": got_ssim,
            "reward": reward,
            "action": action
        }
        logger.info(log)
        return state, float(reward), done, {}

    def update_environment_state(self):
        self.current_image = self.images[self.current_index]
        org_img = self.org_images[self.current_index]
        success, buffer = cv2.imencode('.jpg', org_img, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
        buffer = io.BytesIO(buffer)
        self.original_size = len(buffer.getvalue())
        self.compression_ratio = int(np.random.uniform(30, 100))
        state = np.append(self.current_image, self.compression_ratio)
        return state

    def reset(self):
        self.current_index = 0
        return self.update_environment_state()

    def compress_image(self, image: np.ndarray, action):
        compression_level, rst_interval, luma_quality, chroma_quality, sampling_factor = action

        sampling_map = {
            0: cv2.IMWRITE_JPEG_SAMPLING_FACTOR_411,
            1: cv2.IMWRITE_JPEG_SAMPLING_FACTOR_420,
            2: cv2.IMWRITE_JPEG_SAMPLING_FACTOR_422,
            3: cv2.IMWRITE_JPEG_SAMPLING_FACTOR_440,
            4: cv2.IMWRITE_JPEG_SAMPLING_FACTOR_444
        }
        sampling_factor = np.clip(sampling_factor, 0, 4)
        sampling = sampling_map[sampling_factor]

        encode_params = [
            int(cv2.IMWRITE_JPEG_QUALITY), compression_level,
            int(cv2.IMWRITE_JPEG_RST_INTERVAL), rst_interval,
            int(cv2.IMWRITE_JPEG_LUMA_QUALITY), luma_quality,
            int(cv2.IMWRITE_JPEG_CHROMA_QUALITY), chroma_quality,
            int(cv2.IMWRITE_JPEG_SAMPLING_FACTOR), sampling
        ]
        success, buffer = cv2.imencode('.jpg', image, encode_params)

        if not success:
            raise ValueError("Failed to compress image")

        buffer = io.BytesIO(buffer)
        compressed_size = len(buffer.getvalue())

        buffer.seek(0)
        compressed_image = Image.open(buffer)
        compressed_image = np.array(compressed_image)

        return compressed_image, compressed_size


png_folder = "datasets/kaggle_Kodak/train"
images = []
org_images = []
for filename in os.listdir(png_folder):
    img_path = os.path.join(png_folder, filename)
    img = cv2.imread(img_path)
    feature_vector = []

    # Add height and width
    h, w, c = img.shape
    feature_vector.append(h)
    feature_vector.append(w)

    # Add aspect ratio
    feature_vector.append(h / w)

    # Add mean and std of each channel
    for i in range(c):
        mean = np.mean(img[:, :, i])
        std = np.std(img[:, :, i])
        feature_vector.append(mean)
        feature_vector.append(std)

    # Add entropy
    hist, _ = np.histogram(img.flatten(), bins=256, range=[0, 256])
    hist = hist / np.sum(hist)
    entropy = -np.sum(hist * np.log2(hist + 1e-6))
    feature_vector.append(entropy)

    # Add edge density
    edges = cv2.Canny(img, 100, 200)
    edge_density = np.mean(edges)
    feature_vector.append(edge_density)

    # Add brightness
    brightness = np.mean(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY))
    feature_vector.append(brightness)

    # Add contrast
    contrast = np.std(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY))
    feature_vector.append(contrast)

    # Add sharpness
    sharpness = np.mean(cv2.Laplacian(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), cv2.CV_64F))
    feature_vector.append(sharpness)

    # Add saturation
    saturation = np.mean(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)[:, :, 1])
    feature_vector.append(saturation)

    images.append(feature_vector)
    org_images.append(img)

env = ImageCompressionEnv(images, org_images)


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

# Proximal Policy Optimization
class PPO(nn.Module):
    def __init__(self, input_dim, output_dims):
        super(PPO, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 128)
        self.policy_layers = nn.ModuleList([
            nn.Linear(128, 101),    # For first action dimension (101 values)
            nn.Linear(128, 65536),  # For second action dimension (65536 values)
            nn.Linear(128, 101),    # For third action dimension (101 values)
            nn.Linear(128, 101),    # For fourth action dimension (101 values)
            nn.Linear(128, 5)       # For fifth action dimension (5 values)
        ])
        self.value_layer = nn.Linear(128, 1)
        self.output_dims = output_dims  # List of number of action values per dimension
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        policy_logits = [layer(x) for layer in self.policy_layers]
        value = self.value_layer(x)
        return policy_logits, value

    def get_action(self, state):
        policy_logits, _ = self.forward(state)
        actions = []
        log_probs = []
        entropies = []
        for logits in policy_logits:
            policy_dist = Categorical(logits=logits)
            action = policy_dist.sample()
            actions.append(action.item())
            log_probs.append(policy_dist.log_prob(action))
            entropies.append(policy_dist.entropy())
        return actions, torch.stack(log_probs), torch.stack(entropies)

    def evaluate_action(self, state, action):
        policy_logits, value = self.forward(state)
        log_probs = []
        entropies = []
        for i, logits in enumerate(policy_logits):
            policy_dist = Categorical(logits=logits)
            log_prob = policy_dist.log_prob(action[:, i])
            log_probs.append(log_prob)
            entropies.append(policy_dist.entropy())
        return torch.stack(log_probs, dim=1), torch.squeeze(value), torch.stack(entropies, dim=1)

def compute_gae(rewards, masks, values, gamma, lam=0.95):
    values = values.squeeze()
    returns = []
    gae = 0
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * lam * masks[step] * gae
        returns.insert(0, gae + values[step])
    return torch.tensor(returns)

def ppo_update(agent, optimizer, trajectories, clip_param, gamma):
    states = torch.cat([trajectory[0] for trajectory in trajectories]).detach()
    actions = torch.stack([torch.tensor(trajectory[1]) for trajectory in trajectories]).detach()
    log_probs = torch.stack([trajectory[2] for trajectory in trajectories]).detach()
    rewards = torch.tensor([trajectory[3] for trajectory in trajectories], dtype=torch.float32)
    masks = torch.tensor([trajectory[4] for trajectory in trajectories], dtype=torch.float32)
    
    with torch.no_grad():
        _, values = agent(states)
        values = values.squeeze()
    values = torch.cat((values, torch.zeros(1, dtype=values.dtype)))  # Add the value for the final state
    returns = compute_gae(rewards, masks, values, gamma=gamma)
    
    advantages = returns - values[:-1]
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10)
    
    for _ in range(4):  # Optimize policy for K epochs
        policy_logits, _ = agent(states)
        
        new_log_probs = []
        for i, logits in enumerate(policy_logits):
            dist = Categorical(logits=logits)
            new_log_prob = dist.log_prob(actions[:, i])
            new_log_probs.append(new_log_prob)
        new_log_probs = torch.stack(new_log_probs, dim=1).sum(dim=1)
        
        ratio = torch.exp(new_log_probs - log_probs.sum(dim=1))
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        
        value_loss = nn.MSELoss()(returns, values[:-1])
        loss = policy_loss + 0.5 * value_loss - 0.01 * new_log_probs.mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [3]:
def train(env, agent, optimizer, num_episodes=500, gamma=0.99, clip_param=0.2):
    all_rewards = []
    for episode in range(num_episodes):
        state = env.reset()
        state = torch.FloatTensor(state).unsqueeze(0)
        episode_reward = 0
        done = False
        trajectories = []

        while not done:
            action, log_prob, _ = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            next_state = torch.FloatTensor(next_state).unsqueeze(0)

            mask = 1 if not done else 0
            trajectories.append((state, action, log_prob, reward, mask))

            state = next_state
            episode_reward += reward

        all_rewards.append(episode_reward)

        ppo_update(agent, optimizer, trajectories, clip_param, gamma)

        if episode % 10 == 0:
            print(f"Episode {episode}, Reward: {episode_reward}")
            torch.save(agent.state_dict(), f"model_episode_{episode}_reward_{episode_reward}.pth")
            
    return all_rewards

input_dim = 15 + 1  # Example feature size plus the compression ratio
output_dims = [101, 65536, 101, 101, 5]  # Matching the action space dimensions
agent = PPO(input_dim, len(output_dims))
optimizer = optim.Adam(agent.parameters(), lr=3e-4)
train(env, agent, optimizer)


Episode 0, Reward: 7.536042101004448
Episode 10, Reward: 12.20064886701238
Episode 20, Reward: -6.4791647041845515
Episode 30, Reward: 10.161764210289704
Episode 40, Reward: 10.064753092589733
Episode 50, Reward: 10.31331581751886
Episode 60, Reward: 8.413726598551383
Episode 70, Reward: 10.090268042313461
Episode 80, Reward: 9.879557151480846
Episode 90, Reward: 10.617147391162721
Episode 100, Reward: 9.504449301906616
Episode 110, Reward: -31.309742075105966
Episode 120, Reward: -98.0719115759418
Episode 130, Reward: -81.77643241181316
Episode 140, Reward: -104.25828426128892
Episode 150, Reward: -20.82367630145023
Episode 160, Reward: -23.31499948305776
Episode 170, Reward: -9.00556477990353
Episode 180, Reward: -42.69874558275102
Episode 190, Reward: -38.591498029570815
Episode 200, Reward: -77.80231287273577
Episode 210, Reward: -61.56921916997036
Episode 220, Reward: 3.3864150082949775
Episode 230, Reward: -31.638835019312133
Episode 240, Reward: -47.376460164356295
Episode 250, 

[7.536042101004448,
 7.621902477993324,
 9.064770315127621,
 7.567444506816582,
 10.890478730958653,
 11.582953947478892,
 10.79826918221752,
 10.551752818026841,
 12.018727300333406,
 12.15389888090696,
 12.20064886701238,
 7.185988840938401,
 7.8594551637369845,
 12.045695826930801,
 -1.824278766791597,
 2.5883785441828744,
 7.844491641230423,
 12.168036292075525,
 -18.948835977460234,
 1.2480700057271061,
 -6.4791647041845515,
 -22.8475457621455,
 10.713374912306296,
 1.344336567141489,
 10.895980402542776,
 10.457935993131205,
 9.023744224149057,
 9.979481599278946,
 -3.3801551009560806,
 2.632414933450627,
 10.161764210289704,
 9.507143618166939,
 9.794860820176895,
 10.977390290861278,
 9.586443295563889,
 6.8682802842090664,
 10.1018064442513,
 10.283384197610406,
 9.669269448443519,
 10.613396330811394,
 10.064753092589733,
 10.218713627366824,
 8.971634452468754,
 9.201877605931427,
 9.973876636556831,
 10.188565607333825,
 10.477913294959501,
 9.625081884367356,
 10.071473358

In [4]:
def evaluate(env, agent, num_episodes=10):
    all_rewards = []
    for _ in range(num_episodes):
        state = env.reset()
        state = torch.FloatTensor(state).unsqueeze(0)
        episode_reward = 0
        done = False

        while not done:
            action, _, _ = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            next_state = torch.FloatTensor(next_state).unsqueeze(0)
            state = next_state
            episode_reward += reward

        all_rewards.append( episode_reward )
    return all_rewards

# Example usage
# evaluation_rewards = evaluate(env, agent)
# print(f"Average evaluation reward: {np.mean(evaluation_rewards)}")


In [5]:
def save_model(agent, path):
    torch.save(agent.state_dict(), path)
    
def load_model(agent, path):
    agent.load_state_dict(torch.load(path))
    return agent

# Example usage
save_model(agent, "model.pth")
# loaded_agent = PPO(input_dim, output_dim)
# loaded_agent = load_model(loaded_agent, "model.pth")
# evaluation_rewards = evaluate(env, loaded_agent)
