In [None]:
# final_caching.py

import gymnasium as gym

from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque, OrderedDict, defaultdict
import matplotlib.pyplot as plt
import sys
import os




# Reproducible seeds

GLOBAL_SEED = 42

random.seed(GLOBAL_SEED)

np.random.seed(GLOBAL_SEED)

torch.manual_seed(GLOBAL_SEED)



if torch.cuda.is_available():
    torch.cuda.manual_seed_all(GLOBAL_SEED)




class CacheEnv(gym.Env):
    
    
    def __init__(self, cache_size=8, memory_size=64, sequence_length=100):
        
        super().__init__()
        
        self.cache_size = cache_size
        self.memory_size = memory_size
        
        
        self.sequence_length = sequence_length


        self.cache = OrderedDict()
        self.access_sequence = []
        
        self.fixed_sequences = {}
        self.current_step = 0


        self.action_space = spaces.Discrete(cache_size)
        
        # state: cache contents (cache_size), freq (cache_size), recency (cache_size), next_addr, 3 future addrs
        
        self.observation_space = spaces.Box(low=-1.0, high=float(memory_size),
                                           shape=(cache_size * 3 + 4,), dtype=np.float32)


        self.hits = 0
        self.misses = 0
        
        self.total_accesses = 0
        
        self.access_freq = defaultdict(int)
        self.last_access_time = {}

        self._generate_fixed_sequences()



    def _generate_fixed_sequences(self):
        
        for pattern in ['easy', 'medium', 'hard', 'mixed']:
            
            self.fixed_sequences[pattern] = []
            
            for i in range(50):
                
                self.fixed_sequences[pattern].append(self._generate_access_pattern(pattern, seed=i))



    def _generate_access_pattern(self, pattern_type='mixed', seed=None):
        
        if seed is not None:
            
            old_state = random.getstate()
            random.seed(seed)

        seq = []
        
        if pattern_type == 'easy':
            # loop size should not exceed cache capacity
            loop_size = max(1, self.cache_size)
            seq = [i % loop_size for i in range(self.sequence_length)]
            
        elif pattern_type == 'sequential':
            seq = [i % self.memory_size for i in range(self.sequence_length)]
            
        elif pattern_type == 'loop':
            loop_size = min(10, self.memory_size)
            seq = [i % loop_size for i in range(self.sequence_length)]
            
        elif pattern_type == 'random':
            seq = [random.randint(0, self.memory_size - 1) for _ in range(self.sequence_length)]
            
        elif pattern_type == 'hard':
            seq = []
            phases = max(1, self.sequence_length // 25)
            for _ in range(phases):
                phase_set = random.sample(range(self.memory_size), min(15, self.memory_size))
                for _ in range(25):
                    seq.append(random.choice(phase_set))
            seq = seq[:self.sequence_length]
            
        elif pattern_type == 'medium':
            
            hot_set = list(range(min(8, self.memory_size)))
            
            warm_set = list(range(min(8, self.memory_size), min(20, self.memory_size)))
            
            cold_set = list(range(min(20, self.memory_size), self.memory_size))
            
            
            for _ in range(self.sequence_length):
                
                r = random.random()
                
                if r < 0.5 and hot_set:
                    seq.append(random.choice(hot_set))
                elif r < 0.8 and warm_set:
                    seq.append(random.choice(warm_set))
                else:
                    seq.append(random.choice(cold_set) if cold_set else random.randint(0, self.memory_size - 1))
                    
        else:  # mixed
            
            hot_set = list(range(min(6, self.memory_size)))
            
            warm_set = list(range(min(6, self.memory_size), min(12, self.memory_size)))
            
            cold_set = list(range(min(12, self.memory_size), self.memory_size))
            
            for _ in range(self.sequence_length):
                
                r = random.random()
                
                if r < 0.6 and hot_set:
                    seq.append(random.choice(hot_set))
                    
                elif r < 0.85 and warm_set:
                    seq.append(random.choice(warm_set))
                    
                else:
                    seq.append(random.choice(cold_set) if cold_set else random.randint(0, self.memory_size - 1))
                    
                    

        if seed is not None:
            random.setstate(old_state)
            
            
        return seq




    def _get_state(self):
        
        cache_contents = list(self.cache.keys())
        cache_contents = cache_contents[:self.cache_size]
        
        cache_contents += [-1] * (self.cache_size - len(cache_contents))



        frequencies = []
        
        
        recencies = []
        
        
        for addr in cache_contents:
            
            if addr == -1:
                
                frequencies.append(0.0)
                recencies.append(0.0)
                
            else:
                
                freq = self.access_freq.get(addr, 0) / 10.0
                
                recency = (self.current_step - self.last_access_time.get(addr, self.current_step)) / 50.0
                
                frequencies.append(min(freq, 1.0))
                
                recencies.append(min(recency, 1.0))



        next_addr = -1
        
        if self.current_step < len(self.access_sequence):
            
            next_addr = self.access_sequence[self.current_step]
            
            

        future_addrs = []
        
        
        for i in range(1, 4):
            
            idx = self.current_step + i
            
            if idx < len(self.access_sequence):
                future_addrs.append(self.access_sequence[idx])
            else:
                future_addrs.append(-1)



        state = cache_contents + frequencies + recencies + [next_addr] + future_addrs
        return np.array(state, dtype=np.float32)




    def reset(self, seed=None, options=None):
        
        # choose fixed sequence if requested, otherwise generate a random one
        
        if options and 'sequence_idx' in options:
            
            pattern = options.get('pattern', 'mixed')
            seq_idx = int(options['sequence_idx'])
            seq_list = self.fixed_sequences.get(pattern, [])
            self.access_sequence = seq_list[seq_idx % len(seq_list)]
            
        else:
            pattern = options.get('pattern', 'mixed') if options else 'mixed'
            self.access_sequence = self._generate_access_pattern(pattern)



        # clear internal state
        
        self.cache = OrderedDict()
        self.current_step = 0
        
        self.hits = 0
        self.misses = 0
        
        self.total_accesses = 0
        self.access_freq = defaultdict(int)
        self.last_access_time = {}
        
        
        return self._get_state(), {}




    def step(self, action):
        
        
        if self.current_step >= len(self.access_sequence):
            
            return self._get_state(), 0.0, True, False, {}



        addr = self.access_sequence[self.current_step]
        self.access_freq[addr] += 1
        
        self.last_access_time[addr] = self.current_step


        reward = 0.0
        
        if addr in self.cache:
            self.hits += 1
            reward = 1.0
            self.cache.move_to_end(addr)
            
        else:
            
            self.misses += 1
            reward = -1.0
            
            if len(self.cache) >= self.cache_size:
                cache_list = list(self.cache.keys())
                
                if 0 <= action < len(cache_list):
                    evict_addr = cache_list[action]
                    del self.cache[evict_addr]
                else:
                    self.cache.popitem(last=False)
                    
            self.cache[addr] = 1



        self.total_accesses += 1
        self.current_step += 1
        
        done = self.current_step >= len(self.access_sequence)
        
        return self._get_state(), float(reward), done, False, {}





# Baseline policies

class LRUPolicy:
    
    def __init__(self, cache_size):
        
        self.cache = OrderedDict()
        self.cache_size = cache_size



    def access(self, addr):
        
        hit = addr in self.cache
        
        if hit:
            
            self.cache.move_to_end(addr)
            return True
        
        if len(self.cache) >= self.cache_size:
            
            self.cache.popitem(last=False)
            
        self.cache[addr] = 1
        
        return False



    def reset(self):
        self.cache = OrderedDict()





class FIFOPolicy:
    
    def __init__(self, cache_size):
        
        self.cache = OrderedDict()
        self.cache_size = cache_size


    def access(self, addr):
        
        if addr in self.cache:
            return True
        if len(self.cache) >= self.cache_size:
            self.cache.popitem(last=False)
            
        self.cache[addr] = 1
        
        return False



    def reset(self):
        self.cache = OrderedDict()




class LFUPolicy:
    
    def __init__(self, cache_size):
        
        self.cache = {}
        self.freq = {}
        
        self.cache_size = cache_size
        
        

    def access(self, addr):
        
        if addr in self.cache:
            self.freq[addr] += 1
            return True
        
        if len(self.cache) >= self.cache_size:
            
            min_addr = min(self.freq, key=self.freq.get)
            del self.cache[min_addr]
            del self.freq[min_addr]
            
            
        self.cache[addr] = 1
        
        self.freq[addr] = 1
        return False



    def reset(self):
        self.cache = {}
        self.freq = {}





# Networks and agents

class DQNetwork(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        
        super().__init__()
        
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        
        

    def forward(self, x):
        
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        
        return self.fc3(x)




class ReplayMemory:
    
    
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
        

    def push(self, *args):
        self.memory.append(tuple(args))
        

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    

    def __len__(self):
        return len(self.memory)
    
    


class DQNAgent:
    
    def __init__(self, state_dim, action_dim, learning_rate=1e-3, gamma=0.95,
                 epsilon_start=1.0, epsilon_end=0.01, epsilon_decay_steps=50000,
                 replay_memory_size=50000, batch_size=64, target_update_freq=500):
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.action_dim = action_dim
        
        self.gamma = gamma



        self.epsilon = epsilon_start
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        
        
        self.epsilon_decay_steps = epsilon_decay_steps


        self.batch_size = batch_size
        
        self.target_update_freq = target_update_freq
        

        self.q_network = DQNetwork(state_dim, action_dim).to(self.device)
        
        self.target_network = DQNetwork(state_dim, action_dim).to(self.device)
        
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        self.target_network.eval()
        

        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
        
        self.replay_memory = ReplayMemory(replay_memory_size)
        
        self.steps = 0



    def select_action(self, state, evaluation=False):
        
        if evaluation or random.random() > self.epsilon:
            
            with torch.no_grad():
                
                st = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                q = self.q_network(st)
                
                return int(q.argmax().item())
            
        else:
            
            return random.randrange(self.action_dim)



    def store_transition(self, state, action, reward, next_state, done):
        self.replay_memory.push(state, action, reward, next_state, float(done))


    def update_epsilon(self):
        self.steps += 1
        
        fraction = min(1.0, self.steps / max(1, self.epsilon_decay_steps))
        self.epsilon = self.epsilon_start + fraction * (self.epsilon_end - self.epsilon_start)
        

    def optimize_model(self):
        
        if len(self.replay_memory) < self.batch_size:
            return None
        transitions = self.replay_memory.sample(self.batch_size)
        s, a, r, s2, d = zip(*transitions)
        
        
        state_b = torch.FloatTensor(np.array(s)).to(self.device)
        action_b = torch.LongTensor(a).unsqueeze(1).to(self.device)
        
        reward_b = torch.FloatTensor(r).unsqueeze(1).to(self.device)
        
        next_b = torch.FloatTensor(np.array(s2)).to(self.device)
        done_b = torch.FloatTensor(d).unsqueeze(1).to(self.device)


        current_q = self.q_network(state_b).gather(1, action_b)
        
        with torch.no_grad():
            
            max_next_q = self.target_network(next_b).max(1)[0].unsqueeze(1)
            target_q = reward_b + (1 - done_b) * self.gamma * max_next_q


        loss = nn.MSELoss()(current_q, target_q)
        self.optimizer.zero_grad()
        
        loss.backward()
        
        self.optimizer.step()
        
        return float(loss.item())


    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())





class DoubleDQNAgent(DQNAgent):
    
    def optimize_model(self):
        
        if len(self.replay_memory) < self.batch_size:
            
            return None
        
        
        transitions = self.replay_memory.sample(self.batch_size)
        s, a, r, s2, d = zip(*transitions)
        
        state_b = torch.FloatTensor(np.array(s)).to(self.device)
        action_b = torch.LongTensor(a).unsqueeze(1).to(self.device)
        reward_b = torch.FloatTensor(r).unsqueeze(1).to(self.device)
        
        
        next_b = torch.FloatTensor(np.array(s2)).to(self.device)
        done_b = torch.FloatTensor(d).unsqueeze(1).to(self.device)


        current_q = self.q_network(state_b).gather(1, action_b)
        
        with torch.no_grad():
            
            next_actions = self.q_network(next_b).argmax(1).unsqueeze(1)
            next_q = self.target_network(next_b).gather(1, next_actions)
            
            target_q = reward_b + (1 - done_b) * self.gamma * next_q



        loss = nn.MSELoss()(current_q, target_q)
        
        self.optimizer.zero_grad()
        loss.backward()
        
        self.optimizer.step()
        
        return float(loss.item())
    
    


# Training and evaluation

def train_agent(env, agent, num_episodes=700, max_steps=100, print_every=50):
    
    episode_rewards = []
    epsilon_history = []
    hit_rates = []


    for episode in range(num_episodes):
        
        state, _ = env.reset()
        episode_reward = 0.0


        for step in range(max_steps):
            
            action = agent.select_action(state)
            
            next_state, reward, done, truncated, _ = env.step(action)
            
            agent.store_transition(state, action, reward, next_state, done)
            
            agent.optimize_model()
            agent.update_epsilon()



            if agent.steps % agent.target_update_freq == 0:
                
                agent.update_target_network()


            state = next_state
            
            episode_reward += reward
            
            if done or truncated:
                break



        episode_rewards.append(episode_reward)
        epsilon_history.append(agent.epsilon)
        
        hit_rate = env.hits / max(1, env.total_accesses) * 100
        hit_rates.append(hit_rate)



        if (episode + 1) % print_every == 0 or episode == num_episodes - 1:
            mean_reward = np.mean(episode_rewards[-print_every:])
            mean_hit = np.mean(hit_rates[-print_every:])
            print(f"Ep {episode+1}/{num_episodes} | Reward (last {print_every}): {mean_reward:.2f} | Hit: {mean_hit:.2f}% | ε: {agent.epsilon:.3f}", flush=True)


    return episode_rewards, epsilon_history



def evaluate_agent(env, agent, pattern='mixed', num_episodes=20):
    
    eval_rewards = []
    hit_rates = []
    
    for i in range(num_episodes):
        
        state, _ = env.reset(options={'pattern': pattern, 'sequence_idx': i})
        
        episode_reward = 0.0
        done = False
        steps = 0
        
        while not done and steps < env.sequence_length:
            
            action = agent.select_action(state, evaluation=True)
            next_state, reward, done, truncated, _ = env.step(action)
            
            episode_reward += reward
            state = next_state
            steps += 1
            
            if truncated:
                break
            
        hit_rates.append(env.hits / max(1, env.total_accesses) * 100)
        
        eval_rewards.append(episode_reward)
        
    return eval_rewards, hit_rates






def evaluate_baseline(env, policy, pattern='mixed', num_episodes=20):
    
    rewards = []
    hit_rates = []
    
    for i in range(num_episodes):
        
        # reset environment and policy so no stale state
        state, _ = env.reset(options={'pattern': pattern, 'sequence_idx': i})
        policy.reset()



        total_reward = 0.0
        hits = 0
        total = 0



        # iterate the access sequence produced by the environment
        
        for addr in env.access_sequence:
            
            is_hit = policy.access(addr)
            
            
            if is_hit:
                hits += 1
                total_reward += 1.0
            else:
                total_reward -= 1.0
            total += 1



        rewards.append(total_reward)
        hit_rates.append(hits / max(1, total) * 100)
        
    return rewards, hit_rates





def plot_results(dqn_rewards, ddqn_rewards, dqn_epsilon, ddqn_epsilon, results, out='cache_final_results.png'):
    
    
    fig = plt.figure(figsize=(18, 9))
    ax1 = plt.subplot(2, 3, 1)
    
    window = 50
    
    if len(dqn_rewards) >= window:
        dqn_ma = np.convolve(dqn_rewards, np.ones(window) / window, mode='valid')
        ax1.plot(dqn_ma, label='DQN')
        
    if len(ddqn_rewards) >= window:
        ddqn_ma = np.convolve(ddqn_rewards, np.ones(window) / window, mode='valid')
        ax1.plot(ddqn_ma, label='Double DQN')
        
        
    ax1.set_xlabel('Episode')
    ax1.set_ylabel('Reward (moving avg)')
    
    ax1.set_title('Training Rewards')
    ax1.legend()
    
    ax1.grid(True, alpha=0.3)

    ax2 = plt.subplot(2, 3, 2)
    ax2.plot(dqn_epsilon, label='DQN')
    ax2.plot(ddqn_epsilon, label='Double DQN')
    
    ax2.set_xlabel('Episode')
    ax2.set_ylabel('Epsilon')
    ax2.set_title('Exploration Rate')
    ax2.legend()
    ax2.grid(True, alpha=0.3)


    for idx, pattern in enumerate(['easy', 'medium', 'hard', 'mixed']):
        
        ax = plt.subplot(2, 3, idx + 3)
        methods = ['LRU', 'FIFO', 'LFU', 'DQN', 'DDQN']
        
        means = [results[pattern][m][0] for m in methods]
        stds = [results[pattern][m][1] for m in methods]
        
        bars = ax.bar(methods, means, yerr=stds, capsize=5, alpha=0.8)
        ax.set_ylabel('Hit Rate (%)')
        
        ax.set_title(f'{pattern.capitalize()} Pattern')
        ax.set_ylim(0, 100)
        ax.grid(True, axis='y', alpha=0.3)
        
        
        for bar, mean in zip(bars, means):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width() / 2., height + 1.0, f'{mean:.1f}%',
                    ha='center', va='bottom', fontsize=9)


    plt.tight_layout()
    plt.savefig(out, dpi=300, bbox_inches='tight')
    
    plt.close(fig)
    print(f"✓ Results saved to '{out}'")




