In [1]:
import torch
import numpy as np
from torch_geometric.loader import DataLoader

%reload_ext autoreload
%autoreload 2


# Parameters


In [104]:

num_UE = 5 # number of terminals
num_AP = 30 # number of access points
B = 20 # Mhz
D = 1 # km
tau=10
random_matrix = np.random.randn(tau, tau)
U, S, V = np.linalg.svd(random_matrix)

Hb = 15 # Base station height in m
Hm = 1.65 # Mobile height in m
f = 1900 # Frequency in MHz
aL = (1.1 * np.log10(f) - 0.7) * Hm - (1.56 * np.log10(f) - 0.8)
L = 46.3+33.9*np.log10(f)-13.82*np.log10(Hb)-aL

power_f=0.2 # downlink power
rho_p, rho_d = power_f, power_f

# Pd = power_f / 10 ** ((-203.975 + 10 * np.log10(20 * 10 ** 6) + 9) / 10) # normalized receive SNR
Ther_noise = 20000000 * 10**(-17.4) * 10**-3
Pd = 1/Ther_noise
Pu=Pd

d0=0.01 # km
d1=0.05 # km

num_antenna = 50


In [105]:
num_train = 500
num_test = 200
batchSize = 32

lr = 1e-4
step_size = 5
gamma = 0.9

# FL
num_rounds = 5

num_client = num_AP 
num_epochs = 10
eval_round = num_rounds//10 if num_rounds//10 else 1

# Create data loader for training and testing 

In [106]:
from Utils.data_gen import Generate_Input, create_graph

Beta_all, Phi_all = Generate_Input(num_train, tau, num_UE, num_AP, Pd, D=1, Hb=15, Hm=1.65, f=1900,
                    var_noise=1, Pmin=0, seed=2017, d0=d0, d1=d1)
train_data = create_graph(Beta_all, Phi_all, 'het')
train_loader = [
    DataLoader(train_data[i], batch_size=batchSize, shuffle=True)
    for i in range(num_AP)
]


Beta_test, Phi_test = Generate_Input(num_test, tau, num_UE, num_AP, Pd, D=1, Hb=15, Hm=1.65, f=1900,
                    var_noise=1, Pmin=0, seed=2017, d0=d0, d1=d1)
test_data = create_graph(Beta_test, Phi_test, 'het')
test_loader = [
    DataLoader(test_data[i], batch_size=batchSize, shuffle=False)
    for i in range(num_AP)
]


train_data_cen = create_graph(Beta_all, Phi_all, 'het', isDecentralized=False)
train_loader_cen = DataLoader(train_data_cen, batch_size=batchSize, shuffle=True)
test_data_cen = create_graph(Beta_test, Phi_test, 'het', isDecentralized=False)
test_loader_cen = DataLoader(test_data_cen, batch_size=batchSize, shuffle=False)

In [107]:
hidden_channels = 32 # > 4
num_gnn_layers = 2


ap_dim = train_data[0][0]['AP'].x.shape[1]
ue_dim = train_data[0][0]['UE'].x.shape[1]
edge_dim = train_data[0][0]['down'].edge_attr.shape[1]
tt_meta = [('UE', 'up', 'AP'), ('AP', 'down', 'UE')]
dim_dict = {
    'UE': ue_dim,
    'AP': ap_dim,
    'edge': edge_dim,
}

In [108]:
from Models.GNN import APHetNet

model = APHetNet(
    metadata=tt_meta,
    dim_dict=dim_dict,
    out_channels=hidden_channels,
    num_layers=num_gnn_layers,
    hid_layers=hidden_channels
)

In [111]:
from Models.GNN import APHetNet
from Utils.training import train, eval, package_calculate
from Utils.synthetic_graph import return_graph, combine_graph

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

global_model = APHetNet(
    metadata=tt_meta,
    dim_dict=dim_dict,
    out_channels=hidden_channels,
    num_layers=num_gnn_layers,
    hid_layers=hidden_channels
).to(device)
local_models, optimizers = [], []

