<a href="https://colab.research.google.com/github/GitHrsh/Noma-ISAC/blob/main/NOMA_ISAC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np, numpy.linalg as la, math, typing as tp, itertools, os, time
from multiprocessing import Pool, cpu_count
#parameters
@dataclass
class Sys:
    K: int = 3; M: int = 4; Nt: int = 8; Nr: int = 4; Ns: int = 8
    C: int = 1
    Lb: int = 16
    L_power: tp.Sequence[float] = (0.2, 0.4, 0.6, 0.8, 1.0)
    PBS: float = 10.0; PSCD: float = 1.0
    sigma_c: float = 1.0; sigma_r: float = 1.0
    alpha: float = 0.5
    eps_greedy: float = 0.05
    alpha_linucb: float = 2.5
    seed: int = 42

    block_len: int = 50
    oracle_thresh: int = 200_000
    oracle_samples: int = 20_000
    use_mp: bool = True

@dataclass
class SimCfg:
    T: int = 2000; seeds: int = 3; sys: Sys = field(default_factory=Sys)


def dft_codebook(N: int, L: int):
    idx = np.arange(N)
    return np.exp(-1j * 2 * np.pi * idx[None, :] * np.arange(L)[:, None] / L) / np.sqrt(N)

def steering(N: int, th):
    """Return (len(th), N) matrix of unit-norm steering vectors."""
    th = np.atleast_1d(th)
    idx = np.arange(N)[:, None]
    return (np.exp(1j * np.pi * idx * np.sin(th)) / np.sqrt(N)).T


class ISACEnv:
    def __init__(self, sys: Sys, rng: np.random.Generator):
        self.s = sys; self.rng = rng; self.slot = 0
        self.Wc = dft_codebook(sys.Ns, sys.Lb)
        self._regen_block()

    def _regen_block(self):
        s, rng = self.s, self.rng
        self.H = (rng.normal(size=(s.K, s.Nr, s.Nt)) +
                  1j * rng.normal(size=(s.K, s.Nr, s.Nt))) / math.sqrt(2 * s.Nt)
        self.theta = rng.uniform(-math.pi/3, math.pi/3, (s.K, s.M))
        self.beta = (rng.normal(size=(s.K, s.M)) +
                     1j * rng.normal(size=(s.K, s.M))) / math.sqrt(2)

        self.Wbs = np.stack([self.H[k].mean(axis=0) / la.norm(self.H[k].mean(axis=0))
                             for k in range(s.K)])
        self.pk = np.linspace(1.0, 0.2, s.K); self.pk = self.pk / self.pk.sum() * s.PBS
        self.order = np.argsort([la.norm(self.H[k]) for k in range(s.K)])[::-1]


        a_mat = steering(s.Ns, self.theta.reshape(-1)).reshape(s.K, s.M, s.Ns)
        self.gain = np.abs(np.einsum('ln,kmn->kml', self.Wc.conj(), a_mat))**2

        self._oracle = None


    def comm_rate(self, k):
        s = self.s
        idx = self.order.tolist().index(k)
        sig = self.pk[k] * la.norm(self.H[k] @ self.Wbs[k][:, None])**2
        interf = (sum(self.pk[j] * la.norm(self.H[k] @ self.Wbs[j][:, None])**2
                      for j in self.order[idx+1:]) + s.sigma_c)
        return math.log2(1 + sig / interf)

    def radar_rate(self, k, m, q, ell, acts):
        s = self.s
        num = q * abs(self.beta[k, m])**2 * self.gain[k, m, ell]
        inter = sum(qj * abs(self.beta[k, m])**2 * self.gain[k, m, ellj]
                    for j, (_, ellj, qj) in enumerate(acts) if j != k)
        return math.log2(1 + num / (s.sigma_r + inter))

    def step(self, acts):
        if self.slot and self.slot % self.s.block_len == 0:
            self._regen_block()

        s = self.s
        R = np.zeros(s.K); S = np.zeros(s.K)
        for k, (m, ell, q) in enumerate(acts):
            R[k] = self.comm_rate(k)
            S[k] = self.radar_rate(k, m, q, ell, acts)

        self.slot += 1
        r_k = s.alpha * R + (1 - s.alpha) * S
        return r_k.sum(), r_k


    def oracle(self):
      if self._oracle is not None:
          return self._oracle

      s = self.s; M, Lb, P = s.M, s.Lb, len(s.L_power)
      joint = (M * Lb * P) ** s.K

      def eval_(acts): return self.step(acts)[0]
      best = -1e9

      if joint <= s.oracle_thresh:           # exhaustive
        base = list(itertools.product(range(M), range(Lb), range(P)))
        for combo in itertools.product(base, repeat=s.K):
            acts = [(m, ell, s.L_power[p] * s.PSCD) for (m, ell, p) in combo]
            best = max(best, eval_(acts))
      else:                                   # Monte-Carlo serial
        rng = self.rng
        for _ in range(s.oracle_samples):
            acts = [(rng.integers(M),
                     rng.integers(Lb),
                     rng.choice(s.L_power) * s.PSCD) for _ in range(s.K)]
            best = max(best, eval_(acts))

      self._oracle = best
      return best