def train_cache_comparison(num_episodes=700):
    env = CacheEnv(cache_size=8, memory_size=64, sequence_length=100)
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n



    print("CACHE REPLACEMENT RL - TRAINING")
   
    
    print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
    print(f"Cache: {env.cache_size} | Memory: {env.memory_size} | Seq: {env.sequence_length}")
    
    print("=" * 60, flush=True)

    dqn_agent = DQNAgent(state_dim, action_dim, learning_rate=1e-3, gamma=0.95,
                         epsilon_start=1.0, epsilon_end=0.01, epsilon_decay_steps=50000,
                         replay_memory_size=30000, batch_size=64, target_update_freq=300)
    dqn_rewards, dqn_epsilon = train_agent(env, dqn_agent, num_episodes, max_steps=env.sequence_length, print_every=50)

    ddqn_agent = DoubleDQNAgent(state_dim, action_dim, learning_rate=1e-3, gamma=0.95,
                                epsilon_start=1.0, epsilon_end=0.01, epsilon_decay_steps=50000,
                                replay_memory_size=30000, batch_size=64, target_update_freq=300)
    ddqn_rewards, ddqn_epsilon = train_agent(env, ddqn_agent, num_episodes, max_steps=env.sequence_length, print_every=50)


    results = {}
    
    for pattern in ['easy', 'medium', 'hard', 'mixed']:
        
        lru_eval, lru_hits = evaluate_baseline(env, LRUPolicy(env.cache_size), pattern)
        fifo_eval, fifo_hits = evaluate_baseline(env, FIFOPolicy(env.cache_size), pattern)
        
        lfu_eval, lfu_hits = evaluate_baseline(env, LFUPolicy(env.cache_size), pattern)
        dqn_eval, dqn_hits = evaluate_agent(env, dqn_agent, pattern)
        ddqn_eval, ddqn_hits = evaluate_agent(env, ddqn_agent, pattern)


        lru_mean, lru_std = np.mean(lru_hits), np.std(lru_hits)
        fifo_mean, fifo_std = np.mean(fifo_hits), np.std(fifo_hits)
        
        lfu_mean, lfu_std = np.mean(lfu_hits), np.std(lfu_hits)
        
        dqn_mean, dqn_std = np.mean(dqn_hits), np.std(dqn_hits)
        ddqn_mean, ddqn_std = np.mean(ddqn_hits), np.std(ddqn_hits)


        best_baseline = max(lru_mean, fifo_mean, lfu_mean)


        print(f"\n{pattern.upper()} Pattern:")
        
        print(f" LRU:  {lru_mean:.2f}% ± {lru_std:.2f}%")
        print(f" FIFO: {fifo_mean:.2f}% ± {fifo_std:.2f}%")
        
        print(f" LFU:  {lfu_mean:.2f}% ± {lfu_std:.2f}%")
        print(f" DQN:  {dqn_mean:.2f}% ± {dqn_std:.2f}%  {'✓' if dqn_mean > best_baseline else '✗'} ({dqn_mean - best_baseline:+.2f}%)")
        print(f" DDQN: {ddqn_mean:.2f}% ± {ddqn_std:.2f}%  {'✓' if ddqn_mean > best_baseline else '✗'} ({ddqn_mean - best_baseline:+.2f}%)")


        results[pattern] = {
            'LRU': (lru_mean, lru_std),
            'FIFO': (fifo_mean, fifo_std),
            'LFU': (lfu_mean, lfu_std),
            'DQN': (dqn_mean, dqn_std),
            'DDQN': (ddqn_mean, ddqn_std)
        }


    plot_results(dqn_rewards, ddqn_rewards, dqn_epsilon, ddqn_epsilon, results)
    
    return results



