In [None]:
COLAB: bool = False
if COLAB:
  %git clone https://github.com/RubenCid35/6GSmartRRM
  %mv 6GSmartRRM/* .
  %pip install -e .

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# simple data manipulation
import numpy  as np
import pandas as pd

# deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
from   torch.utils.data import DataLoader, TensorDataset, random_split

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from collections import defaultdict

# progress bar
from   tqdm.notebook import tqdm, trange
import wandb
# remove warnings (remove deprecated warnings)
import warnings
warnings.simplefilter('ignore')

# visualization of resultsa
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from   matplotlib.ticker import MaxNLocator
import seaborn           as sns
import plotly.express as px


# wheter we are using colab or not
import os
if not COLAB and not os.path.exists('./data/simulations'):
    os.chdir('..')

# Simulation Settings
from g6smart.sim_config import SimConfig
from g6smart.evaluation import rate as rate
from g6smart.evaluation import rate_torch as rate_metrics
from g6smart.proposals  import loss as loss_funcs, rate_cnn, rate_dnn
from g6smart.data import load_data, create_datasets, download_simulations_data
from g6smart.track import setup_wandb, real_time_plot

config = SimConfig(0)
config

## Datasets

In [None]:
simulation_path, models_path = download_simulations_data(COLAB)
csi_data = load_data(simulation_path, n_samples=60_000)

In [None]:
train_loader, valid_loader, tests_loader = create_datasets(
    csi_data, split_sizes=[40_000, 10_000, 10_000], batch_size=128, seed=101
)

## FNN Approach

In [None]:
BATCH_SIZE: int = 128
MAX_EPOCH : int = 40
LR: float  = 3e-4

# under ideal conditions, the sisa ideal shannon rate is around 4.
REQ: float      = 8.

learning_config = {
    'loss': 'pure-min-rate',
    'max-epoch': MAX_EPOCH,
    'batch-size': BATCH_SIZE,
    'learning-rate': LR,
    'desired-norm-rate' : REQ,
}

# training config
HS: int    = 1024
HL: int    = 6
DP: float  = 0.1
KEEP_BANDS: bool = True
WEIGHTED_GAIN: bool = True

model_config = {
    'dropout': DP,
    'weighted-gain': WEIGHTED_GAIN,
}

name  = "p3-fnn-00-01-base"
training_config = {}
training_config.update(model_config)
training_config.update(learning_config)

run = setup_wandb(name, 'cnn-rate-confirming', training_config, id = None)
print("run config:", run.config)

model = rate_dnn.RateConfirmAllocModel(20, 4, HS, HL, DP, True, True).to(device)
optimizer = optim.Adam(model.parameters(), LR, weight_decay=1e-5)
scheduler = lrs.CosineAnnealingLR(optimizer, T_max=25, eta_min=1e-4)
train_loss, valid_loss, train_rate, valid_rate = [], [], [], []
for epoch in trange(MAX_EPOCH, desc = "training epoch", unit = "epoch"):
    real_time_plot(train_loss, valid_loss, train_rate, valid_rate)

    # training step
    model.train()
    training_loss = 0.
    train_binary_loss = 0.

    temp = 1.0 # temp_scheduler.step()
    training_metrics = defaultdict(lambda : 0)
    for sample in tqdm(train_loader, desc = 'training step:', unit = 'batch', total = len(train_loader), leave=False):
        optimizer.zero_grad()

        sample     = sample[0].to(device)
        alloc_prob = model(sample, temp)
        loss       = loss_funcs.loss_pure_rate(config, sample, alloc_prob, None, 'min').mean()
        # loss       = loss_fullfield_req(config, sample, alloc_prob, REQ).mean()
        training_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        train_binary_loss += loss_funcs.binarization_error(alloc_prob).item()
        training_metrics = loss_funcs.update_metrics(training_metrics, alloc_prob, sample, None, config, REQ)

    scheduler.step()
    training_loss = training_loss / len(train_loader)
    train_binary_loss = train_binary_loss / len(train_loader)
    training_metrics = { 'train-' + key: val / len(train_loader) for key, val in training_metrics.items()}

    model.eval()
    validation_loss = 0.
    valid_binary_loss = 0.
    validation_metrics = defaultdict(lambda : 0.)
    for sample in tqdm(valid_loader, desc = 'validation step:', unit = 'batch', total = len(valid_loader), leave = False):
        sample     = sample[0].to(device)
        alloc_prob = model(sample, temp)
        loss       = loss_funcs.loss_pure_rate(config, sample, alloc_prob, None, 'min').mean()
        # loss       = loss_fullfield_req(config, sample, alloc_prob, REQ).mean()
        validation_loss += loss.item()

        valid_binary_loss += loss_funcs.binarization_error(alloc_prob).item()
        validation_metrics = loss_funcs.update_metrics(validation_metrics, alloc_prob, sample, None, config, REQ)

    validation_loss = validation_loss / len(valid_loader)
    valid_binary_loss = valid_binary_loss / len(valid_loader)

    validation_metrics = { 'valid-' + key: val / len(valid_loader) for key, val in validation_metrics.items()}

    logged_values = {
        'train-loss': training_loss, 'valid-loss': validation_loss, 'temperature': temp,
        'train-binary-loss': train_binary_loss, 'valid-binary-loss': valid_binary_loss
    }

    logged_values.update(training_metrics)
    logged_values.update(validation_metrics)

    train_loss.append(training_loss)
    valid_loss.append(validation_loss)
    train_rate.append(training_metrics['train-bit-rate'])
    valid_rate.append(validation_metrics['valid-bit-rate'])
    wandb.log(logged_values)

wandb.finish()
torch.save(model.state_dict(), os.path.join(models_path, "cnn-min-00-01.pt") )

In [None]:
len(sample)