In [4]:
import cogflow as cf
from cogflow import InputPath, OutputPath
import numpy as np
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import pandas as pd
import os
from datetime import datetime, timedelta
from enum import Enum
import random
!pip install pytorch_lightning



In [5]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, action_dim)
        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.softmax(self.l3(a), dim=-1)

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        
        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)
        
        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        
        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1

class SplitComputingAgent(cf.pyfunc.PythonModel):
    def __init__(
            self,
            state_dim,
            action_dim,
            max_action,
            discount=0.99,
            tau=0.005,
            policy_noise=0.2,
            noise_clip=0.5,
            policy_freq=2
    ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.actor = Actor(state_dim, action_dim, max_action).to(self.device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = Critic(state_dim, action_dim).to(self.device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.total_it = 0

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        return self.actor(state).cpu().data.numpy().flatten()

    def train(self, replay_buffer, batch_size=256):
        self.total_it += 1
        
        # Sample replay buffer
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

        with torch.no_grad():
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
            
            target_q1, target_q2 = self.critic_target(next_state, next_action)
            target_q = torch.min(target_q1, target_q2)
            target_q = reward + not_done * self.discount * target_q

        current_q1, current_q2 = self.critic(state, action)
        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        if self.total_it % self.policy_freq == 0:
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
            
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def predict(self, context, model_input):
        return self.select_action(model_input)

    def save(self, filename):
        critic_filename = filename + "_critic"
        critic_optimizer_filename = filename + "_critic_optimizer"
        actor_filename = filename + "_actor"
        actor_optimizer_filename = filename + "_actor_optimizer"
        
        cf.pytorch.save_state_dict(self.critic.state_dict(), critic_filename)
        cf.pytorch.save_state_dict(self.critic_optimizer.state_dict(), critic_optimizer_filename)
        cf.pytorch.save_state_dict(self.actor.state_dict(), actor_filename)
        cf.pytorch.save_state_dict(self.actor_optimizer.state_dict(), actor_optimizer_filename)
        
        return {
            "criticFileName": critic_filename,
            "criticOptimizerFileName": critic_optimizer_filename,
            "actorFileName": actor_filename,
            "actorOptimizerFileName": actor_optimizer_filename
        }

    def load(self, filename):
        critic_filename = filename + "_critic"
        critic_optimizer_filename = filename + "_critic_optimizer"
        actor_filename = filename + "_actor"
        actor_optimizer_filename = filename + "_actor_optimizer"
        
        self.critic.load_state_dict(torch.load(critic_filename))
        self.critic_optimizer.load_state_dict(torch.load(critic_optimizer_filename))
        self.critic_target = copy.deepcopy(self.critic)
        
        self.actor.load_state_dict(torch.load(actor_filename))
        self.actor_optimizer.load_state_dict(torch.load(actor_optimizer_filename))
        self.actor_target = copy.deepcopy(self.actor)
        
# UE Type Enumeration
class UEType(Enum):
    SMARTPHONE = 1
    TABLET = 2
    LAPTOP = 3
    IOT = 4

# Wireless Channel Classes
class WirelessLink:
    def __init__(self, bandwidth):
        self.bandwidth = bandwidth
        self.rsrp = None
        self.rsrq = None
        self.sinr = None
        self.cqi = None

    def update_channel_conditions(self, rsrp, rsrq, sinr, cqi):
        self.rsrp = rsrp
        self.rsrq = rsrq
        self.sinr = sinr
        self.cqi = cqi

    def data_rate(self):
        if self.sinr is None:
            raise ValueError("Channel conditions not set")
        return self.bandwidth * np.log2(1 + np.power(10,(self.sinr / 10)))

    def latency(self, data_size):
        return data_size / self.data_rate()

class WirelessChannel:
    def __init__(self, channel_file):
        self.channel_data = pd.read_csv(channel_file)
        self.channel_data['timestamp'] = pd.to_datetime(self.channel_data['timestamp'])
        self.timestamps = sorted(self.channel_data['timestamp'].unique())
        self.current_step = 0

    def get_channel_conditions(self, ue_id, step):
        if step >= len(self.timestamps):
            raise ValueError(f"Step {step} is out of range. Max step is {len(self.timestamps) - 1}")

        timestamp = self.timestamps[step]
        conditions = self.channel_data[
            (self.channel_data['UE_ID'] == ue_id) &
            (self.channel_data['timestamp'] == timestamp)
        ].iloc[0]

        return {
            'rsrp': conditions['RSRP'],
            'rsrq': conditions['RSRQ'],
            'sinr': conditions['SINR'],
            'cqi': conditions['CQI']
        }

    def reset(self):
        self.current_step = 0

    def step(self):
        self.current_step += 1
        if self.current_step >= len(self.timestamps):
            self.current_step = 0

# MEC Server Class
class MECServer:
    def __init__(self, cpu, mem, gpu):
        self.total_cpu = cpu  # in MIPS
        self.total_mem = mem  # in GB
        self.total_gpu = gpu  # in FLOPS

        self.available_cpu = cpu
        self.available_mem = mem
        self.available_gpu = gpu

        self.tasks = []

    def can_accept_task(self, mem_req):
        return self.available_mem >= mem_req

    def get_utilization(self):
        return {
            'cpu': (self.total_cpu - self.available_cpu) / self.total_cpu,
            'mem': (self.total_mem - self.available_mem) / self.total_mem,
            'gpu': (self.total_gpu - self.available_gpu) / self.total_gpu
        }

    def process_task(self, cpu_demand, gpu_demand, memory_demand):
        if self.can_accept_task(memory_demand):
            processing_time = min(cpu_demand / self.available_cpu,
                                gpu_demand / self.available_gpu)
            return processing_time
        else:
            return float('inf')

    def reset(self):
        self.available_cpu = self.total_cpu
        self.available_mem = self.total_mem
        self.available_gpu = self.total_gpu
        self.tasks = []

# DNN Task Class (simplified version)
class DNNTask:
    def __init__(self, num_layers=4):
        self.num_layers = num_layers
        self.layer_demands = self._generate_layer_demands()

    def _generate_layer_demands(self):
        return [{
            'local_cpu_demand': 1e9 * (i + 1),
            'local_gpu_demand': 2e9 * (i + 1),
            'local_memory_demand': 1e8 * (i + 1),
            'remote_cpu_demand': 5e8 * (i + 1),
            'remote_gpu_demand': 1e9 * (i + 1),
            'remote_memory_demand': 5e7 * (i + 1),
            'transmision_data_demand': 1e6 * (i + 1)
        } for i in range(self.num_layers)]

    def get_split_info(self, split_point):
        if split_point < 0 or split_point > self.num_layers:
            raise ValueError(f"Invalid split point: {split_point}")

        local_layers = self.layer_demands[:split_point]
        remote_layers = self.layer_demands[split_point:]

        return {
            'local_cpu_demand': sum(layer['local_cpu_demand'] for layer in local_layers),
            'local_gpu_demand': sum(layer['local_gpu_demand'] for layer in local_layers),
            'local_memory_demand': sum(layer['local_memory_demand'] for layer in local_layers),
            'remote_cpu_demand': sum(layer['remote_cpu_demand'] for layer in remote_layers),
            'remote_gpu_demand': sum(layer['remote_gpu_demand'] for layer in remote_layers),
            'remote_memory_demand': sum(layer['remote_memory_demand'] for layer in remote_layers),
            'transmision_data_demand': sum(layer['transmision_data_demand'] for layer in remote_layers)
        }

# UE Class
class UE:
    DEFAULT_RESOURCES = {
        UEType.SMARTPHONE: {
            "cpu": 2.0e9, "gpu": 50e9, "mem": 4e9, "bat": 3000,
            "e_cpu": 1e-9, "e_gpu": 1e-12, "e_mem": 1e-6,
            "p_base": 0.1, "p_tx": 0.5
        },
        UEType.TABLET: {
            "cpu": 2.5e9, "gpu": 7e9, "mem": 8e9, "bat": 7000,
            "e_cpu": 9e-10, "e_gpu": 9e-13, "e_mem": 9e-7,
            "p_base": 0.15, "p_tx": 0.6
        },
        UEType.LAPTOP: {
            "cpu": 3.5e9, "gpu": 2e12, "mem": 16e9, "bat": 5000,
            "e_cpu": 8e-10, "e_gpu": 8e-13, "e_mem": 8e-7,
            "p_base": 0.2, "p_tx": 0.7
        },
        UEType.IOT: {
            "cpu": 1.0e9, "gpu": 0, "mem": 1e9, "bat": 1000,
            "e_cpu": 1.2e-9, "e_gpu": 0, "e_mem": 1.2e-6,
            "p_base": 0.05, "p_tx": 0.3
        }
    }

    def __init__(self, ue_id, ue_type, num_layers=4):
        self.ue_id = ue_id
        self.ue_type = ue_type
        self.resources = self.DEFAULT_RESOURCES[ue_type].copy()
        self.cpu_load = 20
        self.gpu_load = 10 if self.resources['gpu'] > 0 else 0
        self.mem_load = 10
        self.bat_level = self.resources['bat']
        self.wireless_link = WirelessLink(10e6)
        self.task = DNNTask(num_layers=num_layers)
        self.time = 0

    def update(self):
        self.time += 1
        cpu_change = random.normalvariate(0, 0.05 * self.resources['cpu'])
        self.cpu_load = max(0, min(self.resources['cpu'], self.cpu_load + cpu_change))

        if self.resources['gpu'] > 0:
            gpu_change = random.normalvariate(0, 0.07 * self.resources['gpu'])
            self.gpu_load = max(0, min(self.resources['gpu'], self.gpu_load + gpu_change))

        mem_change = random.normalvariate(0, 0.07 * self.resources['mem'])
        self.mem_load = max(0, min(self.resources['mem'], self.mem_load + mem_change))

        energy_consumption = (self.cpu_load * self.resources['e_cpu'] + 
                            self.gpu_load * self.resources['e_gpu'] + 
                            self.mem_load * self.resources['e_mem'] + 
                            self.resources['p_base']) / 3600
        self.bat_level -= energy_consumption

        if random.random() < 0.1 or self.bat_level < 0.1 * self.resources['bat']:
            self.bat_level = self.resources['bat']

    def get_state(self):
        return {
            'cpu_load': self.cpu_load / self.resources['cpu'],
            'gpu_load': self.gpu_load / self.resources['gpu'] if self.resources['gpu'] > 0 else 0,
            'mem_load': self.mem_load / self.resources['mem'],
            'bat_level': self.bat_level / self.resources['bat'],
            'rsrp': self.wireless_link.rsrp/100,
            'rsrq': self.wireless_link.rsrq/20,
            'sinr': self.wireless_link.sinr/10
        }

    def compute_local(self, cpu_demand, gpu_demand, memory_demand):
        available_cpu = self.resources['cpu'] - self.cpu_load
        available_gpu = self.resources['gpu'] - self.gpu_load
        available_mem = self.resources['mem'] - self.mem_load

        if available_mem < memory_demand:
            return float('inf')

        cpu_time = cpu_demand / available_cpu if available_cpu > 0 else float('inf')
        gpu_time = gpu_demand / available_gpu if available_gpu > 0 else float('inf')
        return min(cpu_time, gpu_time)

    def compute_communication(self, data_size):
        return self.wireless_link.latency(data_size/80)

    def calculate_energy_consumption(self, cpu_usage, gpu_usage, mem_usage, duration):
        return (cpu_usage * self.resources['e_cpu'] + 
                gpu_usage * self.resources['e_gpu'] + 
                mem_usage * self.resources['e_mem'] + 
                self.resources['p_base']) * duration

# Split Computing Environment
class SplitComputingEnv:
    def __init__(self, channel_file, num_ues):
        super(SplitComputingEnv, self).__init__()
        self.channel_data = pd.read_csv(channel_file)
        distinct_ue_ids = self.channel_data['UE_ID'].nunique()
        self.max_steps = self.channel_data['timestamp'].nunique()

        if num_ues > distinct_ue_ids:
            raise ValueError(f"Requested number of UEs ({num_ues}) exceeds the number of distinct UE IDs in the channel file ({distinct_ue_ids})")

        self.num_ues = num_ues
        self.wireless_channel = WirelessChannel(channel_file)
        self.action_space = 5
        self.observation_space = 7
        self.actual_ue_ids = self.channel_data['UE_ID'].unique()[:num_ues]

        ue_type_probabilities = [0.5, 0.15, 0.25, 0.1]
        ue_types = np.random.choice(list(UEType), size=num_ues, p=ue_type_probabilities)
        self.ues = {ue_id: UE(int(ue_id), ue_type, num_layers=self.action_space) 
                   for ue_id, ue_type in zip(self.actual_ue_ids, ue_types)}

        self.mec_server = MECServer(cpu=100e9, mem=256e9, gpu=1e12)
        self.current_step = 0
        self.ues_processed_this_step = set()
        self.total_energy_consumed = {ue_id: [] for ue_id in self.actual_ue_ids}
        self.total_time_taken = {ue_id: [] for ue_id in self.actual_ue_ids}
        self.ue_data_rate = {ue_id: [] for ue_id in self.actual_ue_ids}
        self.sla_violation = {ue_id: 0 for ue_id in self.actual_ue_ids}

    def step(self, action, ue_id):
        if ue_id not in self.ues:
            raise ValueError(f"Invalid UE ID: {ue_id}")
        energy_consumption = 0
        ue = self.ues[ue_id]
        ue.update()
        channel_conditions = self.wireless_channel.get_channel_conditions(ue_id, self.current_step)
        ue.wireless_link.update_channel_conditions(**channel_conditions)

        split_info = ue.task.get_split_info(action)
        local_time = ue.compute_local(split_info['local_cpu_demand'], split_info['local_gpu_demand'],
                                      split_info['local_memory_demand'])
        comm_time = ue.compute_communication(split_info['transmision_data_demand'])

        if self.mec_server.can_accept_task(split_info['remote_memory_demand']):
            remote_time = self.mec_server.process_task(split_info['remote_cpu_demand'], split_info['remote_gpu_demand'],
                                                       split_info['remote_memory_demand'])
        else:
            remote_time = float('inf')

        total_time = local_time + comm_time + remote_time
        if total_time <float('inf'):
            energy_consumption = ue.calculate_energy_consumption(
                split_info['local_cpu_demand'],
                split_info['local_gpu_demand'],
                split_info['local_memory_demand'],
                local_time
            ) + ue.resources['p_tx'] * comm_time
            self.total_energy_consumed[ue_id].append(energy_consumption)
            self.total_time_taken[ue_id].append(total_time)
            self.ue_data_rate[ue_id].append(ue.wireless_link.data_rate())
            reward = self.compute_reward(total_time, energy_consumption, ue.bat_level / ue.resources['bat'])
        else:
            self.sla_violation[ue_id] += 1
            reward = -1


        self.ues_processed_this_step.add(ue_id)

        # Check if all UEs have been processed for this step
        if len(self.ues_processed_this_step) == self.num_ues:
            self.current_step += 1
            self.ues_processed_this_step.clear()

        done = self.current_step >= self.max_steps
        info = {
            'energy_consumption': energy_consumption,
            'total_time': total_time
        }
        return ue.get_state(), reward, done, info

    def reset(self):
        self.current_step = 0
        self.ues_processed_this_step.clear()
        self.total_energy_consumed = {ue_id: [] for ue_id in self.actual_ue_ids}
        self.total_time_taken = {ue_id: [] for ue_id in self.actual_ue_ids}
        self.ue_data_rate = {ue_id: [] for ue_id in self.actual_ue_ids}
        self.sla_violation = {ue_id: 0 for ue_id in self.actual_ue_ids}
        # Reset UEs and update their initial channel conditions
        initial_states = {}
        for ue_id, ue in self.ues.items():
            ue.__init__(ue.ue_id, ue.ue_type)  # Reset UE to initial state

            # Get initial channel conditions for this UE
            initial_conditions = self.wireless_channel.get_channel_conditions(ue_id, self.current_step)

            # Update UE's wireless link with initial conditions
            ue.wireless_link.update_channel_conditions(**initial_conditions)

            # Get the initial state for this UE
            initial_states[ue_id] = ue.get_state()

        self.mec_server.reset()
        self.current_step += 1
        return initial_states

    def compute_reward(self, delay, energy, battery_level):
        # before computing the reward we normalize the delay and energy:
        delay = delay*10
        energy = energy / 10
        w1, w2, w3 = 0.4, 0.4, 0.2
        r_time = 1 / (delay + 1e-6)
        r_energy = 1 / (energy + 1e-6)
        r_battery = battery_level
        return w1 * r_time + w2 * r_energy + w3 * r_battery

    def render(self, mode='human'):
        pass

    def close(self):
        pass

    def get_average_energy_consumption(self):
        return {ue_id: np.array(self.total_energy_consumed[ue_id]).mean() for ue_id in self.total_energy_consumed.keys()}

    def get_average_total_time(self):
        return {ue_id: (np.array(self.total_time_taken[ue_id])*100).mean() for ue_id in self.total_time_taken.keys()}

    def get_total_sla_violation(self):
        return {ue_id: self.sla_violation[ue_id] for ue_id in self.sla_violation.items()}

    def get_total_data_rate(self):
        return  self.ue_data_rate
    def get_overall_average_energy_consumption(self):

        means = [np.mean(energy) for energy in self.total_energy_consumed.values() if len(energy) > 0]
        result = np.mean(means) if means else 0
        return result

    def get_overall_average_total_time(self):

        means = [np.mean(time) *100 for time in self.total_time_taken.values() if len(time) > 0]
        result = np.mean(means) if means else 0
        return result

    def get_overall_ue_data_rate(self):

        means = [np.mean(rate)/1e6 for rate in self.ue_data_rate.values() if len(rate) > 0]
        result = np.mean(means) if means else 0
        return result

    def get_overall_average_sla_violation(self):
        violations = [arr for arr in self.sla_violation.values()]
        return np.mean(violations)
        
class ReplayBuffer(object):
    def __init__(self, size):
        """Create Replay buffer.

        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        """
        self._storage = []
        self._maxsize = size
        self._next_idx = 0

    def __len__(self):
        return len(self._storage)

    def add(self, obs_t, action, reward, obs_tp1, done):
        data = (obs_t, action, reward, obs_tp1, done)

        if self._next_idx >= len(self._storage):
            self._storage.append(data)
        else:
            self._storage[self._next_idx] = data
        self._next_idx = (self._next_idx + 1) % self._maxsize

    def _encode_sample(self, idxes):
        obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
        for i in idxes:
            data = self._storage[i]
            obs_t, action, reward, obs_tp1, done = data
            obses_t.append(np.array(obs_t, copy=False))
            actions.append(np.array(action, copy=False))
            rewards.append(reward)
            obses_tp1.append(np.array(obs_tp1, copy=False))
            dones.append(done)
        return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones)

    def sample(self, batch_size):
        """Sample a batch of experiences.

        Parameters
        ----------
        batch_size: int
            How many transitions to sample.

        Returns
        -------
        obs_batch: np.array
            batch of observations
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_batch: np.array
            next set of observations seen after executing act_batch
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        """
        idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
        return self._encode_sample(idxes)

In [6]:
# Channel Data Generation Component
def generate_channel_data_op(num_ues: int, 
                           duration_minutes: int,
                           sampling_rate_ms: int,
                           output_file: OutputPath('CSV')):
    import numpy as np
    import pandas as pd
    from datetime import datetime, timedelta
    
    # [Previous generate_3gpp_channel_data function code]
    # Generate the channel data and save to output_file
    df.to_csv(output_file, index=False)
    """
    Generate channel data following 3GPP specifications for urban macro cell scenarios.
    
    Parameters:
    -----------
    num_ues : int
        Number of UEs to simulate
    duration_minutes : int
        Duration of the simulation in minutes
    sampling_rate_ms : int
        Sampling rate in milliseconds
    
    Returns:
    --------
    str
        Path to the generated CSV file
    
    Notes:
    ------
    Following 3GPP TR 38.901 for channel modeling:
    - Path loss model: Urban Macro (UMa)
    - Frequency: 2GHz
    - BS height: 25m
    - UE height: 1.5m
    - Distance range: 10-500m
    """
    
    def calculate_path_loss(distance, f_c=2.0):
        """Calculate path loss based on 3GPP UMa model"""
        # 3GPP TR 38.901 Urban Macro path loss model
        h_BS = 25  # Base station height in meters
        h_UT = 1.5  # User terminal height in meters
        
        # Calculate break point distance
        h_E = 1  # Effective environment height
        h_BP = 4 * (h_BS - h_E) * (h_UT - h_E) * f_c / 0.3
        
        if distance < h_BP:
            # LOS path loss before break point
            PL = 28.0 + 22*np.log10(distance) + 20*np.log10(f_c)
        else:
            # LOS path loss after break point
            PL = 28.0 + 40*np.log10(distance) + 20*np.log10(f_c) - 9*np.log10(h_BP**2 + (h_BS-h_UT)**2)
        
        return PL

    def generate_shadow_fading(size):
        """Generate shadow fading with log-normal distribution"""
        # 3GPP specifies 4dB standard deviation for UMa LOS
        return np.random.normal(0, 4, size)
    
    def calculate_sinr(path_loss, shadow_fading, interference_power=-90):
        """Calculate SINR based on path loss and interference"""
        tx_power = 43  # BS transmission power in dBm (typical macro cell)
        noise_floor = -174 + 10*np.log10(20e6)  # Thermal noise for 20MHz bandwidth
        
        # Received power = Tx power - path loss + shadow fading
        rx_power = tx_power - path_loss + shadow_fading
        
        # SINR calculation
        interference_plus_noise = 10*np.log10(10**(interference_power/10) + 10**(noise_floor/10))
        sinr = rx_power - interference_plus_noise
        
        return sinr
    
    def map_sinr_to_cqi(sinr):
        """Map SINR to CQI according to 3GPP specifications"""
        # Simplified CQI mapping based on SINR ranges
        cqi_ranges = [
            (-float('inf'), -6.9),
            (-6.9, -5.1), (-5.1, -3.3), (-3.3, -1.5), (-1.5, 0.3),
            (0.3, 2.1), (2.1, 3.9), (3.9, 5.7), (5.7, 7.5),
            (7.5, 9.3), (9.3, 11.1), (11.1, 12.9), (12.9, 14.7),
            (14.7, 16.5), (16.5, float('inf'))
        ]
        
        for cqi, (min_sinr, max_sinr) in enumerate(cqi_ranges):
            if min_sinr <= sinr < max_sinr:
                return cqi + 1
        return 15
    
    def calculate_rsrp(path_loss, shadow_fading):
        """Calculate RSRP based on path loss"""
        tx_power = 43  # BS transmission power in dBm
        return tx_power - path_loss + shadow_fading
    
    def calculate_rsrq(rsrp, num_resource_blocks=100):
        """Calculate RSRQ based on RSRP and RSSI"""
        # Simplified RSRQ calculation
        noise_per_rb = -120  # Noise per resource block in dBm
        rssi = 10*np.log10(10**(rsrp/10) + num_resource_blocks * 10**(noise_per_rb/10))
        rsrq = rsrp - rssi + 10*np.log10(num_resource_blocks)
        return np.clip(rsrq, -19.5, -3)

    # Calculate number of samples
    num_samples = int((duration_minutes * 60 * 1000) / sampling_rate_ms)
    timestamps = [datetime.now() + timedelta(milliseconds=i*sampling_rate_ms) for i in range(num_samples)]
    
    # Generate UE distances (assuming some mobility)
    ue_distances = {ue_id: [] for ue_id in range(num_ues)}
    for ue_id in range(num_ues):
        # Initial distance
        current_distance = np.random.uniform(10, 500)
        # Random walk for distance changes
        for _ in range(num_samples):
            # Add some random movement (-2 to +2 meters per sample)
            current_distance += np.random.uniform(-2, 2)
            current_distance = np.clip(current_distance, 10, 500)  # Keep within valid range
            ue_distances[ue_id].append(current_distance)
    
    # Generate data for each UE and timestamp
    data = []
    for t_idx, timestamp in enumerate(timestamps):
        for ue_id in range(num_ues):
            distance = ue_distances[ue_id][t_idx]
            path_loss = calculate_path_loss(distance)
            shadow_fading = generate_shadow_fading(1)[0]
            
            # Calculate channel conditions
            rsrp = calculate_rsrp(path_loss, shadow_fading)
            rsrq = calculate_rsrq(rsrp)
            sinr = calculate_sinr(path_loss, shadow_fading)
            cqi = map_sinr_to_cqi(sinr)
            
            data.append({
                'timestamp': timestamp,
                'UE_ID': ue_id,
                'RSRP': round(rsrp, 2),
                'RSRQ': round(rsrq, 2),
                'SINR': round(sinr, 2),
                'CQI': int(cqi),
                'distance': round(distance, 2)
            })
    
    # Create DataFrame and save to CSV
    df = pd.DataFrame(data)
    output_file = 'channel_data_3gpp.csv'
    df.to_csv(output_file, index=False)
    
    # Print summary statistics
    print("\n=== Channel Data Statistics ===")
    print(f"Number of UEs: {num_ues}")
    print(f"Duration: {duration_minutes} minutes")
    print(f"Sampling rate: {sampling_rate_ms}ms")
    print("\nMetrics ranges:")
    for metric in ['RSRP', 'RSRQ', 'SINR', 'CQI']:
        print(f"{metric}: {df[metric].min():.2f} to {df[metric].max():.2f}")

channel_gen_op = cf.create_component_from_func(
    func=generate_channel_data_op,
    output_component_file='channel-gen-component.yaml'
)

In [7]:
# Environment Setup Component
def setup_environment_op(channel_file: InputPath('CSV'),
                        num_ues: int,
                        output_config: OutputPath('JSON')):
    import json
    from UE import UE, UEType
    from wireless_channel import WirelessChannel
    from mec_server import MECServer
    
    # Initialize environment configuration
    config = {
        'num_ues': num_ues,
        'action_space': 5,
        'observation_space': 7,
        'mec_server': {
            'cpu': 100e9,
            'mem': 256e9,
            'gpu': 1e12
        }
    }
    
    with open(output_config, 'w') as f:
        json.dump(config, f)

setup_env_op = cf.create_component_from_func(
    func=setup_environment_op,
    output_component_file='setup-env-component.yaml'
)

In [None]:
# Model Training Component
def train_model_op(channel_file: InputPath('CSV'),
                  env_config: InputPath('JSON'),
                  max_timesteps: int,
                  batch_size: int) -> str:
    import torch
    import numpy as np
    import cogflow as cf
    from datetime import datetime
    
    # Initialize environment and agent
    cf.pytorch.autolog()
    
    with cf.start_run() as run:
        # Training loop
         # Log parameters
        cf.log_param("num_ues", num_ues)
        cf.log_param("seed", seed)
        cf.log_param("batch_size", batch_size)
        cf.log_param("discount", discount)
        cf.log_param("start_timesteps", start_timesteps)
        
        # Training loop
        states = env.reset()
        episode_num = 0
        
        for t in tqdm(range(int(max_timesteps)), desc="Training Progress"):
            episode_rewards = {ue_id: 0 for ue_id in env.actual_ue_ids}
            done = False
            
            while not done:
                for ue_id in env.actual_ue_ids:
                    state = states[ue_id]
                    
                    # Select action
                    if t < start_timesteps:
                        action = np.random.randint(0, env.action_space)
                    else:
                        action = agent.select_action(np.array(state))
                        action = np.clip(
                            action + np.random.normal(0, max_action * expl_noise),
                            0,
                            max_action
                        )
                    
                    # Execute action
                    next_state, reward, done, info = env.step(int(action), ue_id)
                    
                    # Store transition
                    replay_buffer.add(state, action, next_state, reward, float(done))
                    
                    episode_rewards[ue_id] += reward
                    states[ue_id] = next_state
                    
                    # Train agent
                    if t >= start_timesteps:
                        agent.train(replay_buffer, batch_size)
            
            # Log metrics
            avg_energy = env.get_overall_average_energy_consumption()
            avg_time = env.get_overall_average_total_time()
            avg_data_rate = env.get_overall_ue_data_rate()
            avg_violations = env.get_overall_average_sla_violation()
            
            cf.log_metric("average_energy", avg_energy, step=episode_num)
            cf.log_metric("average_time", avg_time, step=episode_num)
            cf.log_metric("average_data_rate", avg_data_rate, step=episode_num)
            cf.log_metric("sla_violations", avg_violations, step=episode_num)
        
        # Save model
        model_name = "split-computing-model"
        result = cf.log_model(agent, "model", 
                            registered_model_name=model_name)
        return f"{run.info.artifact_uri}/{result.artifact_path}"

train_op = cf.create_component_from_func(
    func=train_model_op,
    output_component_file='train-component.yaml'
)

In [None]:
# Evaluation Component
def evaluate_model_op(model_path: str,
                     channel_file: InputPath('CSV'),
                     env_config: InputPath('JSON')):
    import json
    from tqdm import tqdm
    import torch
    
    cf.autolog()
    
    with cf.start_run() as run:
        # Load model and configuration
        model = cf.pyfunc.load_model(model_path)
        with open(env_config, 'r') as f:
            config = json.load(f)
            
        # Initialize environment
        env = SplitComputingEnv(channel_file, num_ues=config['num_ues'])
        
        # Evaluation parameters
        n_evaluation_episodes = 100
        results = {
            'energy_consumption': [],
            'processing_time': [],
            'data_rate': [],
            'sla_violations': [],
            'split_decisions': [],
            'rewards': []
        }
        
        # Run evaluation episodes
        for episode in tqdm(range(n_evaluation_episodes), desc="Evaluating"):
            episode_reward = 0
            states = env.reset()
            done = False
            
            while not done:
                for ue_id in env.actual_ue_ids:
                    state = states[ue_id]
                    # Get model prediction
                    action = model.predict(np.array(state))
                    split_point = int(np.argmax(action))
                    
                    # Execute action
                    next_state, reward, done, info = env.step(split_point, ue_id)
                    
                    # Record metrics
                    results['energy_consumption'].append(info['energy_consumption'])
                    results['processing_time'].append(info['total_time'])
                    results['split_decisions'].append(split_point)
                    results['rewards'].append(reward)
                    results['data_rate'].append(env.get_total_data_rate()[ue_id][-1])
                    
                    states[ue_id] = next_state
                    episode_reward += reward
            
            results['sla_violations'].append(env.get_overall_average_sla_violation())
        
        # Calculate overall metrics
        avg_energy = np.mean(results['energy_consumption'])
        avg_time = np.mean(results['processing_time'])
        avg_data_rate = np.mean(results['data_rate'])
        avg_sla_violations = np.mean(results['sla_violations'])
        avg_reward = np.mean(results['rewards'])
        
        # Log metrics to cogflow
        cf.log_metric("average_energy_consumption", avg_energy)
        cf.log_metric("average_processing_time", avg_time)
        cf.log_metric("average_data_rate", avg_data_rate)
        cf.log_metric("average_sla_violations", avg_sla_violations)
        cf.log_metric("average_reward", avg_reward)
        
        # Log split point distribution
        split_dist = pd.Series(results['split_decisions']).value_counts().to_dict()
        cf.log_dict("split_point_distribution", split_dist)
        
        # Create and log detailed performance analysis
        performance_df = pd.DataFrame({
            'Energy': results['energy_consumption'],
            'Time': results['processing_time'],
            'DataRate': results['data_rate'],
            'SplitPoint': results['split_decisions'],
            'Reward': results['rewards']
        })
        
        # Calculate per-split-point statistics
        split_stats = performance_df.groupby('SplitPoint').agg({
            'Energy': ['mean', 'std'],
            'Time': ['mean', 'std'],
            'DataRate': ['mean', 'std'],
            'Reward': ['mean', 'std']
        }).round(4)
        
        # Log split point performance statistics
        cf.log_dict("split_point_statistics", split_stats.to_dict())
        
        # Calculate correlation matrix
        corr_matrix = performance_df.corr().round(4)
        cf.log_dict("metric_correlations", corr_matrix.to_dict())
        
        # Log distribution plots
        import matplotlib.pyplot as plt
        import seaborn as sns
        
        # Energy distribution per split point
        plt.figure(figsize=(10, 6))
        sns.boxplot(data=performance_df, x='SplitPoint', y='Energy')
        plt.title('Energy Consumption Distribution per Split Point')
        plt.xlabel('Split Point')
        plt.ylabel('Energy (J)')
        cf.log_figure("energy_distribution", plt.gcf())
        plt.close()
        
        # Processing time distribution per split point
        plt.figure(figsize=(10, 6))
        sns.boxplot(data=performance_df, x='SplitPoint', y='Time')
        plt.title('Processing Time Distribution per Split Point')
        plt.xlabel('Split Point')
        plt.ylabel('Time (ms)')
        cf.log_figure("time_distribution", plt.gcf())
        plt.close()
        
        # Create summary report
        summary = {
            "evaluation_episodes": n_evaluation_episodes,
            "average_metrics": {
                "energy_consumption": float(avg_energy),
                "processing_time": float(avg_time),
                "data_rate": float(avg_data_rate),
                "sla_violations": float(avg_sla_violations),
                "reward": float(avg_reward)
            },
            "split_point_distribution": split_dist,
            "performance_summary": "Success" if avg_sla_violations < 0.1 else "Needs Improvement"
        }
        
        # Log summary
        cf.log_dict("evaluation_summary", summary)
        
        return {
            "average_energy": avg_energy,
            "average_time": avg_time,
            "average_data_rate": avg_data_rate,
            "sla_violations": avg_sla_violations,
            "average_reward": avg_reward
        }


In [None]:
cf.pipeline(
    name="split-computing-pipeline",
    description="Split Computing DRL Training Pipeline"
)
def split_computing_pipeline(
    num_ues: int = 3,
    duration_minutes: int = 1000,
    sampling_rate_ms: int = 100,
    max_timesteps: int = 100000,
    batch_size: int = 256
):
    # Generate channel data
    channel_data_task = channel_gen_op(
        num_ues=num_ues,
        duration_minutes=duration_minutes,
        sampling_rate_ms=sampling_rate_ms
    )
    
    # Setup environment
    env_setup_task = setup_env_op(
        channel_file=channel_data_task.outputs['output_file'],
        num_ues=num_ues
    )
    
    # Train model
    train_task = train_op(
        channel_file=channel_data_task.outputs['output_file'],
        env_config=env_setup_task.outputs['output_config'],
        max_timesteps=max_timesteps,
        batch_size=batch_size
    )
    
    # Evaluate model
    eval_task = eval_op(
        model_path=train_task.output,
        channel_file=channel_data_task.outputs['output_file'],
        env_config=env_setup_task.outputs['output_config']
    )

In [None]:
# Run Pipeline
def run_pipeline():
    client = cf.client()
    client.create_run_from_pipeline_func(
        split_computing_pipeline,
        arguments={
            "num_ues": 3,
            "duration_minutes": 1000,
            "sampling_rate_ms": 100,
            "max_timesteps": 100000,
            "batch_size": 256
        }
    )

run_pipeline()