if __name__ == "__main__":
    
    print("Starting training...", flush=True)
    
    results = train_cache_comparison(num_episodes=700)
    





##### I also made attempts to make model using PPO and A2C RL algorithms but due to the complex caching environment and 
##### complex memory caching task, it proved to be very challenging to create or integrate a A2C or PPO that could learn 
##### and converge properly to a desired result.

"""
Attempts at developing PPO and A2C that were not converging properly for given complex tasks:


class ActorCritic(nn.Module):

    def __init__(self, state_dim, action_dim, hidden_dim=128):
    
    
        super().__init__()
        
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        self.actor = nn.Linear(hidden_dim, action_dim)
        self.critic = nn.Linear(hidden_dim, 1)
        
        
    def forward(self, x):
    
        x = self.shared(x)
        return self.actor(x), self.critic(x)
        
        

class A2CAgent:


    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.95):
    
    
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.gamma = gamma
        self.model = ActorCritic(state_dim, action_dim).to(self.device)
        
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        
        
    def select_action(self, state):
    
    
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        logits, _ = self.model(state_t)
        
        probs = torch.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        
        
        action = dist.sample()
        return action.item(), dist.log_prob(action)
        
        
    def update(self, trajectories):
    
        R = 0
        returns = []
        
        for r, _, _ in reversed(trajectories):
            R = r + self.gamma * R
            returns.insert(0, R)
            
        returns = torch.FloatTensor(returns).to(self.device)
        
        
        states = torch.FloatTensor([s for _, s, _ in trajectories]).to(self.device)
        actions = torch.LongTensor([a for _, _, a in trajectories]).to(self.device)
        
        log_probs = torch.stack([lp for lp, _, _ in trajectories])
        
        logits, values = self.model(states)
        values = values.squeeze()
        
        advantage = returns - values.detach()
        policy_loss = -(log_probs * advantage).mean()
        
        value_loss = nn.MSELoss()(values, returns)
        self.optimizer.zero_grad()
        
        (policy_loss + value_loss).backward()
        self.optimizer.step()



class PPOAgent:


    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.95, eps_clip=0.2):
    
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.gamma = gamma
        self.eps_clip = eps_clip
        
        self.model = ActorCritic(state_dim, action_dim).to(self.device)
        
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        
        
    def select_action(self, state):
    
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        logits, _ = self.model(state_t)
        probs = torch.softmax(logits, dim=-1)
        
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        
        return action.item(), dist.log_prob(action)
        
        
        
    def update(self, trajectories, epochs=4):
    
        states = torch.FloatTensor([s for _, s, _ in trajectories]).to(self.device)
        actions = torch.LongTensor([a for _, _, a in trajectories]).to(self.device)
        
        old_log_probs = torch.stack([lp for lp, _, _ in trajectories]).detach()
        
        
        returns = []
        
        R = 0
        
        for r, _, _ in reversed(trajectories):
        
            R = r + self.gamma * R
            returns.insert(0, R)
            
        returns = torch.FloatTensor(returns).to(self.device)
        
        
        for _ in range(epochs):
        
            logits, values = self.model(states)
            values = values.squeeze()
            
            
            probs = torch.softmax(logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            
            log_probs = dist.log_prob(actions)
            ratio = torch.exp(log_probs - old_log_probs)
            advantage = returns - values.detach()
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advantage

            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = nn.MSELoss()(values, returns)
            self.optimizer.zero_grad()
            (policy_loss + value_loss).backward()
            
            self.optimizer.step()
"""