# Init every client model/optimizer
for each_AP in range(num_AP):
    model = APHetNet(
        metadata=tt_meta,
        dim_dict=dim_dict,
        out_channels=hidden_channels,
        num_layers=num_gnn_layers,
        hid_layers=hidden_channels
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    local_models.append(model)
    optimizers.append(optimizer)

# Main Training

In [None]:
import copy
from Utils.training import (
    get_global_info, distribute_global_info, average_weights,
    fl_train, fl_eval, fl_eval_rate
)

log = []

print(f"Starting Federated Learning with {num_client} clients for {num_rounds} rounds")

for round in range(num_rounds):
    # print(f"\n=== Federated Round {round+1}/{num_rounds} ===")
    
    ## 1.Exchange global information
    send_to_server = get_global_info(
        train_loader, local_models, optimizers,
        tau=tau, rho_p=power_f, rho_d=power_f
    )
    response_all = distribute_global_info(send_to_server)
    
    
    ## 2. Training Local models    
    local_weights = []
    total_loss = 0.0
    for model, opt, batches , responses_ap in zip(local_models, optimizers, train_loader, response_all):
        model.train() 
        opt.zero_grad() 
        for epoch in range(num_epochs):
            train_loss = fl_train(
                batches, responses_ap, model, opt,
                tau=tau, rho_p=power_f, rho_d=power_f, num_antenna=num_antenna
            )
        local_weights.append(copy.deepcopy(model.state_dict()))
        total_loss += train_loss
    avg_loss = total_loss / num_client
    # print(f"Round {round+1}: Average local training loss = {avg_loss:.6f}")


    ## 3. Update global models
    global_weights = average_weights(local_weights)

    # Broadcast updated global weights to all clients
    for model in local_models:
        model.load_state_dict(global_weights)
        

    ## 4. Exchange global eval information
    # print("Evaluating global model(s)...")
    # send_to_server_eval = get_global_info(
    #     test_loader, local_models, optimizers,
    #     tau=tau, rho_p=power_f, rho_d=power_f
    # )
    # response_all_eval = distribute_global_info(send_to_server_eval)
    # total_eval_rate = 0.0
    # for client_idx, (model, loader, responses) in enumerate(zip(local_models, test_loader, response_all_eval)):
    #     model.eval() 
    #     for epoch in range(num_epochs):
    #         eval_metrics = fl_eval(
    #             loader, responses, model,
    #             tau=tau, rho_p=power_f, rho_d=power_f, num_antenna=num_antenna
    #         )
    #     total_eval_rate += eval_metrics
    # total_eval_rate = total_eval_rate/num_client
    #     # print(f"Client {client_idx}: {eval_metrics}")
    
    
    total_eval_rate = fl_eval_rate(
        test_loader, model,
        tau=tau, rho_p=power_f, rho_d=power_f, num_antenna=num_antenna
    )
    if round%eval_round==0:
        print(f"Round {round+1:02d}/{num_rounds}: Avg Training Rate = {-avg_loss:.6f} | Avg Eval rate = {total_eval_rate:.6f}")
            
    log.append({
        "round": round + 1,
        "train_loss": avg_loss,
        "eval": total_eval_rate
    })
    

Starting Federated Learning with 30 clients for 5 rounds


KeyboardInterrupt: 

In [114]:
len(test_loader)

30

In [136]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


all_power = []
all_large_scale = []

all_phi = []

with torch.no_grad():
    for batch_idx, batches_at_k in enumerate(zip(*test_loader)):
        per_batch_power = []
        per_batch_large_scale = []
        for ap_idx, (model, batch) in enumerate(zip(local_models, batches_at_k)):
            model.eval()
            # iterate over all batch of each AP
            batch = batch.to(device)
            num_graphs = batch.num_graphs
            num_UEs = batch['UE'].x.shape[0]//num_graphs
            num_APs = batch['AP'].x.shape[0]//num_graphs
            
            x_dict, edge_dict, edge_index = model(batch)
            power = x_dict['UE'].reshape(num_graphs, num_UEs, -1)
            power_matrix = power[:,:,-1][:, None, :]
            pilot_matrix = batch['UE'].x.reshape(num_graphs, num_UEs, -1)
            large_scale = batch['AP','down','UE'].edge_attr.reshape(num_graphs, num_APs, num_UEs)
            
            per_batch_power.append(power_matrix)
            per_batch_large_scale.append(large_scale)
            # per_batch_phi.append(power_matrix)
        per_batch_phi = pilot_matrix
        per_batch_power = torch.cat(per_batch_power, dim=1)
        per_batch_large_scale = torch.cat(per_batch_large_scale, dim=1)
         
        all_power.append(per_batch_power)
        all_large_scale.append(per_batch_large_scale)
        all_phi.append(per_batch_phi)

In [151]:
from Utils.training import rate_calculation, variance_calculate
total_min_rate = 0.0
total_samples = 0.0
for each_power, each_large_scale, each_phi in zip(all_power, all_large_scale, all_phi):
    num_graphs = len(each_power)
    each_channel_variance = variance_calculate(each_large_scale, each_phi, tau=tau, rho_p=rho_p)
    rate = rate_calculation(each_power, each_large_scale, each_channel_variance, each_phi, rho_d=rho_d, num_antenna=num_antenna)
    min_rate, _ = torch.min(rate, dim=1)
    min_rate = torch.mean(min_rate)
    total_min_rate += min_rate.item() * num_graphs
    total_samples += num_graphs
total_min_rate/total_samples

3.655353698730469

In [None]:
print("\nTraining complete!")
print(f"Final Round | Avg Train Loss: {log[-1]['train_loss']:.6f} | Avg Eval: {log[-1]['eval']}")


Training complete!
Final Round | Avg Train Loss: -5.516941 | Avg Eval: 5.171294670104981


# Centralized training

In [100]:
from Models.GNN import APHetNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR

cen_model = APHetNet(
    metadata=tt_meta,
    dim_dict=dim_dict,
    out_channels=hidden_channels,
    num_layers=num_gnn_layers,
    hid_layers=hidden_channels,
    edge_conv=True
).to(device)
cen_optimizer = torch.optim.Adam(cen_model.parameters(), lr=lr)
cen_scheduler = StepLR(cen_optimizer, step_size=10, gamma=0.5)

In [101]:
num_epochs = 500
eval_epochs = num_epochs//10 if num_epochs//10 else 1

In [102]:
from Utils.training import cen_eval, cen_train
for epoch in range(num_epochs):
    cen_model.eval()
    with torch.no_grad():
        train_eval = cen_eval(
            train_loader_cen, cen_model,
            tau=tau, rho_p=power_f, rho_d=power_f, num_antenna=num_antenna
        )
        test_eval = cen_eval(
            train_loader_cen, cen_model,
            tau=tau, rho_p=power_f, rho_d=power_f, num_antenna=num_antenna
        )
    cen_model.train()
    train_loss = cen_train(
        train_loader_cen, cen_model, cen_optimizer,
        tau=tau, rho_p=power_f, rho_d=power_f, num_antenna=num_antenna
    )
    cen_scheduler.step()
    if epoch%eval_epochs==0:
        print(
            f"Epoch {epoch+1:03d}/{num_epochs} | "
            f"Train Loss: {train_loss:.6f} | "
            f"Train Rate: {train_eval:.6f} | "
            f"Test Rate: {test_eval:.6f} "
        )

Epoch 001/500 | Train Loss: -0.974967 | Train Rate: 0.860557 | Test Rate: 0.815201 
Epoch 051/500 | Train Loss: -1.692961 | Train Rate: 1.788714 | Test Rate: 1.751839 
Epoch 101/500 | Train Loss: -1.897718 | Train Rate: 1.816060 | Test Rate: 1.841762 
Epoch 151/500 | Train Loss: -1.918362 | Train Rate: 1.705942 | Test Rate: 1.845506 
Epoch 201/500 | Train Loss: -1.926587 | Train Rate: 1.762285 | Test Rate: 1.931217 
Epoch 251/500 | Train Loss: -1.783407 | Train Rate: 1.703272 | Test Rate: 1.828142 
Epoch 301/500 | Train Loss: -1.983649 | Train Rate: 1.928573 | Test Rate: 1.871372 
Epoch 351/500 | Train Loss: -1.791572 | Train Rate: 1.922842 | Test Rate: 1.812562 
Epoch 401/500 | Train Loss: -1.881939 | Train Rate: 1.722239 | Test Rate: 1.850021 
Epoch 451/500 | Train Loss: -1.861650 | Train Rate: 1.798852 | Test Rate: 1.782266 
