In [None]:
COLAB: bool = True
if COLAB:
  !git clone https://github.com/RubenCid35/6GSmartRRM
  !mv 6GSmartRRM/* .
  !pip install -e .

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install -q torch_geometric
%pip install -q pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html

In [None]:
# simple data manipulation
import numpy  as np
import pandas as pd
import numpy.typing as npt

# deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
from   torch.utils.data import DataLoader, TensorDataset, random_split

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from collections import defaultdict

# results logging
import wandb
wandb.login()

# progress bar
from   tqdm.notebook import tqdm, trange

# remove warnings (remove deprecated warnings)
import warnings
warnings.simplefilter('ignore')

# visualization of resultsa
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from   matplotlib.ticker import MaxNLocator
import seaborn           as sns

# Graph Algorithms.
import networkx as nx

# Google Colab (many lines are removed)
import os
import zipfile
from google.colab import drive
from distutils.dir_util import copy_tree

# wheter we are using colab or not
if not COLAB and not os.path.exists('./data/simulations'):
    os.chdir('..')

# Simulation Settings
from g6smart.sim_config import SimConfig
from g6smart.evaluation import rate_torch as rate_metrics
from g6smart.baseline.subband.sisa import sisa_algoritm
config = SimConfig(0)
config

In [None]:
def setup_wandb(name: str, group: str, config: dict[str, float], id: str = None):
    config['name'] = name
    return wandb.init(
        project="6GSmartRRM",
        name   = name,
        id     = id,
        group  = group,
        config = config,
        resume = "allow" if id is None else "must"
    )

## Simulations and Information

Thanks to the given scripts, we can load a group of generated simulations. They don't have any solutions (neither approximations).

In [None]:
# Moung Google Drive Code
if COLAB:
    drive.mount('/content/drive')

    # Move Simulations to avoid cluttering the drive folder
    if not os.path.exists('/content/simulations'):
      os.mkdir('/content/simulations')

    print("simulations folder:", list(os.listdir('/content/simulations')))
    if len(os.listdir('/content/simulations')) == 0:
      copy_tree('/content/drive/MyDrive/TFM/simulations', '/content/simulations')

    # unzip all simulations
    print("Name of the already simulated data: \n", )
    for zip_file in os.listdir('/content/simulations'):
        if zip_file.endswith('.zip'):
            print(" ----> " + zip_file)
            with zipfile.ZipFile("/content/simulations/" + zip_file, 'r') as zip_ref:
                zip_ref.extractall('/content/simulations/')

    SIMULATIONS_PATH: str = "/content/simulations"
    MODELS_PATH: str = "/content/drive/MyDrive/TFM/models/"
else:
    if not os.path.exists('./data/simulations'): os.mkdir('./data/simulations')
    for zip_file in os.listdir('data'):
        if zip_file.endswith('.zip'):
            print(" ----> " + zip_file)
            with zipfile.ZipFile("./data/" + zip_file, 'r') as zip_ref:
                zip_ref.extractall('./data/simulations')
    SIMULATIONS_PATH: str = "./data/simulations"
    MODELS_PATH: str = "./models/"
    if not os.path.exists(MODELS_PATH):
      os.mkdir(MODELS_PATH)


In [None]:
cmg   = np.load(SIMULATIONS_PATH + '/Channel_matrix_gain.npy')
sisa_alloc = np.load(SIMULATIONS_PATH + '/sisa-allocation.npy')

# get sample from all
n_sample = 110_000
cmg   = cmg[:n_sample]
sisa_alloc = sisa_alloc[:n_sample].astype(int)

n_sample = B = cmg.shape[0]
K, N, _  = cmg.shape[1:]

shape    = lambda s: " x".join([f"{d:3d}" for d in s])
print(f"channel    matrix shape: {shape(cmg.shape)} \nallocation matrix shape: {shape(sisa_alloc.shape)}")

In [None]:
def metrics(C, A, P):
  C = torch.tensor(C)
  A = torch.tensor(A)
  P = torch.tensor(P) if P is not None else None
  A    = rate_metrics.onehot_allocation(A, 4, 20)
  sinr = rate_metrics.signal_interference_ratio(config, C, A, P)
  shannon  = torch.sum(A * torch.log2(1 + sinr), dim = 1).numpy()
  return float(shannon.mean()), float(shannon.min()), float(shannon.max())

In [None]:
from torch.utils.data import DataLoader, random_split, TensorDataset

# === CONFIG ===
BATCH_SIZE: int = 512
TRAIN_SAMPLE: int = 65_000
VALID_SAMPLE: int = 25_000
TESTS_SAMPLE: int = 20_000

# === Convert numpy arrays to tensors ===
whole_data       = torch.tensor(cmg).float()                  # [total_samples, K, N, N]
sisa_data        = torch.tensor(sisa_alloc).long()            # [total_samples, N]
power_data       = torch.full_like(sisa_data, config.transmit_power, dtype = torch.float32)

# === Create data splits ===
train_idx, valid_idx, tests_idx = random_split(
    range(len(whole_data)),
    [TRAIN_SAMPLE, VALID_SAMPLE, TESTS_SAMPLE],
    generator=torch.Generator().manual_seed(101)
)

# === Slice tensors ===
train_cmg   = whole_data[train_idx]
valid_cmg   = whole_data[valid_idx]
tests_cmg   = whole_data[tests_idx]

train_sisa  = sisa_data[train_idx]
valid_sisa  = sisa_data[valid_idx]
tests_sisa  = sisa_data[tests_idx]

train_power = power_data[train_idx]
valid_power = power_data[valid_idx]
tests_power = power_data[tests_idx]

# === Create PyTorch Datasets ===
train_dataset = TensorDataset(train_cmg, train_sisa, train_power)
valid_dataset = TensorDataset(valid_cmg, valid_sisa, valid_power)
tests_dataset = TensorDataset(tests_cmg, tests_sisa, tests_power)

# === Wrap in DataLoaders ===
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
tests_loader = DataLoader(tests_dataset, batch_size=BATCH_SIZE, shuffle=False)


# Joint Allocation: Graph Neural Network

This section implements the following paper:
* Graph Neural Networks Approach for
Joint Wireless Power Control and
Spectrum Allocation


## Graph Representation


In [None]:
from torch_geometric.data import Data
from typing import List, Tuple
def transform2graphs(cmg: torch.Tensor) ->  List[Data]:
  B, K, N, _ = cmg.shape

  # precompute all node features
  node_feats = cmg.diagonal(dim1 = 2, dim2 = 3).unsqueeze(-1)

  # create edge indices for all graphs
  rows = torch.arange(N).repeat_interleave(N-1)
  cols = torch.cat([
      torch.cat([
          torch.arange(0, i),
          torch.arange(i+1, N)
      ])
      for i in range(N)
  ])

  edge_index = torch.stack([cols, rows], dim = 0)

  # mask self signal
  mask = ~torch.eye(N, dtype=torch.bool, device=cmg.device)
  mask = mask.reshape(1, 1, N, N).expand(B, K, -1, -1)
  edge_attrs = cmg[mask].view(B, K, -1, 1)

  graphs = []
  for b in range(B):
      batch_graphs = []
      for k in range(K):
          graph = Data(
              x=node_feats[b, k],
              edge_index=edge_index.clone(),
              edge_attr=edge_attrs[b, k]
          )
          batch_graphs.append(graph)
      graphs.append(batch_graphs)

  return graphs

In [None]:
g = torch.tensor(cmg[[0, 1]]).float()
print("channel state matrix:", g.shape)
graphs = transform2graphs(g)
print("nº of samples:", len(graphs), "\nnª of subgraphs:", len(graphs[0]))

## Graph Neural Network

In [None]:
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing

class MLP(nn.Module):
  def __init__(
      self, input_size: int, output_size: int,
      hidden_layers: int, hidden_dim: int,
      normalize: bool = False, dropout: float | None = None
    ):
    super().__init__()

    layers = []
    dims   = [input_size] + [hidden_dim] * (hidden_layers) + [output_size]
    for _in, _out in zip(dims[:-1], dims[1:]):
      layers.append(nn.Linear(_in, _out))
      layers.append(nn.ReLU())
      if normalize: layers.append(nn.BatchNorm1d(_out))
      if isinstance(dropout, float) and 0 < dropout < 1:
        layers.append(nn.Dropout(dropout))

    avoid_layers = int(normalize) + int(isinstance(dropout, float) and 0 < dropout < 1)
    if avoid_layers > 1:
      layers = layers[:-avoid_layers]
    self.mlp = nn.Sequential(*layers)

  def forward(self, X: torch.Tensor) -> torch.Tensor:
    return self.mlp(X)


class GNNExtractor(MessagePassing):
  def __init__(
      self, node_dim: int, edge_dim: int,
      hidden_layers: int, hidden_dim: int,
      normalize: bool = False, dropout: float | None = None
    ):
    super().__init__(aggr='max')
    self.mlp_msg = MLP(
      node_dim + edge_dim  , hidden_dim, hidden_layers, hidden_dim,
      normalize = normalize, dropout = dropout
    )
    self.mlp_upd = MLP(
      node_dim + hidden_dim, hidden_dim, hidden_layers, hidden_dim,
      normalize = normalize, dropout = dropout
)

  def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
      return self.propagate(edge_index, x=x, edge_attr=edge_attr)

  def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
      return self.mlp_msg(torch.cat([x_j, edge_attr], dim=1))

  def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
      return self.mlp_upd(torch.cat([x, aggr_out], dim=1))

In [None]:
from torch_geometric.data import Batch

class GNNEncoder(nn.Module):
  def __init__(
    self,node_dim: int, edge_dim: int,
    hidden_dim: int, hidden_layers: int, num_layers: int,
    normalize: bool = False, dropout: float | None = None,
  ):
    super().__init__()
    self.layers = nn.ModuleList([
        GNNExtractor(
            node_dim if i == 0 else hidden_dim,
            edge_dim,
            hidden_layers,
            hidden_dim,
            normalize=normalize,
            dropout=dropout,
        ) for i in range(num_layers)
    ])

  def forward(self, graphs: List[List[Data]]) -> torch.Tensor:
    data_list = [g for batch in graphs for g in batch]
    batch_obj = Batch.from_data_list(data_list)

    # Process all graphs simultaneously
    x = batch_obj.x
    for layer in self.layers:
        x = layer(x, batch_obj.edge_index, batch_obj.edge_attr)

    # Reshape back to original structure (B, N, K, D)
    B = len(graphs)
    K = len(graphs[0])
    N = x.size(0) // (B * K)  # Original number of nodes
    D = x.size(-1)

    # Reshape: (B*K*N, D) -> (B, K, N, D) -> (B, N, K, D)
    return x.view(B, K, N, D).transpose(1, 2)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax

class EGATLayer(MessagePassing):
    def __init__(
        self,
        node_dim: int,
        edge_dim: int,
        out_dim: int,
        heads: int = 1,
        dropout: float = 0.0,
        negative_slope: float = 0.2,
    ):
        super().__init__(aggr='add', node_dim=0)
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.out_dim = out_dim
        self.heads = heads
        self.neg_slope = negative_slope
        self.dropout = dropout

        # Linear transformations for nodes and edges
        self.lin_node = nn.Linear(node_dim, heads * out_dim, bias=False)
        self.lin_edge = nn.Linear(edge_dim, heads * out_dim, bias=False)

        # Attention logit parameters (for nodes + edges)
        self.attn_src = nn.Parameter(torch.Tensor(1, heads, out_dim))
        self.attn_dst = nn.Parameter(torch.Tensor(1, heads, out_dim))
        self.attn_edge = nn.Parameter(torch.Tensor(1, heads, out_dim))

        # Final bias and output projection
        self.bias = nn.Parameter(torch.Tensor(heads * out_dim))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_node.weight)
        nn.init.xavier_uniform_(self.lin_edge.weight)
        nn.init.xavier_uniform_(self.attn_src)
        nn.init.xavier_uniform_(self.attn_dst)
        nn.init.xavier_uniform_(self.attn_edge)
        nn.init.zeros_(self.bias)

    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor:
        # Project nodes and edges into multi-head space
        x_proj = self.lin_node(x).view(-1, self.heads, self.out_dim)
        edge_proj = self.lin_edge(edge_attr).view(-1, self.heads, self.out_dim)

        # Propagate messages (calls message() and aggregate())
        out = self.propagate(
            edge_index,
            x=x_proj,
            edge_attr=edge_proj,
        )

        # Combine heads and add bias
        out = out.view(-1, self.heads * self.out_dim) + self.bias
        return out

    def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, index: Tensor) -> Tensor:
        # Compute attention scores using both node and edge features
        alpha_src = (x_j * self.attn_src).sum(-1)
        alpha_dst = (x_i * self.attn_dst).sum(-1)
        alpha_edge = (edge_attr * self.attn_edge).sum(-1)

        alpha = alpha_src + alpha_dst + alpha_edge
        alpha = F.leaky_relu(alpha, self.neg_slope)
        alpha = softmax(alpha, index)  # Normalize attention weights

        # Apply dropout to attention weights
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        # Weight messages by attention
        return x_j * alpha.unsqueeze(-1)

In [None]:
class GNNEncoder(nn.Module):
    def __init__(
        self,
        node_dim: int,
        edge_dim: int,
        hidden_dim: int,
        hidden_layers: int,
        num_layers: int,
        heads: int = 1,
        dropout: float | None = None,
    ):
        super().__init__()
        self.layers = nn.ModuleList([
            EGATLayer(
                node_dim if i == 0 else hidden_dim * heads,
                edge_dim,
                hidden_dim,
                heads=heads,
                dropout=dropout if dropout else 0.0,
            )
            for i in range(num_layers)
        ])

    def forward(self, graphs: List[List[Data]]) -> torch.Tensor:
        data_list = [g for batch in graphs for g in batch]
        batch_obj = Batch.from_data_list(data_list)

        x = batch_obj.x
        for layer in self.layers:
            x = layer(x, batch_obj.edge_index, batch_obj.edge_attr)

        B = len(graphs)
        K = len(graphs[0])
        N = x.size(0) // (B * K)
        D = x.size(-1)

        return x.view(B, K, N, D).transpose(1, 2)

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels: int, mid_channels: List[int]):
        super().__init__()
        layers = []
        for out_channels in mid_channels:
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm2d(out_channels))
            in_channels = out_channels
        self.conv = nn.Sequential(*layers)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        # z: (B, N, K, D)
        z = z.permute(0, 3, 1, 2)  # → (B, D, N, K)
        x = self.conv(z)           # → (B, 1, N, K) after final channel = 1
        return x.squeeze(1)        # → (B, N, K)

class AllocationHead(nn.Module):
    def __init__(self, embed_dim: int, K: int, p_max: float):
        super().__init__()
        self.cnn = CNNBlock(embed_dim, [8, 4, 1])
        self.p_max = p_max

    def forward(self, Z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Z: (B, N, K, D)
        X = self.cnn(Z)                      # (B, N, K)
        channel_probs = F.softmax(X, dim=-1)   # (B, N, K)
        channel_probs = channel_probs.permute(0, 2, 1)
        power = F.sigmoid(X).mean(dim=-1) * self.p_max  # (B, N)
        return channel_probs, power


class GNNJointResourceModel(nn.Module):
    def __init__(
        self, gnn_args: dict, head_args: dict,
    ):
      super().__init__()
      self.encoder = GNNEncoder(**gnn_args)
      self.head = AllocationHead(**head_args)

    def forward(self, graphs: List[List[Data]]) -> Tuple[torch.Tensor, torch.Tensor]:
      Z = self.encoder(graphs)
      return self.head(Z)

In [None]:
HIDDEN_DIM: int = 10
encoder_args = dict(
  node_dim=1, edge_dim=1, hidden_dim=HIDDEN_DIM, num_layers=3,
  hidden_layers=1, dropout=None
)
alloc_args = dict(embed_dim=HIDDEN_DIM, K=4, p_max=1e-3)


In [None]:
g = torch.tensor(cmg[[0, 1]]).float()
print("channel state matrix:", g.shape)
graphs = transform2graphs(g)
print("nº of samples:", len(graphs), "\nnª of subgraphs:", len(graphs[0]))
model = GNNJointResourceModel(encoder_args, alloc_args)
channel, power = model(graphs)
print("power alloc:", power.shape)
print("channel alloc:", channel.shape)

print("graph 1 & node 1:")
print("power: ", power[0, 0])
print("power: ", channel[0, :, 0])

## Training Loop

In [None]:
from IPython.display import clear_output
def real_time_plot(*metrics):
    names = ['training', 'validation']
    assert len(metrics) % 2 == 0, "A odd pair of metrics is required"
    clear_output(wait=True)  # Clear the previous plot

    fig, ax = plt.subplots(1, 2, figsize=(14, 4))  # Two subplots, stacked vertically
    # Plot loss
    for i, loss in enumerate(metrics[:len(metrics) // 2]):
      ax[0].plot(loss, label = f"loss: {names[i]}")
    ax[0].set_title('Real-Time Loss')
    ax[0].set_xlabel('Epochs')
    ax[0].set_ylabel('Loss')
    ax[0].legend()

    # Plot metric (e.g., SINR or accuracy)
    for i, metric in enumerate(metrics[len(metrics) // 2:]):
      ax[1].plot(metric, label = f"loss: {names[i]}")
    ax[1].set_title('Real-Time Metric')
    ax[1].set_xlabel('Epochs')
    ax[1].set_ylabel('Bit Rate (Mbps)')
    ax[1].legend()

    plt.tight_layout()  # Adjust layout to avoid overlap
    plt.show()

In [None]:
def supervised_loss_from_targets(
    channel_probs: torch.Tensor,      # [B, N, K]
    sisa_alloc: torch.Tensor,         # [B, N] → integers in [0, K-1]
    pred_power: torch.Tensor,         # [B, N]
    power_alloc_max: torch.Tensor     # [B, N]
) -> torch.Tensor:
    cce = nn.CrossEntropyLoss()
    mse = nn.MSELoss()

    # reshape to 2D for batch processing
    L1 = cce(channel_probs.view(-1, channel_probs.size(-1)), sisa_alloc.view(-1))
    L2 = mse(pred_power, power_alloc_max)
    return L1 + L2

def min_approx(x: torch.Tensor, p: float = 1e2):
    """
    Differentiable Approximation of Minimum Function. This function approximates
    the value of min(x)

      # based on fC https://mathoverflow.net/questions/35191/a-differentiable-approximation-to-the-minimum-function
    """
    mu = 0
    inner = torch.mean(torch.exp(- p * (x - mu)), dim = 1)
    return mu - (1 / p) * torch.log(inner)

def base_loss( rate: torch.Tensor, mode: str = 'sum', p: int = 10 ) -> torch.Tensor:
    if mode == 'sum':
      loss_rate = torch.sum(rate, dim = 1)
    elif mode == 'min':
      loss_rate = min_approx(rate, p)
    elif mode == 'mean':
      loss_rate = torch.mean(rate, dim = 1)
    return - loss_rate

def requirement_loss( rate: torch.Tensor, req: float = 1 ):
  # in ideal conditions, this is equal to the spectral efficiency as all
  # subbands have the bandwidth.
  return F.relu(req - rate).mean(dim = 1)

def joint_loss(
    config: SimConfig, C: torch.Tensor, A: torch.Tensor, P: torch.Tensor,
    req: float = 1., violation_penality: float = 1.
):

  A    = rate_metrics.onehot_allocation(A, 4, 20)
  sinr = rate_metrics.signal_interference_ratio(config, C, A, P)
  rate = torch.sum(A * torch.log2(1 + sinr), dim = 1)
  base = base_loss(rate, mode = "min", p = 10)
  reqs = requirement_loss(rate, req = req)
  return base + violation_penality * reqs, base, reqs


In [None]:
def binarization_error(alloc: torch.Tensor) -> float:
    rounded = torch.round(alloc)
    return torch.mean(torch.abs(alloc - rounded))

def update_metrics(metrics, A, C, P, config, req):
    A    = rate_metrics.onehot_allocation(A, 4, 20)
    sinr = rate_metrics.signal_interference_ratio(config, C, A, P, False)
    rate = rate_metrics.bit_rate(config, sinr, A)
    fairness = rate_metrics.jain_fairness(rate)
    spectral = rate_metrics.spectral_efficency(config, rate)
    plf      = rate_metrics.proportional_loss_factor(config, C, A, P)

    shannon  = torch.sum(A * torch.log2(1 + sinr), dim = 1)
    ecf_req  = torch.mean((shannon >= req).float(), dim = 1)

    metrics['bit-rate'] += rate.mean().item() / 1e6
    metrics['jain-fairness'] += fairness.mean().item()
    metrics['spectral-efficency'] += spectral.mean().item()
    metrics['proportional-loss' ] += plf.mean().item()
    metrics['over-requirement' ] += ecf_req.mean().item()
    return metrics

In [None]:
MAX_EPOCH : int = 7
LR: float  = 1e-1

# under ideal conditions, the sisa ideal shannon rate is around 4.
REQ: float      = 3.
VIOLATION_PENALITY: float = 10.

learning_config = {
    'loss': 'pure-min-rate',
    'max-epoch': MAX_EPOCH,
    'batch-size': BATCH_SIZE,
    'learning-rate': LR,
    'desired-min-rate' : REQ,
    'qos-penality': VIOLATION_PENALITY,
    'train-valid-split' : f"{TRAIN_SAMPLE}-{VALID_SAMPLE}"
}

# training config
HIDDEN_DIM: int = 32
encoder_args = dict(
  node_dim=1, edge_dim=1, hidden_dim=HIDDEN_DIM, num_layers=3,
  hidden_layers=2, normalize=False, dropout=None
)
alloc_args = dict(embed_dim=HIDDEN_DIM, K=4, p_max=config.transmit_power)


name  = "p1-joint-000-v1"
training_config = {}
training_config["encoder"] = encoder_args
training_config["alloc"] = alloc_args

try: wandb.finish(quiet = True)
except: pass
run = setup_wandb(name, 'rate-confirming', training_config, id = None)
print("run config:", run.config)

model = GNNJointResourceModel(encoder_args, alloc_args).to(device)
optimizer = optim.Adam(model.parameters(), LR, weight_decay=1e-5)
scheduler = lrs.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-4)

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

train_loss, valid_loss, train_rate, valid_rate = [], [], [], []
for epoch in trange(MAX_EPOCH, desc = "training epoch", unit = "epoch"):
    real_time_plot(train_loss, valid_loss, train_rate, valid_rate)

    # training step
    model.train()
    training_loss = 0.
    train_binary_loss = 0.
    train_rate_loss = 0.
    train_reqs_loss = 0.

    desc = f'training step (epoch: {epoch:03d}):'
    training_metrics = defaultdict(lambda : 0)
    for sample, sisa_alloc, base_power in tqdm(train_loader, desc = desc, unit = 'batch', total = len(train_loader), leave=False):
        optimizer.zero_grad()

        sample = sample.to(device)                          # [B, K, N, N]
        # Vectorized graph processing
        graphs = transform2graphs(sample)
        graphs = [
            [g.to(device, non_blocking=True) for g in batch]
            for batch in graphs
        ]

        # Mixed precision
        #with torch.cuda.amp.autocast():
        alloc_prob, power_alloc = model(graphs)
        loss, sub_loss, req_loss = joint_loss(
          config = config, C = sample, A = alloc_prob, P = power_alloc,
          req = REQ, violation_penality = VIOLATION_PENALITY
        )
        loss = loss.mean()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        training_loss += loss.item()
        train_binary_loss += binarization_error(alloc_prob).item()
        train_rate_loss += sub_loss.mean().item()
        train_reqs_loss += req_loss.mean().item()
        # Update metrics (ensure no GPU retention)
        training_metrics = update_metrics(training_metrics, alloc_prob.detach(), sample, power_alloc.detach(), config, REQ)

        # Cleanup
        del sample, graphs, alloc_prob, power_alloc, loss, sub_loss, req_loss
        torch.cuda.empty_cache()

    training_loss = training_loss / len(train_loader)
    train_binary_loss = train_binary_loss / len(train_loader)
    train_rate_loss   = train_rate_loss / len(train_loader)
    train_reqs_loss   = train_reqs_loss / len(train_loader)
    training_metrics = { 'train-' + key: val / len(train_loader) for key, val in training_metrics.items()}

    # validation step
    model.eval()
    validation_loss = 0.
    valid_binary_loss = 0.
    valid_rate_loss = 0.
    valid_reqs_loss = 0.

    desc = f'validation step (epocch: {epoch:03d}):'
    validation_metrics = defaultdict(lambda : 0)
    for sample, sisa_alloc, base_power in tqdm(valid_loader, desc = desc, unit = 'batch', total = len(valid_loader), leave=False):
        sample = sample.to(device)                          # [B, K, N, N]
        graphs = transform2graphs(sample)
        graphs = [
            [g.to(device, non_blocking=True) for g in batch]
            for batch in graphs
        ]

        # Mixed precision
        # with torch.cuda.amp.autocast():
        alloc_prob, power_alloc = model(graphs)
        loss, sub_loss, req_loss = joint_loss(
          config = config, C = sample, A = alloc_prob, P = power_alloc,
          req = REQ, violation_penality = VIOLATION_PENALITY
        )
        loss = loss.mean()

        validation_loss   += loss.item()
        valid_binary_loss += binarization_error(alloc_prob).item()
        valid_rate_loss += sub_loss.mean().item()
        valid_reqs_loss += req_loss.mean().item()
        # Update metrics (ensure no GPU retention)
        validation_metrics = update_metrics(
            validation_metrics, alloc_prob.detach(), sample, power_alloc.detach(),
            config, REQ
        )

        # Cleanup
        del sample, graphs, alloc_prob, power_alloc, loss, sub_loss, req_loss
        torch.cuda.empty_cache()

    validation_loss = validation_loss / len(valid_loader)
    valid_binary_loss = valid_binary_loss / len(valid_loader)
    valid_rate_loss   = valid_rate_loss / len(valid_loader)
    valid_reqs_loss   = valid_reqs_loss / len(valid_loader)
    validation_metrics = { 'valid-' + key: val / len(valid_loader) for key, val in validation_metrics.items()}

    logged_values = {
      'train-loss': training_loss, 'valid-loss': validation_loss,
      'train-base-loss': valid_rate_loss, 'valid-base-loss': valid_rate_loss,
      'train-violation-loss': train_reqs_loss, 'valid-violation-loss': valid_reqs_loss,
      'train-binary-loss': train_binary_loss, 'valid-binary-loss': valid_binary_loss
    }

    logged_values.update(training_metrics)
    logged_values.update(validation_metrics)

    train_loss.append(training_loss)
    valid_loss.append(validation_loss)
    train_rate.append(training_metrics['train-bit-rate'])
    valid_rate.append(validation_metrics['valid-bit-rate'])
    wandb.log(logged_values)

wandb.finish()

In [None]:
wandb.finish()
del sample, graphs, alloc_prob, power_alloc, loss, sub_loss, req_loss
torch.cuda.empty_cache()

In [None]:
orates = []
srates = []

model.eval()
for sample, sisa_alloc, ploc in tqdm(tests_loader, desc = 'testing step:', unit = 'batch', total = len(tests_loader), leave=False):
    sample = sample.to(device)                          # [B, K, N, N]
    sisa_alloc = sisa_alloc.to(device)                          # [B, K, N, N]
    ploc = ploc.to(device)                          # [B, K, N, N]
    graphs = transform2graphs(sample)
    graphs = [
        [g.to(device, non_blocking=True) for g in batch]
        for batch in graphs
    ]
    alloc_prob, P = model(graphs)
    A = torch.argmax(alloc_prob, dim = 1)

    A    = rate_metrics.onehot_allocation(A, 4, 20)
    sinr = rate_metrics.signal_interference_ratio(config, sample, A.detach(), P.detach())
    rate = rate_metrics.bit_rate(config, sinr, A).cpu().numpy()
    orates.append(rate / 1e6)

    A    = rate_metrics.onehot_allocation(sisa_alloc, 4, 20)
    sinr = rate_metrics.signal_interference_ratio(config, sample, sisa_alloc, ploc)
    rate = rate_metrics.bit_rate(config, sinr, sisa_alloc).cpu().numpy()
    srates.append(rate / 1e6)

    del sample, graphs, alloc_prob, P, A, sisa_alloc, ploc
    torch.cuda.empty_cache()

orates = np.concat(orates, axis = 0).flatten()
srates = np.concat(srates, axis = 0).flatten()

In [None]:
from g6smart import evaluation as evals
from scipy.interpolate import interp1d

def add_cdf_plot(sample, label, color = None, ax = None, show_legend = True):
    ax = ax or plt.gca()

    # add CDF
    pos, cdf = evals.get_cdf(sample)
    ax.plot(pos, cdf, label = label, color = color)

    # cdf func
    cdf_func = interp1d(pos, cdf, bounds_error=False, fill_value=(0, 1))
    # add percentiles
    percentiles = [ 0.0005, 0.05, 0.25, 0.50]
    qs = np.quantile(sample, percentiles)
    styles = ['-', '-.', '--', ":"]
    for s, p, q in zip(styles, percentiles, qs):
        ax.plot(
            [q, q], [0, float(cdf_func(q))],
            linestyle=s, alpha=0.6, color = "black",
            label=f'Q{p}' if show_legend else None
        )

plt.figure(figsize = (5.5, 5))

rates = np.random.random(size = 500) * 40
add_cdf_plot(orates, "Ours", "red", show_legend=True)
add_cdf_plot(srates, "SISA", "blue", show_legend=False)

# Sort legend: group quantiles at the end
handles, labels = plt.gca().get_legend_handles_labels()
normal = [(h, l) for h, l in zip(handles, labels) if not l.startswith("Q")]
quantiles = [(h, l) for h, l in zip(handles, labels) if l.startswith("Q")]
handles, labels = zip(*(normal + quantiles))
plt.legend(handles, labels)
plt.ylim(bottom=0)
plt.xlim(left=0)
plt.xlabel("Bit Rate (Mbps)")
plt.ylabel("CDF")
plt.title("Bit Rate's CDF")

plt.tight_layout()
plt.show()

In [None]:
BASE = "/content/drive/MyDrive/TFM/results"
np.save(BASE + "/test-ours.npy", orates)
np.save(BASE + "/test-sisa.npy", srates)