In [1]:
import torch
import numpy as np
from torch_geometric.loader import DataLoader

# from Models.GNN import IGCNet
from Utils.data_gen import Generate_Input, create_graph
from Utils.training import sr_loss_matrix, average_weights
%reload_ext autoreload
%autoreload 2


# Parameters

In [2]:
num_H = 8
num_test = 20
K = 5 # number of terminals
M = 10 # number of access points
B = 20 # Mhz
D = 1 # km
tau=10
random_matrix = np.random.randn(tau, tau)
U, S, V = np.linalg.svd(random_matrix)

Hb = 15 # Base station height in m
Hm = 1.65 # Mobile height in m
f = 1900 # Frequency in MHz
aL = (1.1 * np.log10(f) - 0.7) * Hm - (1.56 * np.log10(f) - 0.8)
L = 46.3+33.9*np.log10(f)-13.82*np.log10(Hb)-aL

power_f=0.2 # downlink power
# Pd = power_f / 10 ** ((-203.975 + 10 * np.log10(20 * 10 ** 6) + 9) / 10) # normalized receive SNR
Ther_noise = 20000000 * 10**(-17.4) * 10**-3
Pd = 1/Ther_noise
Pu=Pd

d0=0.01 # km
d1=0.05 # km

N = 50

R_cf_min = np.zeros(N)  # Cell Free
R_cf_sum = np.zeros(N)
R_cf_opt_min = np.zeros(N)

In [3]:
num_train = 4
num_test = 2
batchSize = 32
num_rounds = 20

num_epochs = 500
lr = 1e-4
step_size = 5
gamma = 0.9




## Create data loader for training and testing 

In [40]:
num_AP = M 
Beta_all, Phi_all = Generate_Input(num_train, tau, K, M, Pd, D=1, Hb=15, Hm=1.65, f=1900,
                    var_noise=1, Pmin=0, seed=2017, d0=d0, d1=d1)
train_data = create_graph(Beta_all, Phi_all)
train_loader = [
    DataLoader(train_data[i], batch_size=batchSize, shuffle=True)
    for i in range(num_AP)
]

Beta_test, Phi_test = Generate_Input(num_test, tau, K, M, Pd, D=1, Hb=15, Hm=1.65, f=1900,
                    var_noise=1, Pmin=0, seed=2017, d0=d0, d1=d1)
test_data = create_graph(Beta_test, Phi_test)
test_loader = [
    DataLoader(test_data[i], batch_size=batchSize, shuffle=False)
    for i in range(num_AP)
]

In [41]:
hidden_channels = 32 # > 4
num_gnn_layers = 2


ap_dim = train_data[0][0]['AP'].x.shape[1]
ue_dim = train_data[0][0]['UE'].x.shape[1]
edge_dim = train_data[0][0]['down'].edge_attr.shape[1]
tt_meta = [('UE', 'up', 'AP'), ('AP', 'down', 'UE')]
dim_dict = {
    'UE': ue_dim,
    'AP': ap_dim,
    'edge': edge_dim,
}

In [42]:
from Models.GNN import APHetNet

model = APHetNet(
    metadata=tt_meta,
    dim_dict=dim_dict,
    out_channels=hidden_channels,
    num_layers=num_gnn_layers,
    hid_layers=hidden_channels
)

# Main trainining pipeline

### test

In [70]:
from Models.GNN import APHetNet
from Utils.training import train, eval
from Utils.synthetic_graph import return_graph, combine_graph

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

global_model = APHetNet(
    metadata=tt_meta,
    dim_dict=dim_dict,
    out_channels=hidden_channels,
    num_layers=num_gnn_layers,
    hid_layers=hidden_channels
).to(device)
local_models, optimizers = [], []