class PolicyBase:
    def __init__(self,env): self.e=env; self.s=env.s; self.K=self.s.K
    def act(self): ...
    def update(self,*_): ...


class RandomPol(PolicyBase):
    def act(self):
        r=self.e.rng; s=self.s; return [(r.integers(s.M), r.integers(s.Lb), r.choice(s.L_power)*s.PSCD) for _ in range(s.K)]
    def update(self,*_): pass


class _TabularBase(PolicyBase):
    def __init__(self,env):
        super().__init__(env); s=self.s; self.na=s.M*len(s.L_power)*s.Lb
        self.N=np.zeros((self.K,self.na)); self.S=np.zeros((self.K,self.na))
    def _idx(self,m,ell,p): return m*len(self.s.L_power)*self.s.Lb+p*self.s.Lb+ell
    def _decode(self,idx):
        per=len(self.s.L_power)*self.s.Lb
        m=idx//per; rest=idx%per; p=rest//self.s.Lb; ell=rest%self.s.Lb
        return m,p,ell

class GreedyPol(_TabularBase):
    def act(self):
        s=self.s; acts=[]
        for k in range(self.K):
            if np.random.rand()<s.eps_greedy or self.N[k].sum()==0:
                m=np.random.randint(s.M); p=np.random.randint(len(s.L_power)); ell=np.random.randint(s.Lb)
            else:
                idx=np.argmax(self.S[k]/np.maximum(1,self.N[k])); m,p,ell=self._decode(idx)
            q=s.L_power[p]*s.PSCD; acts.append((m,ell,q)); setattr(self,f'_p_{k}',p)
        return acts
    def update(self,acts,r_k):
        for k,(m,ell,q) in enumerate(acts):
            p=getattr(self,f'_p_{k}'); idx=self._idx(m,ell,p); self.N[k,idx]+=1; self.S[k,idx]+=r_k[k]

class CUCBPol(_TabularBase):
    def __init__(self,env): super().__init__(env); self.t=1
    def act(self):
        s=self.s; acts=[]
        for k in range(self.K):
            bonus=np.sqrt(2*np.log(self.t)/np.maximum(1,self.N[k]))
            idx=np.argmax(self.S[k]/np.maximum(1,self.N[k])+bonus); m,p,ell=self._decode(idx)
            q=s.L_power[p]*s.PSCD; acts.append((m,ell,q)); setattr(self,f'_p_{k}',p)
        self.t+=1; return acts
    def update(self,acts,r_k):
        for k,(m,ell,q) in enumerate(acts):
            p=getattr(self,f'_p_{k}'); idx=self._idx(m,ell,p); self.N[k,idx]+=1; self.S[k,idx]+=r_k[k]


