In [None]:
!git clone https://github.com/LeGiangK62/GNN-Resource-Management.git
%cd GNN-Resource-Management/NewDir

Cloning into 'GNN-Resource-Management'...
remote: Enumerating objects: 209, done.[K
remote: Counting objects: 100% (209/209), done.[K
remote: Compressing objects: 100% (141/141), done.[K
remote: Total 209 (delta 110), reused 158 (delta 62), pack-reused 0[K
Receiving objects: 100% (209/209), 236.41 KiB | 1.27 MiB/s, done.
Resolving deltas: 100% (110/110), done.
/content/GNN-Resource-Management/NewDir


In [None]:
%%capture
import torch
!pip install torch_geometric

# Optional dependencies:
if torch.cuda.is_available():
  !pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
else:
  !pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html
#

In [1]:
import torch
import numpy as np

from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Sigmoid
from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader
from torch_geometric.nn import Linear, HGTConv

from WSN_GNN import generate_channels_wsn

# Create HeteroData from the wireless system

In [2]:
#region Create HeteroData from the wireless system
def convert_to_hetero_data(channel_matrices):
    graph_list = []
    num_sam, num_aps, num_users = channel_matrices.shape
    for i in range(num_sam):
        x1 = torch.ones(num_users, 1)
        x2 = torch.ones(num_users, 1)  # power allocation
        x3 = torch.ones(num_users, 1)  # ap selection?
        user_feat = torch.cat((x1,x2,x3),1)  # features of user_node
        ap_feat = torch.zeros(num_aps, num_aps_features)  # features of user_node
        edge_feat_uplink = channel_matrices[i, :, :].reshape(-1, 1)
        edge_feat_downlink = channel_matrices[i, :, :].reshape(-1, 1)
        graph = HeteroData({
            'user': {'x': user_feat},
            'ap': {'x': ap_feat}
        })
        # Create edge types and building the graph connectivity:
        graph['user', 'uplink', 'ap'].edge_attr = torch.tensor(edge_feat_uplink, dtype=torch.float)
        graph['ap', 'downlink', 'user'].edge_attr = torch.tensor(edge_feat_downlink, dtype=torch.float)
        graph['user', 'uplink', 'ap'].edge_index = torch.tensor(adj_matrix(num_users, num_aps).transpose(), dtype=torch.int64)
        graph['ap', 'downlink', 'user'].edge_index = torch.tensor(adj_matrix(num_aps, num_users).transpose(),
                                                                dtype=torch.int64)

        # graph['ap', 'downlink', 'user'].edge_attr  = torch.tensor(edge_feat_downlink, dtype=torch.float)
        graph_list.append(graph)
    return graph_list


def adj_matrix(num_from, num_dest):
    adj = []
    for i in range(num_from):
        for j in range(num_dest):
            adj.append([i, j])
    return np.array(adj)


# Build Heterogeneous GNN

In [174]:
#region Build Heterogeneous GNN
class HetNetGNN(torch.nn.Module):
    def __init__(self, data, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                           num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

        self.lin1 = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        original = x_dict['user'].clone
        x_dict = {
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in x_dict.items()
        }

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        original = x_dict['user'] # not original
        power = self.lin(x_dict['user'])
        ap_selection = self.lin1(x_dict['user'])
        ap_selection = torch.abs(ap_selection).int()
#         print(original, power, ap_selection)
        print(x for node_type, x in x_dict.items())
        
        out = torch.cat((original[:,1].unsqueeze(-1), power[:,1].unsqueeze(-1), ap_selection[:,1].unsqueeze(-1)), 1)
        return out

class EdgeConv(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def foward(self, graph, inputs):
        return 1


class RoundActivation(torch.nn.Module):
    def forward(self, x):
        return torch.round(torch.abs(x))

#endregion


# Training and Testing functions

In [165]:
#region Training and Testing functions
def loss_function(output, batch, is_train=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_user = batch['user']['x'].shape[0]
    num_ap = batch['ap']['x'].shape[0]
    ##
    channel_matrix = batch['user', 'ap']['edge_attr']
    ##
#     power_max = batch['user']['x'][:, 0]
#     power = batch['user']['x'][:, 1]
#     ap_selection = batch['user']['x'][:, 2]
    power_max = output[:, 0]
    power = output[:, 1]
    ap_selection = output[:, 2]
    ##
    ap_selection = ap_selection.int()
    index = torch.arange(num_user)

    G = torch.reshape(channel_matrix, (-1, num_ap, num_user))
    # P = torch.reshape(power, (-1, num_ap, num_user)) #* p_max
    P = torch.zeros_like(G, requires_grad=True).clone()
    P[0, ap_selection[index], index] = power_max * power
    ##
    # new_noise = torch.from_numpy(noise_matrix).to(device)
    desired_signal = torch.sum(torch.mul(P, G), dim=1).unsqueeze(-1)
    G_UE = torch.sum(G, dim=2).unsqueeze(-1)
    all_signal = torch.matmul(P.permute((0,2,1)), G_UE)
    interference = all_signal - desired_signal #+ new_noise
    rate = torch.log(1 + torch.div(desired_signal, interference))
    sum_rate = torch.mean(torch.sum(rate, 1))
    mean_power = torch.mean(torch.sum(P.permute((0,2,1)), 1))

    if is_train:
        return torch.neg(sum_rate / mean_power)
    else:
        return sum_rate / mean_power



def train(data_loader):
    model.train()
    device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    total_examples = total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()
        batch = batch.to(device_type)
        # batch_size = batch['user'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)
        tmp_loss = loss_function(out, batch, True)
        tmp_loss.backward()
        optimizer.step()
        #total_examples += batch_size
        total_loss += float(tmp_loss) #* batch_size

    return total_loss #/ total_examples


def test(data_loader):
    model.eval()
    device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    total_examples = total_loss = 0
    for batch in data_loader:
        batch = batch.to(device_type)
        # batch_size = batch['user'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)
        tmp_loss = loss_function(out, batch, False)
        #total_examples += batch_size
        total_loss += float(tmp_loss) #* batch_size

    return total_loss #/ total_examples
#endregion



# Main

In [5]:
K = 3  # number of APs
N = 5  # number of nodes
R = 10  # radius

num_users_features = 3
num_aps_features = 3

num_train = 2  # number of training samples
num_test = 4  # number of test samples

reg = 1e-2
pmax = 1
var_db = 10
var = 1 / 10 ** (var_db / 10)
var_noise = 10e-11

power_threshold = 2.0

X_train, noise_train, pos_train, adj_train, index_train = generate_channels_wsn(K, N, num_train, var_noise, R)
X_test, noise_test, pos_test, adj_test, index_test = generate_channels_wsn(K + 1, N + 10, num_test, var_noise, R)

In [28]:
# Maybe need normalization here
train_data = convert_to_hetero_data(X_train)
test_data = convert_to_hetero_data(X_test)

batchSize = 1

train_loader = DataLoader(train_data, batchSize, shuffle=True, num_workers=1)
test_loader = DataLoader(test_data, batchSize, shuffle=True, num_workers=1)

In [146]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = train_data[0]
data = data.to(device)

model = HetNetGNN(data, hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)
model = model.to(device)

# # print(data.edge_index_dict)
# with torch.no_grad():
#     output = model(data.x_dict, data.edge_index_dict)
# print(output)
# print(data)

# data = test_data[0]
# data = data.to(device)
#
# with torch.no_grad():
#     output = model(data.x_dict, data.edge_index_dict)
#     print(output)


## Training and testing

In [176]:

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)

for epoch in range(1, 101):
    loss = train(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Test Reward: {test_acc:.4f}')

<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158C6EA9E0>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158EFB7970>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158EFB7970>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158EFB7C10>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158C6EA9E0>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158C6EA9E0>
Epoch: 001, Train Loss: -351.9912, Test Reward: 1687.1417
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158C6EA9E0>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158EFB79E0>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158EFB79E0>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158EFB74A0>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158C6EA9E0>
<generator object HetNetGNN.forward.<locals>.<genexpr> at 0x000002158C6EA9E0>
Epoch:

KeyboardInterrupt: 

In [178]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = train_data[0]
data = data.to(device)

model = HetNetGNN(data, hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)
model = model.to(device)
model.train()
device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
total_examples = total_loss = 0
for batch in train_loader:
    optimizer.zero_grad()
    batch = batch.to(device_type)
    # batch_size = batch['user'].batch_size
    break
print(batch.x_dict)
# out = model(batch.x_dict, batch.edge_index_dict)
# print(out.shape)
# tmp_loss = loss_function(out, data, True)
# print(tmp_loss)
# tmp_loss.backward()
# # Print computation graph for debugging
# print("Computation Graph:")
# for name, param in model.named_parameters():
#     if param.grad is not None:
#         print(name, param.grad.abs().sum())

# # optimizer.step()
# # #total_examples += batch_size
# # total_loss += float(tmp_loss) #* batch_size

{'user': tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]), 'ap': tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])}


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_user = batch['user']['x'].shape[0]
num_ap = batch['ap']['x'].shape[0]
##
channel_matrix = batch['user', 'ap']['edge_attr']
power_max = batch['user']['x'][:, 0]
power = batch['user']['x'][:, 1]
ap_selection = batch['user']['x'][:, 2]
ap_selection = ap_selection.int()
index = torch.arange(num_user)

G = torch.reshape(channel_matrix, (-1, num_ap, num_user))
# P = torch.reshape(power, (-1, num_ap, num_user)) #* p_max
P = torch.zeros_like(G, requires_grad=True).clone()
P[0, ap_selection[index], index] = power_max * power
##
# new_noise = torch.from_numpy(noise_matrix).to(device)
desired_signal = torch.sum(torch.mul(P, G), dim=1).unsqueeze(-1)
G_UE = torch.sum(G, dim=2).unsqueeze(-1)
all_signal = torch.matmul(P.permute((0,2,1)), G_UE)
interference = all_signal - desired_signal #+ new_noise
rate = torch.log(1 + torch.div(desired_signal, interference))
sum_rate = torch.mean(torch.sum(rate, 1))
mean_power = torch.mean(torch.sum(P.permute((0,2,1)), 1))

if is_train:
    return torch.neg(sum_rate / mean_power)
else:
    return sum_rate / mean_power