# Init every client model/optimizer
for each_AP in range(num_AP):
    model = APHetNet(
        metadata=tt_meta,
        dim_dict=dim_dict,
        out_channels=hidden_channels,
        num_layers=num_gnn_layers,
        hid_layers=hidden_channels
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    local_models.append(model)
    optimizers.append(optimizer)

for round in range(num_rounds):
    # zip loaders to align batches across APs
    for batches in zip(*train_loader):                       # sync step across APs
        send_to_server = []
        
        for model, opt, batch in zip(local_models, optimizers, batches):
            # Check batch here? something wrong?
            model.eval()
            batch = batch.to(device)
            
            x_dict, edge_dict, edge_index = model(batch)

            send_to_server.append({
                'num_graphs': batch.num_graphs,
                "AP": x_dict['AP'].detach(),
                "edge_attr": edge_dict['AP', 'down', 'UE'].detach(),
                "edge_index": edge_index['AP', 'down', 'UE'].detach()}
            )
        
    return_from_server = return_graph(send_to_server)
    
    for model, opt, batches, complements in zip(local_models, optimizers, train_loader, return_from_server):
        model.train() 
        opt.zero_grad() 
        batch = batch.to(device)
        
        for batch, complement in zip(batches, complements):
            batch = batch.to(device)
            complement = complement.to(device)
            modified_batch = combine_graph(batch, complement) # incorrect?
            break
            
        # local_loss = train(batches, complements, model, opt)
        
        
        # local_loss.backward()
        # opt.step()

    # # --- FedAvg (optional, by sample count) ---
    # local_weights = [m.state_dict() for m in local_models]
    # sizes = [len(dl.dataset) for dl in train_loader]
    # new_state = average_weights(local_weights, sizes)
    # global_model.load_state_dict(new_state)
    
    # # Update client models
    # for m in local_models: 
    #     m.load_state_dict(global_model.state_dict())
    
        


In [71]:
num_graph = batch.num_graphs
x_dict, edge_dict, edge_index = model(modified_batch)
graphData, nodeFeatDict, edgeAttrDict = modified_batch, x_dict, edge_dict

## Loss function test

In [None]:
num_graphs = graphData.num_graphs
num_UE = graphData['UE'].x.shape[0]//num_graphs
num_AP = graphData['AP'].x.shape[0]//num_graphs
pilot_matrix = graphData['UE'].x.reshape(num_graphs, num_UE, -1)
large_sacle_fading = edgeAttrDict[('AP', 'down', 'UE')][:,:-1].reshape(num_graphs, num_AP, num_UE)
power = edgeAttrDict[('AP', 'down', 'UE')][:,:-1].reshape(num_graphs, num_AP, num_UE)

In [75]:
# modified_batch
batch

HeteroDataBatch(
  AP={
    x=[4, 1],
    batch=[4],
    ptr=[5],
  },
  UE={
    x=[20, 10],
    batch=[20],
    ptr=[5],
  },
  (AP, down, UE)={
    edge_index=[2, 20],
    edge_attr=[20, 1],
  },
  (UE, up, AP)={
    edge_index=[2, 20],
    edge_attr=[20, 1],
  }
)

# Main Training Old Flow

In [None]:
num_AP = M 
Beta_all, Phi_all = Generate_Input(num_train, tau, K, M, Pd, D=1, Hb=15, Hm=1.65, f=1900,
                    var_noise=1, Pmin=0, seed=2017, d0=d0, d1=d1)
train_data = create_graph(Beta_all, Phi_all)
train_loader = [
    DataLoader(train_data[i], batch_size=batchSize, shuffle=True)
    for i in range(num_AP)
]

Beta_test, Phi_test = Generate_Input(num_test, tau, K, M, Pd, D=1, Hb=15, Hm=1.65, f=1900,
                    var_noise=1, Pmin=0, seed=2017, d0=d0, d1=d1)
test_data = create_graph(Beta_test, Phi_test)
test_loader = [
    DataLoader(test_data[i], batch_size=batchSize, shuffle=False)
    for i in range(num_AP)
]