class LinUCBPol(PolicyBase):
    def __init__(self,env):
        super().__init__(env); s=self.s; self.d=4+s.Lb
        self.A=[np.eye(self.d) for _ in range(self.K)]; self.b=[np.zeros(self.d) for _ in range(self.K)]
    def _phi(self,k,m,ell,q):
        s=self.s; x=np.zeros(self.d)
        x[0]=la.norm(self.e.H[k]); x[1]=math.sin(self.e.theta[k,m]); x[2]=math.cos(self.e.theta[k,m]); x[3]=q/s.PSCD; x[4+ell]=1
        return x
    def act(self):
        s=self.s; acts=[]
        for k in range(self.K):
            Ainv=la.inv(self.A[k]); theta=Ainv@self.b[k]; best=-1e9
            for m in range(s.M):
                for ell in range(s.Lb):
                    for p,qf in enumerate(s.L_power):
                        q=qf*s.PSCD; phi=self._phi(k,m,ell,q)
                        ucb=phi@theta+s.alpha_linucb*math.sqrt(phi@Ainv@phi)
                        if ucb>best: best,trip=ucb,(m,ell,q,phi)
            m,ell,q,phi=trip; acts.append((m,ell,q)); setattr(self,f'_phi{k}',phi)
        return acts
    def update(self,acts,r_k):
        for k in range(self.K):
            phi=getattr(self,f'_phi{k}'); self.A[k]+=np.outer(phi,phi); self.b[k]+=r_k[k]*phi

import plotly.graph_objects as go
import json, os, time

def run_static(cfg: SimCfg, verbose: bool = True):
    s      = cfg.sys
    names  = ["LinUCB", "CUCB", "Greedy", "Random"]
    log    = {n: {"cum_reg": [], "inst_reg": [], "thr": [],
                  "stab": [], "cpu": []} for n in names}

    header = "{:<8}{:<7}{:>9}{:>9}{:>9}{:>7}".format(
             "slot", "alg", "thr", "instReg", "stab", "cpu")
    if verbose:
        print("--- STATIC CHANNEL RUN ---")
        print(header); print("-"*len(header))

    for sd in range(cfg.seeds):
        rng   = np.random.default_rng(s.seed + sd)

        env   = ISACEnv(s, rng)
        env._regen_block = lambda: None
        pols  = dict(LinUCB=LinUCBPol(env), CUCB=CUCBPol(env),
                     Greedy=GreedyPol(env), Random=RandomPol(env))

        prev_target = {n: -np.ones(s.K, dtype=int) for n in names}
        oracle      = env.oracle()

        for t in range(cfg.T):
            for name, pol in pols.items():
                tic   = time.perf_counter()
                acts  = pol.act()
                glob, r_k = env.step(acts)
                pol.update(acts, r_k)
                cpu   = time.perf_counter() - tic

                inst = oracle - glob
                cum  = inst if t == 0 else inst + log[name]["cum_reg"][-1]
                stab = np.mean([acts[k][0] == prev_target[name][k]
                                for k in range(s.K)])  # always defined

                log[name]["thr"].append(glob)
                log[name]["inst_reg"].append(inst)
                log[name]["cum_reg"].append(cum)
                log[name]["stab"].append(stab)
                log[name]["cpu"].append(cpu)
                prev_target[name] = np.array([acts[k][0] for k in range(s.K)])

                if verbose:
                    print("{:<8}{:<7}{:>9.3f}{:>9.3f}{:>9.2f}{:>7.4f}".format(
                          t+1, name, glob, inst, stab, cpu))
    return log

