In [None]:
COLAB: bool = False
if COLAB:
  !git clone https://github.com/RubenCid35/6GSmartRRM
  !mv 6GSmartRRM/* /content/
  !pip install -e .
  from google.colab import drive
  drive.mount('/content/drive', force_remount=True)

In [None]:
# vast ai check gpu for invalid specs
!nvidia-smi -q | grep 'Power Limit'

In [None]:
%load_ext autoreload
%autoreload 2
!pip install -q wandb matplotlib seaborn
!pip install -q torch_geometric

In [None]:
# simple data manipulation
import numpy  as np

# deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
import torch.cuda.amp as amp # For Automatic Mixed Precision

from functools import partial

from torch.utils.data import Dataset
from torch_geometric.data import Data, Dataset as GeoDataset, Batch
from torch_geometric.loader import DataLoader # Import PyG DataLoader
import torch_geometric.nn as gnn
from torch_geometric.utils import to_dense_batch

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from collections import defaultdict

# progress bar
from   tqdm.notebook import tqdm, trange
# import wandb

# remove warnings (remove deprecated warnings)
import warnings
warnings.simplefilter('ignore')

# visualization of resultsa
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from   matplotlib.ticker import MaxNLocator
import seaborn           as sns

# wheter we are using colab or not
import os
if not COLAB and not os.path.exists('./data/simulations'):
    os.chdir('..')
    print("current path: ", os.getcwd())

# Simulation Settings
from g6smart.sim_config import SimConfig
from g6smart.evaluation import rate as rate
from g6smart.evaluation.utils import get_cdf
from g6smart.evaluation import rate_torch as rate_metrics
from g6smart.proposals  import loss as loss_funcs, rate_cnn, rate_dnn
from g6smart.data import load_data, create_datasets, download_simulations_data
# from g6smart.track import setup_wandb, real_time_plot
from g6smart.train import train_model

config = SimConfig(0)
print(config)

## Datasets

In [None]:
simulation_path, models_path = download_simulations_data(COLAB)
print("simulations data paths:", simulation_path)
print("saved model location  :", models_path)


In [None]:
csi_data = load_data(simulation_path, n_samples= 120_000)
train_dataset, valid_dataset, tests_dataset = create_datasets(
#    csi_data, split_sizes=[130_000, 60_000, 10_000], seed=101
    csi_data, split_sizes=[ 70_000, 30_000, 20_000], seed=101
#    csi_data, split_sizes=[ 7_000, 3_000, 2_000], seed=101
)

## Graph Transformation

In [None]:
class CustomGraphDataset(GeoDataset):
    def __init__(self, torch_dataset: Dataset, remove_edges: int | None = None):
        super().__init__(None, None)
        self.torch_dataset = torch_dataset
        self.remove_edges  = remove_edges

    @property
    def raw_file_names(self): return []

    @property
    def processed_file_names(self): return []

    def len(self) -> int: return len(self.torch_dataset)
    def __len__(self) -> int: return len(self.torch_dataset)

    def get(self, idx: int) -> Data:
        # get raw data
        csi_tensor = self.torch_dataset[idx][0]
        K, N, _ = csi_tensor.shape

        # node features
        # self signal
        diagonal    = torch.arange(N)
        self_gain   = csi_tensor[:, diagonal, diagonal].permute(1, 0)

        # optimization features
        band_alloc  = torch.ones((N, K)) / K
        x           = band_alloc

        # edage features



        row, col    = torch.combinations(torch.arange(N), 2).t()
        edge_index  = torch.cat(
            [torch.stack([row, col]), torch.stack([col, row])],
            dim = 1
        )
        edge_attr1   = csi_tensor[:, row, col].permute(1, 0)
        edge_attr2   = csi_tensor[:, col, row].permute(1, 0)
        edge_attr    = torch.cat([edge_attr1, edge_attr2], dim = 0)
        if self.remove_edges is not None and self.remove_edges > 0:
            num_edges = edge_attr1.shape[0]
            if self.remove_edges >= num_edges:
                # If removing all or more edges than available, return empty edges
                edge_index = torch.empty((2, 0), dtype=torch.long, device=edge_index.device)
                edge_attr = torch.empty((0, K), dtype=torch.float, device=edge_attr.device)
            else:
                mean_attr = torch.mean(edge_attr, dim = 1)
                sorted_index = torch.argsort(mean_attr, descending = False)
                sorted_index = sorted_index[self.remove_edges:]
                edge_index   = edge_index[:, sorted_index]
                edge_attr    = edge_attr[sorted_index, :]

        data = Data(
            x = band_alloc, v = self_gain,
            edge_index = edge_index, edge_attr = edge_attr
        )

        return csi_tensor, data

dataset = CustomGraphDataset(tests_dataset, 50)
dataset[0][1]

## Graph Models

In [None]:
from torch_geometric.utils import add_self_loops

class EdgeAwareMPNN(gnn.MessagePassing):
    def __init__(self, in_channels, edge_dim, out_channels, aggr='add'):
        super().__init__(aggr=aggr)  # φ²: aggregation (e.g., 'add' = sum)

        # φ: message MLP over (h_j, e_{j,i})
        self.phi = nn.Sequential(
            nn.Linear(in_channels + edge_dim, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )

        # α: update MLP over (h_i, aggregated_msg)
        self.alpha = nn.Sequential(
            nn.Linear(in_channels + out_channels, out_channels),
            nn.ReLU(),
        )

    def forward(self, x, edge_index, edge_attr):
        # Optionally include self-loops; adapt edge_attr if needed
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # Concatenate neighbor feature with edge attribute
        msg_input = torch.cat([x_j, edge_attr], dim=-1)
        return self.phi(msg_input)

    def update(self, aggr_out, x):
        # α(h_i, aggregated_messages)
        update_input = torch.cat([x, aggr_out], dim=-1)
        return self.alpha(update_input)

In [None]:
from torch_geometric.utils import add_self_loops

class RRMMP(gnn.MessagePassing):
    def __init__(self,
        alloc_dim: int, node_dim: int, edge_dim: int,
        hidden_dim: int, aggr: str = "sum"
    ):
        super().__init__(aggr=aggr, node_dim = 0, flow = "source_to_target")  # φ²: aggregation (e.g., 'add' = sum)

        in_dim = node_dim + edge_dim

        # φ: message MLP over (h_j, e_{j,i})
        self.phi1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # α: update MLP over (h_i, aggregated_msg)
        self.phi2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(), # maybe sigmoid
        )

        self.band_allocator = nn.Sequential(
            nn.Linear(alloc_dim + hidden_dim, alloc_dim),
            nn.Softmax(dim = 1)
        )


    def forward(self, x, v, edge_index, edge_attr):
        # Optionally include self-loops; adapt edge_attr if needed
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes = x.size(0))

        out = self.propagate(edge_index, x=x, v=v, edge_attr=edge_attr)
        di = torch.cat([x, out], dim = 1)
        return self.band_allocator(di)

    def message(self, x_i, v_i, edge_attr):
        # Concatenate neighbor feature with edge attribute
        msg_input = torch.cat([v_i, edge_attr], dim=-1)
        return self.phi1(msg_input)

    def update(self, aggr_out, x):
        # α(h_i, aggregated_messages)
        return self.phi2(aggr_out)

In [None]:
tests_loader = DataLoader(CustomGraphDataset(tests_dataset, 20), batch_size = 64, shuffle=False)

graph = next(iter(tests_loader))[1]
print(graph)
layer = RRMMP(4, 4, 4, 4)
result = layer(graph.x, graph.v, graph.edge_index, graph.edge_attr)

result[:5, :]

In [None]:
class GNNAllocator(nn.Module):
    def __init__(self,
        alloc_dim: int, node_dim: int, edge_dim: int, n_bands: int,
        hidden_dim: int = 128, n_layers: int = 4,
    ):
        super().__init__()

        layers = []
        for _ in range(n_layers + 1):
            l = RRMMP(alloc_dim, node_dim, edge_dim, hidden_dim)
            layers.append((l, "x, v, edge_index, edge_attr -> x"))
            layers.append(nn.ReLU())

        self.gnn = gnn.Sequential("x, v, edge_index, edge_attr", layers[:-1])
        self.n_bands = n_bands
    def forward(self, data: Data | Batch, tau: float = 0.5):
        x, v, edge_index, edge_attr = data.x, data.v, data.edge_index, data.edge_attr
        x = self.gnn(x, v, edge_index, edge_attr)

        # channel optimization

        probs = x.reshape(data.batch_size, -1, self.n_bands)
        probs = probs.permute(0, 2, 1)
        return probs



In [None]:
tests_loader = DataLoader(CustomGraphDataset(tests_dataset), batch_size = 64, shuffle=False)

graph = next(iter(tests_loader))[1]
print(graph)
layer = GNNAllocator(4, 4, 4, 4)
result = layer(graph)

result[0, :, :4]

## Training Procedure

In [None]:
BATCH_SIZE   = 512
REMOVE_EDGES = 100
train_loader = DataLoader(CustomGraphDataset(train_dataset, REMOVE_EDGES), batch_size = BATCH_SIZE, shuffle=True )
valid_loader = DataLoader(CustomGraphDataset(valid_dataset, REMOVE_EDGES), batch_size = BATCH_SIZE, shuffle=True )
tests_loader = DataLoader(CustomGraphDataset(tests_dataset, REMOVE_EDGES), batch_size = BATCH_SIZE, shuffle=False)

In [None]:
def _bit_rate(config: SimConfig, C: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
    A    = torch.argmax(A, dim = 1)
    sinr = rate_metrics.signal_interference_ratio(config, C, A, None)
    rate = torch.sum( 10 * torch.log2(1 + sinr), dim = 1)
    return torch.mean(rate, dim = 1).mean()

In [None]:
percentiles = {}

In [None]:
model =  GNNAllocator(4, 4, 4, 4, 128, 4)
model_name = f"gnn-v3-{REMOVE_EDGES}"

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0015)
scheduler = lrs.CosineAnnealingLR(optimizer, T_max=40, eta_min=3e-4)

loss_func = partial(loss_funcs.loss_pure_rate, mode = 'min', p = 1e5, a = 0.6)
for step in range(5):
    model.train()
    total_loss = 0.
    total_bin_error = 0.
    total_bit_rate  = 0.
    for sample, graph in tqdm(train_loader, desc = "training: ", unit=" batch", total = len(train_loader), leave = False):
        optimizer.zero_grad()

        sample = sample.to(device)
        graph  = graph.to(device)
        alloc_prob = model(graph)        # soft output
        loss = loss_func(config, sample, alloc_prob).mean()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
        optimizer.step()

        # training metrics
        total_loss += loss.item()
        total_bin_error += loss_funcs.binarization_error(alloc_prob)
        total_bit_rate  += _bit_rate(config, sample, alloc_prob)

        del sample, alloc_prob, loss, graph
        torch.cuda.empty_cache()

    ttotal_loss = total_loss / len(train_loader)
    ttotal_bin_error = total_bin_error / len(train_loader)
    ttotal_bit_rate  = total_bit_rate / len(train_loader)

    model.eval()
    total_loss = 0.
    total_bin_error = 0.
    total_bit_rate  = 0.

    with torch.no_grad():
        for sample, graph in tqdm(valid_loader, desc = "validation: ", unit=" batch", total = len(valid_loader), leave = False):
            sample = sample.to(device)
            graph  = graph.to(device)
            alloc_prob = model(graph)        # soft output
            loss = loss_func(config, sample, alloc_prob).mean()

            # loss = loss_interference(sample, alloc_prob).mean()
            total_loss += loss.item()
            total_bin_error += loss_funcs.binarization_error(alloc_prob)
            total_bit_rate  += _bit_rate(config, sample, alloc_prob)

            del sample, alloc_prob, loss, graph
            torch.cuda.empty_cache()

    total_loss = total_loss / len(valid_loader)
    total_bin_error = total_bin_error / len(valid_loader)
    total_bit_rate  = total_bit_rate / len(valid_loader)

    lr = scheduler.get_last_lr()[-1]
    print(
        f"[{step:>3d}] (lr: {lr:1.2e})",
        f"train loss: {ttotal_loss:7.4f}",
        f"(bin error: {ttotal_bin_error:5.3e}, bit rate: {ttotal_bit_rate:4.2f})",
        f"valid loss: { total_loss:7.4f}",
        f"(bin error: { total_bin_error:5.3e}, bit rate: { total_bit_rate:4.2f})",
        sep = " "
    )

    scheduler.step()

In [None]:
import json
model.eval()
total_loss = 0.
total_bin_error = 0.
metrics = defaultdict(lambda : 0)

rates = []
with torch.no_grad():
    for sample, graph in tqdm(tests_loader, desc = "testing: ", unit=" batch", total = len(tests_loader), leave = False):
        sample = sample.to(device)
        graph  = graph.to(device)
        A = model(graph)        # soft output
        loss = loss_func(config, sample, A).mean()

        # loss = loss_interference(sample, alloc_prob).mean()
        total_loss += loss.item()
        total_bin_error += loss_funcs.binarization_error(A)
        total_bit_rate  += _bit_rate(config, sample, A)

        A = torch.argmax(A, dim = 1)
        metrics = loss_funcs.update_metrics(metrics, A, sample, None, config, 4)
        sinr = rate_metrics.signal_interference_ratio(config, sample, A, None)
        rate = torch.sum(10 * torch.log2(1 + sinr), dim = 1)
        rates.append(rate.cpu().flatten().numpy())


        del sample, A, loss, graph
        torch.cuda.empty_cache()

total_loss = total_loss / len(tests_loader)
total_bin_error = total_bin_error / len(tests_loader)

metrics = { key: val / len(tests_loader) for key, val in metrics.items()}

print("testing run:")
print("testing batches: ", len(tests_loader))
print("test test error: ", total_loss)
print("test test binarization error: ", total_bin_error)
print("bit rate / quality metrics:\n", json.dumps(metrics, indent = 2))

percentiles[model_name] = get_cdf(np.hstack(rates))

# Simple

In [None]:
data = np.hstack(rates)
for per in [0.005, 0.05, 0.5, 1, 10, 25, 50, 95, 99]:
    per_point = np.percentile(data, per)
    print(f"percentile: {per:6.3f} ----> {per_point:.2f}")


In [None]:
import pandas as pd

results = pd.read_csv("./results.csv", index_col = 0)

for name, (pos, _) in percentiles.items():
    results[f"{name}_values"] = pos

display(results[ :20])
display(results[-20:])

_, ax = plt.subplots(1, 1, figsize = (6, 4))

model_list = results.columns.tolist()[1:]

for name in model_list:
    per = results["percentiles"].values
    pos = results[name].values
    ax.plot(pos, per, label = name)

plt.ylim(0, 1)
plt.legend()
plt.show()


In [None]:
results.to_csv('results.csv')

In [None]:
H, graph = next(iter(tests_loader))
A = model(graph.to(device))

A = torch.argmax(A, dim = 1)
A = rate_metrics.onehot_allocation(A, 4, 20)

H  = H.to(device)
Ar = A.unsqueeze(2)
Ac = A.unsqueeze(3)

interference = torch.matmul(H , Ac)
interference = torch.matmul(Ar, interference)
interference.sum(dim = (1, 2, 3))


In [None]:
di = torch.arange(20)
sample[:, :, di, di].shape

