In [None]:
from __future__ import annotations

import os
import time
import numpy as np
import tensorflow as tf


from ddpg_ide.system.mec_env_var import MecTermLD, MecSvrEnv

def _setup_tf():
    
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            print(f"✓ Found {len(gpus)} GPU(s). Memory growth enabled.")
        except RuntimeError as e:
            print(f"GPU Setup Error: {e}")

def main():
    _setup_tf()

    
    for k in range(7):
        print(f"\n{'='*70}")
        print(f'TESTING DDPG+IDE TRIAL: {k}')
        print(f"{'='*70}\n")

        
        MAX_EPISODE = 100        
        MAX_EPISODE_LEN = 10000  
        
        NUM_R = 4           
        SIGMA2 = 1e-9      
        
        t_factor = 0.5      

        
        config = {'state_dim': 3, 'action_dim': 2}
        
        
        train_config = {
            'sigma2': SIGMA2,
            'random_seed': int(time.time() * 1000) % 1000,
            'agent_type': 'ddpg',  
        }

        
        trial_name = "ddpg_ide_deterministic"  
        res_path = f'test_results_{trial_name}/'
        model_root = f'models_{trial_name}/'  
        model_prefix = os.path.join(model_root, f"trial_{k}")  
        
        os.makedirs(res_path, exist_ok=True)

        
        user_config_list = [
            {'id': '1', 'model': 'AR', 'num_r': NUM_R, 'rate': 1.0, 'dis': 100, 
             'action_bound': 2, 'data_buf_size': 100, 't_factor': t_factor, 'penalty': 1000},
            
            {'id': '2', 'model': 'AR', 'num_r': NUM_R, 'rate': 2.0, 'dis': 100, 
             'action_bound': 2, 'data_buf_size': 100, 't_factor': t_factor, 'penalty': 1000},
            
            {'id': '3', 'model': 'AR', 'num_r': NUM_R, 'rate': 3.0, 'dis': 100, 
             'action_bound': 2, 'data_buf_size': 100, 't_factor': t_factor, 'penalty': 1000},
        ]

        print(f"Testing DDPG + IDE Agent")
        print(f"Model path: {model_root}")
        print(f"User Config: {user_config_list}\n")

        
        
        user_list = []
        for info in user_config_list:
            info = dict(info)  
            info.update(config)  
            
            
            info['ckpt_dir'] = f"{model_prefix}_user_{info['id']}"
            
            
            if not os.path.exists(info['ckpt_dir']):
                print(f"WARNING: Checkpoint directory not found: {info['ckpt_dir']}")
                print(f"Skipping trial {k}")
                break
            
            
            try:
                user = MecTermLD(info, train_config)
                user_list.append(user)
                print(f"Loaded User {info['id']} (DDPG+IDE) from {info['ckpt_dir']}")
            except Exception as e:
                print(f"Error loading User {info['id']}: {e}")
                break
        
        
        if len(user_list) != len(user_config_list):
            print(f"Skipping trial {k} due to missing checkpoints\n")
            continue

        
        env = MecSvrEnv(user_list, NUM_R, SIGMA2, MAX_EPISODE_LEN)
        print(f"Environment initialized\n")

        
        res_r, res_p, res_b, res_o, res_d = [], [], [], [], []
        
        
        res_op = []  # Wasted power
        res_ts = []  # Transmitted/Offloaded bits
        res_ps = []  # Processed/Local bits

    
        print(f"{'='*70}")
        print(f"Starting Test for Trial {k} ({MAX_EPISODE} episodes)")
        print(f"{'='*70}\n")
        
        start_time = time.time()
        
        for i in range(MAX_EPISODE):
            
            cur_init_ds_ep = env.reset(is_train=False)

            
            cur_r_ep = np.zeros(len(user_list), dtype=np.float32)   # Reward
            cur_p_ep = np.zeros(len(user_list), dtype=np.float32)   # Total Power
            cur_op_ep = np.zeros(len(user_list), dtype=np.float32)  # Wasted Power
            cur_ts_ep = np.zeros(len(user_list), dtype=np.float32)  # Transmit/Offload bits
            cur_ps_ep = np.zeros(len(user_list), dtype=np.float32)  # Processed/Local bits
            cur_rs_ep = np.zeros(len(user_list), dtype=np.float32)  # Arrived bits
            cur_ds_ep = np.zeros(len(user_list), dtype=np.float32)  # Data buffer size
            cur_ch_ep = np.zeros(len(user_list), dtype=np.float32)  # Channel gain
            cur_of_ep = np.zeros(len(user_list), dtype=np.float32)  # Overflows

            for j in range(MAX_EPISODE_LEN):
                
                (cur_r, done, cur_p, cur_op, _noise, cur_ts, cur_ps, cur_rs,
                 cur_ds, cur_ch, cur_of) = env.step(is_random=False)

                # Cộng dồn số liệu
                cur_r_ep += cur_r
                cur_p_ep += cur_p
                cur_op_ep += cur_op
                cur_ts_ep += cur_ts
                cur_ps_ep += cur_ps
                cur_rs_ep += cur_rs
                cur_ds_ep += cur_ds
                cur_ch_ep += cur_ch
                cur_of_ep += cur_of

                if done:
                    
                    res_r.append(cur_r_ep / MAX_EPISODE_LEN)
                    res_p.append(cur_p_ep / MAX_EPISODE_LEN)
                    res_b.append(cur_ds_ep / MAX_EPISODE_LEN)
                    res_o.append(cur_of_ep / MAX_EPISODE_LEN)
                    res_d.append(cur_ds)

                    print('%d:r:%s,p:%s,op:%s,tr:%s,pr:%s,rev:%s,dbuf:%s,ch:%s,ibuf:%s,rbuf:%s' %
                          (i,
                           cur_r_ep / MAX_EPISODE_LEN,
                           cur_p_ep / MAX_EPISODE_LEN,
                           cur_op_ep / MAX_EPISODE_LEN,
                           cur_ts_ep / MAX_EPISODE_LEN,
                           cur_ps_ep / MAX_EPISODE_LEN,
                           cur_rs_ep / MAX_EPISODE_LEN,
                           cur_ds_ep / MAX_EPISODE_LEN,
                           cur_ch_ep / MAX_EPISODE_LEN,
                           cur_init_ds_ep,
                           cur_ds))
                    break


        
        elapsed_time = time.time() - start_time
        
        print(f"\n{'='*70}")
        print(f"TRIAL {k} TEST COMPLETED in {elapsed_time/60:.2f} minutes")
        print(f"{'='*70}")
        
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        name = os.path.join(res_path, f'test_trial_{k}_{timestamp}')
        
        
        np.savez(name, 
                 rewards=np.array(res_r), 
                 powers=np.array(res_p),
                 wasted_powers=np.array(res_op),
                 buffers=np.array(res_b), 
                 overflows=np.array(res_o), 
                 final_buffers=np.array(res_d),
                 transmitted_bits=np.array(res_ts),
                 processed_bits=np.array(res_ps))

        
        print(f"\nFinal Statistics:")
        print(f"  Avg Reward: {np.mean([r.mean() for r in res_r]):.4f} ± {np.std([r.mean() for r in res_r]):.4f}")
        print(f"  Avg Power: {np.mean([p.mean() for p in res_p]):.4f} ± {np.std([p.mean() for p in res_p]):.4f}")
        print(f"  Avg Buffer: {np.mean([b.mean() for b in res_b]):.4f} ± {np.std([b.mean() for b in res_b]):.4f}")
        print(f"  Avg Wasted Power: {np.mean([op.mean() for op in res_op]):.4f}")
        print(f"\nResults saved to: {name}.npz")
        print(f"Tested model from: {model_prefix}\n")

    print("\n{'='*70}")
    print("ALL TRIALS TESTING COMPLETED!")
    print(f"{'='*70}\n")


In [None]:
main()
# 390 phút cho 7 models