In [None]:
COLAB: bool = False
if COLAB:
  !git clone https://github.com/RubenCid35/6GSmartRRM
  !mv 6GSmartRRM/* /content/
  !pip install -e .
  from google.colab import drive
  drive.mount('/content/drive', force_remount=True)

In [None]:
# vast ai check gpu for invalid specs
!nvidia-smi -q | grep 'Power Limit' 

In [None]:
%load_ext autoreload
%autoreload 2
!pip install -q wandb matplotlib seaborn 

In [None]:
# simple data manipulation
import numpy  as np

# deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
from   torch.utils.data import DataLoader
import torch.cuda.amp as amp # For Automatic Mixed Precision

from functools import partial

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from collections import defaultdict

# progress bar
from   tqdm.notebook import tqdm, trange
import wandb

# remove warnings (remove deprecated warnings)
import warnings
warnings.simplefilter('ignore')

# visualization of resultsa
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from   matplotlib.ticker import MaxNLocator
import seaborn           as sns

# wheter we are using colab or not
import os
if not COLAB and not os.path.exists('./data/simulations'):
    os.chdir('..')
    print("current path: ", os.getcwd())

# Simulation Settings
from g6smart.sim_config import SimConfig
from g6smart.evaluation import rate as rate
from g6smart.evaluation.utils import get_cdf
from g6smart.evaluation import rate_torch as rate_metrics
from g6smart.proposals  import loss as loss_funcs, rate_cnn, rate_dnn
from g6smart.data import load_data, create_datasets, download_simulations_data
from g6smart.track import setup_wandb, real_time_plot
from g6smart.train import train_model

config = SimConfig(0)
print(config)

## Datasets

In [None]:
simulation_path, models_path = download_simulations_data(COLAB)
print("simulations data paths:", simulation_path)
print("saved model location  :", models_path)


In [None]:
csi_data = load_data(simulation_path, n_samples=12_000)
train_dataset, valid_dataset, tests_dataset = create_datasets(
#    csi_data, split_sizes=[130_000, 60_000, 10_000], batch_size=2048, seed=101
    csi_data, split_sizes=[ 7_000, 3_000, 2_000], seed=101
)

## Graph Transformation

In [None]:
from torch.utils.data import Dataset
from torch_geometric.data import Data, Dataset as GeoDataset

class CustomGraphDataset(GeoDataset):
    def __init__(self, torch_dataset: Dataset):
        super().__init__(None, None)
        self.torch_dataset = torch_dataset

    @property
    def raw_file_names(self): return []

    @property
    def processed_file_names(self): return []

    def len(self) -> int: return len(self.torch_dataset)
    def __len__(self) -> int: return len(self.torch_dataset)

    def get(self, idx: int) -> Data:
        # get raw data
        csi_tensor = self.torch_dataset[idx][0]
        K, N, _ = csi_tensor.shape

        # node features
        node_attr= torch.randn(N, 16) # node features
        
        # edage features
        row, col    = torch.combinations(torch.arange(N), 2).t()
        cell_coords = torch.stack([row, col])
        edge_index  = torch.cat([cell_coords, cell_coords], dim = 1)
        edge_attr = csi_tensor[:, row, col].permute(1, 0)
        edge_attr = torch.cat([edge_attr, edge_attr], dim = 0)
        
        return Data(x = node_attr, edge_index = edge_index, edge_attr = edge_attr)


In [None]:
from torch_geometric.loader import DataLoader # Import PyG DataLoader

BATCH_SIZE   = 1024
train_loader = DataLoader(CustomGraphDataset(train_dataset), batch_size = BATCH_SIZE, shuffle=True )
valid_loader = DataLoader(CustomGraphDataset(valid_dataset), batch_size = BATCH_SIZE, shuffle=True )
tests_loader = DataLoader(CustomGraphDataset(tests_dataset), batch_size = BATCH_SIZE, shuffle=False)

## Graph Models

In [None]:
import torch_geometric.nn as gnn

class CustomMessagePassing(gnn.MessagePassing):
    def __init__(self, node_dim: int, edge_dim: int, out_dim: int, **kwargs):
        super().__init__(aggr='add', flow='source_to_target', **kwargs)

        self.msg_mlp = nn.Sequential(
            nn.Linear(node_dim + edge_dim, out_dim),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(out_dim),
            nn.Linear(out_dim, out_dim),
            nn.ReLU(inplace=True),
        )

        self.upt_mlp = nn.Sequential(
            nn.Linear(node_dim + out_dim, out_dim),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(out_dim),
            nn.Linear(out_dim, out_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        combined = torch.cat([x_j, edge_attr], dim=-1)
        return self.msg_mlp(combined)

    def update(self, aggr_out, x):
        combined = torch.cat([aggr_out, x], dim=-1)
        return self.upt_mlp(combined)
    
class GNNAllocator(nn.Module):
    def __init__(self, n_subbands: int, node_dim: int, edge_dim: int, hidden_dim: int = 128, gnn_layers: int = 3):
        super().__init__()

        self.K = n_subbands
        self.node_embed = nn.Linear(node_dim, hidden_dim)

        gnn_blocks = []
        for _ in range(gnn_layers):
            gnn_blocks.append(
                (CustomMessagePassing(hidden_dim, edge_dim, hidden_dim), 'x, edge_index, edge_attr -> x')
            )

        self.gnn_layers = gnn.Sequential('x, edge_index, edge_attr', gnn_blocks)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim // 2, self.K),
            nn.Softmax(dim = 1)
        )

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = self.node_embed(x)
        x = self.gnn_layers(x, edge_index, edge_attr)
        return self.classifier(x)
    
model = GNNAllocator(4, 16, 4)
data = next(iter(train_loader))
print(model)
model(data)

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import NNConv

class SubbandAllocatorGNN(torch.nn.Module):
    def __init__(self, node_feat_dim, edge_feat_dim, hidden_dim, num_classes):
        super(SubbandAllocatorGNN, self).__init__()

        # MLP to process edge features into dynamic weight matrices
        self.edge_mlp1 = nn.Sequential(
            nn.Linear(edge_feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * node_feat_dim)
        )

        self.conv1 = NNConv(
            in_channels=node_feat_dim,
            out_channels=hidden_dim,
            nn=self.edge_mlp1,
            aggr='mean'
        )

        self.edge_mlp2 = nn.Sequential(
            nn.Linear(edge_feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * hidden_dim)
        )

        self.conv2 = NNConv(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            nn=self.edge_mlp2,
            aggr='mean'
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, num_classes),
            nn.Softmax(dim = 1)
        )

    def forward(self, data: Data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        # x: [N, node_feat_dim]
        # edge_index: [2, E]
        # edge_attr: [E, edge_feat_dim]

        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = self.classifier(x)
        return x

model2 = SubbandAllocatorGNN(16, 4, 64, 4)
model2(data)

In [None]:
param_size = 0
for param in model2.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model2.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

## Training Procedure