def run(cfg: SimCfg, verbose: bool = False, log_dir: str = "logs"):
    os.makedirs(log_dir, exist_ok=True)
    s, names = cfg.sys, ["LinUCB", "CUCB", "Greedy", "Random"]
    log = {n: {"cum_reg": [], "inst_reg": [], "thr": [],
               "stab": [], "cpu": []} for n in names}

    for sd in range(cfg.seeds):

        log_file = open(os.path.join(log_dir, f"decisions_seed{sd}.txt"), "w")
        log_file.write("# slot  algorithm   decisions_json\n")

        rng   = np.random.default_rng(s.seed + sd)
        env   = ISACEnv(s, rng)
        pols  = dict(LinUCB=LinUCBPol(env), CUCB=CUCBPol(env),
                     Greedy=GreedyPol(env), Random=RandomPol(env))

        prev_target = {n: -np.ones(s.K, dtype=int) for n in names}
        oracle = env.oracle()

        for t in range(cfg.T):
            block_pos = t % s.block_len
            if block_pos == 0:
                oracle = env.oracle()

            for name, pol in pols.items():
                tic   = time.perf_counter()
                acts  = pol.act()
                glob, r_k = env.step(acts)
                pol.update(acts, r_k)
                cpu   = time.perf_counter() - tic

                serial_acts = [(int(m), int(ell), float(q)) for (m, ell, q) in acts]
                log_file.write(f"{t+1:<5} {name:<9} {json.dumps(serial_acts)}\n")



                inst = oracle - glob
                cum  = inst if t == 0 else inst + log[name]["cum_reg"][-1]
                stab = (np.mean([acts[k][0] == prev_target[name][k]
                                 for k in range(s.K)])
                        if block_pos else np.nan)

                log[name]["thr"].append(glob)
                #log[name]["inst_reg"].append(inst)
                log[name]["cum_reg"].append(cum)
                log[name]["stab"].append(stab)
                #log[name]["cpu"].append(cpu)
                prev_target[name] = np.array([acts[k][0] for k in range(s.K)])

        log_file.close()
    return log

if __name__ == "__main__":
    cfg = SimCfg(T=5000, seeds=1)
    res = run_static(cfg, verbose=False)
    blk = slice((cfg.T // cfg.sys.block_len - 1)*cfg.sys.block_len, cfg.T)
    for alg in res:
        print(f"{alg:8}",
              f"threshold={np.mean(res[alg]['thr']):6.3f}",
              f"instantaneous average regret={np.mean(res[alg]['inst_reg']):6.3f}",
              f"stability(single block)={np.nanmean(res[alg]['stab'][blk]):5.2f}",
              #f"computation time={np.mean(res[alg]['cpu']):6.4f}"
              )
    t = np.arange(1, cfg.T+1)
    figs = {}
    figs["Throughput"] = go.Figure()
    figs["CumReg"]     = go.Figure()
    #figs["InstReg"]    = go.Figure()
    figs["Stability"]  = go.Figure()
    #figs["CPU"]        = go.Figure()

    for alg in res:
        figs["Throughput"].add_trace(go.Scatter(x=t, y=res[alg]["thr"],
                                    mode='lines', name=alg))
        figs["CumReg"].add_trace(go.Scatter(x=t, y=res[alg]["cum_reg"],
                                    mode='lines', name=alg))
        #figs["InstReg"].add_trace(go.Scatter(x=t, y=res[alg]["inst_reg"],
         #                           mode='lines', name=alg))
        figs["Stability"].add_trace(go.Scatter(x=t, y=res[alg]["stab"],
                                    mode='lines', name=alg))
        #figs["CPU"].add_trace(go.Scatter(x=t, y=res[alg]["cpu"],
           #                         mode='lines', name=alg))

    figs["Throughput"].update_layout(title="Throughput per Slot",
        xaxis_title="slot", yaxis_title="bits")
    figs["CumReg"].update_layout(title="Regret",
        xaxis_title="slot", yaxis_title="regret")
    #figs["InstReg"].update_layout(title="Instantaneous Regret",
     #   xaxis_title="slot", yaxis_title="bits")
    figs["Stability"].update_layout(title="Stability",
        xaxis_title="slot", yaxis_title="stability")
    #figs["CPU"].update_layout(title="compute time",
     #   xaxis_title="slot", yaxis_title="seconds")

    for fig in figs.values(): fig.show()

LinUCB   threshold= 2.909 instantaneous average regret=-0.162 stability(single block)= 1.00
CUCB     threshold= 1.472 instantaneous average regret= 1.275 stability(single block)= 0.47
Greedy   threshold= 2.846 instantaneous average regret=-0.099 stability(single block)= 0.96
Random   threshold= 1.355 instantaneous average regret= 1.392 stability(single block)= 0.30