"""
References

https://www.cs.utexas.edu/~lin/papers/micro19c.pdf

https://www.cs.cmu.edu/~weinaw/pdf/delayed-hits.pdf

https://gymnasium.farama.org/index.html

https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

"""

Starting training...
CACHE REPLACEMENT RL - TRAINING
Device: cpu
Cache: 8 | Memory: 64 | Seq: 100
Ep 50/700 | Reward (last 50): -10.36 | Hit: 44.82% | ε: 0.901
Ep 100/700 | Reward (last 50): -10.96 | Hit: 44.52% | ε: 0.802
Ep 150/700 | Reward (last 50): -13.60 | Hit: 43.20% | ε: 0.703
Ep 200/700 | Reward (last 50): -9.32 | Hit: 45.34% | ε: 0.604
Ep 250/700 | Reward (last 50): -10.56 | Hit: 44.72% | ε: 0.505
Ep 300/700 | Reward (last 50): -6.92 | Hit: 46.54% | ε: 0.406
Ep 350/700 | Reward (last 50): -4.00 | Hit: 48.00% | ε: 0.307
Ep 400/700 | Reward (last 50): -1.16 | Hit: 49.42% | ε: 0.208
Ep 450/700 | Reward (last 50): 3.40 | Hit: 51.70% | ε: 0.109
Ep 500/700 | Reward (last 50): 3.08 | Hit: 51.54% | ε: 0.010
Ep 550/700 | Reward (last 50): 2.84 | Hit: 51.42% | ε: 0.010
Ep 600/700 | Reward (last 50): 3.68 | Hit: 51.84% | ε: 0.010
Ep 650/700 | Reward (last 50): 8.44 | Hit: 54.22% | ε: 0.010
Ep 700/700 | Reward (last 50): 4.24 | Hit: 52.12% | ε: 0.010
Ep 50/700 | Reward (last 50): -11.08 

'\nReferences\n\nhttps://www.cs.utexas.edu/~lin/papers/micro19c.pdf\n\nhttps://www.cs.cmu.edu/~weinaw/pdf/delayed-hits.pdf\n\nhttps://gymnasium.farama.org/index.html\n\nhttps://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html\n\n'