In [9]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import io
from PIL import Image
from skimage.metrics import structural_similarity as ssim

class ImageCompressionEnv(gym.Env):
    def __init__(self, images):
        super(ImageCompressionEnv, self).__init__()
        self.images = images
        self.current_index = 0
        self.current_image = self.images[self.current_index]
        self.compression_ratio = 1.0  # Initialize with a default value
        
        # Define action and observation space
        self.action_space = spaces.Discrete(101)  # Compression levels [0, 100]
        size = self.current_image.size + 1  # Include compression ratio in the state
        self.observation_space = spaces.Box(
            low=0, high=255, shape=(size,), dtype=np.uint8
        )

    def step(self, action):
        compressed_image, compressed_size = self.compress_image(self.current_image, action)
        image = Image.fromarray(self.current_image)
        buffer = io.BytesIO()
        image.save(buffer, format="JPEG", quality=100)
        original_size = buffer.tell()
        max_size = original_size * self.compression_ratio
        
        if compressed_size > max_size:
            reward = 0.0
        else:
            grey_original = np.dot(self.current_image[...,:3], [0.2989, 0.5870, 0.1140]).astype(np.uint8)
            grey_compressed = np.dot(compressed_image[...,:3], [0.2989, 0.5870, 0.1140]).astype(np.uint8)
            reward = ssim(grey_original, grey_compressed, multichannel=True)
        
        self.current_index = (self.current_index + 1) % len(self.images)
        self.current_image = self.images[self.current_index]
        done = self.current_index == 0
        
        # Include compression ratio in the state
        state = np.append(compressed_image.flatten(), self.compression_ratio)
        return state, float(reward), done, {}

    def reset(self):
        self.current_index = 0
        self.current_image = self.images[self.current_index]
        self.compression_ratio = np.random.uniform(0.1, 1.0)  # Random compression ratio between 30% and 100%
        
        # Include compression ratio in the state
        state = np.append(self.current_image.flatten(), self.compression_ratio)
        return state

    def compress_image(self, image: np.ndarray, compression_level):
        image = Image.fromarray(image)
        buffer = io.BytesIO()
        image.save(buffer, format="JPEG", quality=compression_level)
        compressed_size = buffer.tell()
        buffer.seek(0)
        compressed_image = Image.open(io.BytesIO(buffer.getvalue()))
        compressed_image = np.array(compressed_image)
        return compressed_image, compressed_size

# Example usage
images = [np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8) for _ in range(100)]
env = ImageCompressionEnv(images)


12288
<class 'int'>


In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

class PPO(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PPO, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.policy_layer = nn.Linear(128, output_dim)
        self.value_layer = nn.Linear(128, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        policy = self.policy_layer(x)
        value = self.value_layer(x)
        return policy, value

    def get_action(self, state):
        policy, _ = self.forward(state)
        policy_dist = Categorical(logits=policy)
        action = policy_dist.sample()
        return action.item(), policy_dist.log_prob(action), policy_dist.entropy()

    def evaluate_action(self, state, action):
        policy, value = self.forward(state)
        policy_dist = Categorical(logits=policy)
        action_log_probs = policy_dist.log_prob(action)
        dist_entropy = policy_dist.entropy()
        return action_log_probs, torch.squeeze(value), dist_entropy

def compute_gae(rewards, masks, values, gamma=0.99, tau=0.95):
    returns = []
    gae = 0
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * tau * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

def ppo_update(agent, optimizer, trajectories, clip_param=0.2):
    states = torch.stack([trajectory[0] for trajectory in trajectories])
    actions = torch.tensor([trajectory[1] for trajectory in trajectories])
    log_probs_old = torch.tensor([trajectory[2] for trajectory in trajectories])
    returns = torch.tensor([trajectory[3] for trajectory in trajectories])
    advantages = returns - torch.tensor([trajectory[4] for trajectory in trajectories])
    
    for _ in range(4):
        log_probs, state_values, dist_entropy = agent.evaluate_action(states, actions)
        ratio = torch.exp(log_probs - log_probs_old)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = (returns - state_values).pow(2).mean()
        loss = policy_loss + 0.5 * value_loss - 0.01 * dist_entropy.mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Example usage
input_dim = 64 * 64 * 3 + 1  # Example image size plus the compression ratio
output_dim = 101  # Compression levels from 0 to 100
agent = PPO(input_dim, output_dim)
optimizer = optim.Adam(agent.parameters(), lr=3e-4)





In [11]:
def train(env, agent, optimizer, num_episodes=100, gamma=0.99, clip_param=0.2):
    all_rewards = []
    for episode in range(num_episodes):
        state = env.reset()
        state = torch.FloatTensor(state).unsqueeze(0)
        episode_reward = 0
        done = False
        trajectories = []

        while not done:
            action, log_prob, _ = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            next_state = torch.FloatTensor(next_state).unsqueeze(0)

            mask = 1 if not done else 0
            trajectories.append((state, action, log_prob, reward, mask))

            state = next_state
            episode_reward += reward

        all_rewards.append(episode_reward)

        states = torch.stack([trajectory[0] for trajectory in trajectories])
        actions = torch.tensor([trajectory[1] for trajectory in trajectories])
        log_probs = torch.tensor([trajectory[2] for trajectory in trajectories])
        rewards = torch.tensor([trajectory[3] for trajectory in trajectories])
        masks = torch.tensor([trajectory[4] for trajectory in trajectories])

        with torch.no_grad():
            _, next_value = agent(states[-1])
            values = torch.cat([agent(states)[1], next_value.unsqueeze(0)])

        returns = compute_gae(rewards, masks, values, gamma)
        ppo_update(agent, optimizer, trajectories, clip_param)

        if episode % 10 == 0:
            print(f"Episode {episode}, Reward: {episode_reward}")

    return all_rewards

# Example usage
train(env, agent, optimizer)

Episode 0, Reward: 69.60677360535132
Episode 10, Reward: 88.79381830309961
Episode 20, Reward: 0.0
Episode 30, Reward: 88.79381830309961
Episode 40, Reward: 88.79381830309961
Episode 50, Reward: 88.79381830309961
Episode 60, Reward: 88.79381830309961
Episode 70, Reward: 88.79381830309961
Episode 80, Reward: 88.79381830309961
Episode 90, Reward: 88.79381830309961


[69.60677360535132,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 0.0,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 0.0,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 0.0,
 0.0,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 0.0,
 0.0,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 0.0,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 0.0,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 0.0,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79381830309961,
 88.79

In [12]:
def evaluate(env, agent, num_episodes=100):
    all_rewards = []
    for _ in range(num_episodes):
        state = env.reset()
        state = torch.FloatTensor(state).unsqueeze(0)
        episode_reward = 0
        done = False

        while not done:
            action, _, _ = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            next_state = torch.FloatTensor(next_state).unsqueeze(0)
            state = next_state
            episode_reward += reward

        all_rewards.append( episode_reward )
    return all_rewards

# Example usage
evaluation_rewards = evaluate(env, agent)
print(f"Average evaluation reward: {np.mean(evaluation_rewards)}")
