In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from dgl.data import DGLDataset
import dgl
import dgl.function as fn
import torch.nn as nn
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Sigmoid, BatchNorm1d as BN, ReLU6 as ReLU6

In [2]:
train_layouts = 100000
test_layouts = 2000

train_n = 70
test_n  = [4,5,6,35]

In [3]:
def data_generate(n_node,K):
    inputs = np.zeros((K,n_node))
    label = np.zeros((K,n_node))
    for i_sample in range(K):
        x = np.random.random(n_node)
        y = []
        for i in range(n_node):
            y_ = x.copy()[i]
            x_copy    = np.delete(x.copy(),i)
            max_index = np.argmax(x_copy)
            x_copy    = np.sort(np.delete(x_copy,max_index))
            for j in range(n_node-2):
                y_ += (x_copy[j]**2)*x_copy[n_node-3-j]
            y.append(y_/n_node)
        y = np.array(y)
        inputs[i_sample] = x
        label[i_sample]  = y
    return inputs,label

In [4]:
class PCDataset(DGLDataset):
    def __init__(self, data, label,n_node):
        self.data = data
        self.n_node = n_node
        self.label = np.expand_dims(label, axis = -1)
        self.get_cg()
        super().__init__(name='power_control')
        
        
    def build_graph(self, idx): 
        H = self.data[idx,:]
        
        graph = dgl.graph(self.adj, num_nodes=self.n_node)
        
        node_features = torch.tensor(np.expand_dims(H,axis=1), dtype = torch.float)
        ## Node feature of the k-th node is the direct link channel of k-th pair
        node_labels = torch.tensor(self.label[idx,:], dtype = torch.float)
        ## Node label is the power obtained by FPlinQ
        
        # edge_features  = []
        # for e in self.adj:
        #     edge_features.append([H[e[0],e[1]],H[e[1],e[0]]])
        ## Edge feature between node e[0] and e[1] is the interference channel between e[0]-th pair and e[1]-th pair

        graph.ndata['feat'] = node_features
        graph.ndata['label'] = node_labels
        # graph.edata['feat'] = torch.tensor(edge_features, dtype = torch.float)
        return graph
    
    def get_cg(self):
        ## The graph is a complete graph
        self.adj = []
        for i in range(0,self.n_node):
            for j in range(0,self.n_node):
                if(not(i==j)):
                    self.adj.append([i,j])
                    
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.data)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        return self.graph_list[index]

    def process(self):
        n = len(self.data)
        self.graph_list = []
        for i in range(n):
            graph = self.build_graph(i)
            self.graph_list.append(graph)

def collate(samples):
    '''DGL collate function'''
    graphs = samples
    batched_graph = dgl.batch(graphs)
    return batched_graph

In [5]:
x_train, y_train = data_generate(train_n, train_layouts)

In [6]:
test_data_list = []
for n_test in test_n:
    x_test, y_test = data_generate(n_test, test_layouts)
    test_data_list.append((x_test, y_test,n_test))

In [7]:
train_data = PCDataset(x_train, y_train, train_n)
test_data_list = [PCDataset(x_test, y_test, n_test) for (x_test, y_test, n_test) in test_data_list]

In [8]:
batch_size = 64
train_loader = DataLoader(train_data, batch_size, shuffle=True, collate_fn=collate)
test_loader_list = [DataLoader(test_data_list[i], test_layouts, shuffle=False, collate_fn=collate) for i in range(len(test_n))]

In [9]:
def MLP(channels, batch_norm=True):
    return Seq(*[
        Seq(Lin(channels[i - 1], channels[i]), ReLU(), BN(channels[i]))
        for i in range(1, len(channels))
    ])
class EdgeConv(nn.Module):
    def __init__(self, mlp, **kwargs):
        super(EdgeConv, self).__init__()
        self.mlp = mlp
        #self.reset_parameters()

    def concat_message_function(self, edges):
        return {'out': torch.cat([edges.src['hid'], edges.dst['hid']], axis=1)}
    
    def forward(self, g):
        g.apply_edges(self.concat_message_function)
        g.edata['out'] = self.mlp(g.edata['out'])
        g.update_all(fn.copy_e('out', 'm'),
                     fn.max('m', 'hid'))

In [10]:
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = EdgeConv(MLP([2, 16]))
        self.conv2 = EdgeConv(MLP([2*16, 32]))
        self.mlp = MLP([32, 16])
        self.mlp = Seq(*[self.mlp, Seq(Lin(16, 1))])

    def forward(self, g):
        g.ndata['hid'] = g.ndata['feat']
        self.conv1(g)
        self.conv2(g)
        out = self.mlp(g.ndata['hid'])
        return out

In [11]:
def train(epoch):
    """ Train for one epoch. """
    model.train()
    loss_all = 0
    for batch_idx, g in enumerate(train_loader):
        #data = data.to(device)
        optimizer.zero_grad()
        g = g.to("cuda:0")
        output = model(g)
        loss = F.mse_loss(output, g.ndata['label'])
        loss.backward()
        loss_all += loss.item() * len(g.ndata['feat'])
        optimizer.step()
    return loss_all / len(train_loader.dataset)

In [12]:
def test(loader, train_K, test_mode = False):
    model.eval()
    mse = nmse = 0
    for i,g in enumerate(loader) :
        n = len(g.ndata['feat'])
        bs = len(g.ndata['feat'])//train_K
        g = g.to("cuda:0")
        output = model(g).reshape(bs,-1)
        y_test = g.ndata['label'].reshape(bs,-1)
        loss = F.mse_loss(output, y_test)
        mse += loss.item() * bs
        # if i==0:
        #     print(output.shape)
        if test_mode:
            nmse += (((output - y_test)**2).sum(axis = -1)/(y_test**2).sum(axis = -1)).sum().item()
    if test_mode:
        return mse / len(loader.dataset), nmse / len(loader.dataset)
    return mse / len(loader.dataset)

In [13]:
model = GCN().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [14]:
for epoch in range(0, 20):
    if(epoch % 1 == 0):
        loss = test(train_loader,train_n)
        print('Epoch {:03d}, Train Loss: {:.4f}'.format(epoch, loss),end=" ")
        print("Val MSE:", end=" ")
        for (loader, n_test) in zip(test_loader_list,test_n):
            mse, nmse = test(loader, n_test, True)
            print('{:.4f}'.format(mse),end=", ")
        print("\n")
    train(epoch)
    #scheduler.step()

Epoch 000, Train Loss: 0.0307 Val MSE: 0.0748, 0.0643, 0.0573, 0.0329, 

Epoch 001, Train Loss: 0.0017 Val MSE: 11.5856, 6.6241, 4.2646, 0.0066, 

Epoch 002, Train Loss: 0.0012 Val MSE: 10.7167, 5.9837, 3.8112, 0.0048, 

Epoch 003, Train Loss: 0.0011 Val MSE: 10.1985, 5.6252, 3.5651, 0.0042, 

Epoch 004, Train Loss: 0.0009 Val MSE: 10.5501, 5.8063, 3.6765, 0.0033, 

Epoch 005, Train Loss: 0.0009 Val MSE: 10.3010, 5.5790, 3.5104, 0.0035, 

Epoch 006, Train Loss: 0.0008 Val MSE: 9.7856, 5.3010, 3.3344, 0.0029, 

Epoch 007, Train Loss: 0.0008 Val MSE: 9.7293, 5.1906, 3.2431, 0.0031, 

Epoch 008, Train Loss: 0.0008 Val MSE: 9.6087, 5.1152, 3.1969, 0.0030, 

