In [8]:
from dataset_copy_optimize_label import *
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
args = ("../data/one_graph/network_file_0self.csv", "../data/jaaks_druginfo_ttid.csv")    
data = process_data(args) 

Create Network
--------------

In [2]:
embed_dim = 128
input_dim = data.num_node_features
output_dim = data.y.shape[1]
setup_seed(2023)

In [3]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(input_dim, embed_dim)
        self.conv2 = GCNConv(embed_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return torch.sigmoid(x)

Instantiate model
-----------------

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

Train model
-----------

In [5]:
model.train()
for epoch in range(600):
    optimizer.zero_grad()
    out = model(data)
    loss = F.binary_cross_entropy(out[data.train_idx], data.y[data.train_idx])
    if epoch % 10 == 0:
        print(f'Loss:{loss}')
    loss.backward()
    optimizer.step()

Loss:0.7119203209877014
Loss:0.4707825481891632
Loss:0.2671552896499634
Loss:0.1779192090034485
Loss:0.15560877323150635
Loss:0.14707860350608826
Loss:0.14734981954097748
Loss:0.13702495396137238
Loss:0.1361154466867447
Loss:0.13740789890289307
Loss:0.13260656595230103
Loss:0.1307443529367447
Loss:0.12728293240070343
Loss:0.12386992573738098
Loss:0.11757905781269073
Loss:0.12647682428359985
Loss:0.11863357573747635
Loss:0.11520705372095108
Loss:0.11335065960884094
Loss:0.11027797311544418
Loss:0.11165044456720352
Loss:0.10897687077522278
Loss:0.1066279485821724
Loss:0.10652278363704681
Loss:0.1069236621260643
Loss:0.10305405408143997
Loss:0.1011844128370285
Loss:0.10217678546905518
Loss:0.09754178673028946
Loss:0.10168968141078949
Loss:0.09983055293560028
Loss:0.10412181913852692
Loss:0.09785719960927963
Loss:0.09754735976457596
Loss:0.09873894602060318
Loss:0.09363508969545364
Loss:0.09452679753303528
Loss:0.09341519325971603
Loss:0.09407299011945724
Loss:0.09698154032230377
Loss:0.09

Evaluate accuracy
-----------------

In [6]:
model.eval()
pred = model(data)
pred_label = np.zeros(data.y[data.test_mask].size())
pred = pred[data.test_mask].detach().numpy()
for i, v in enumerate(data.y[data.test_mask].detach().numpy()):
    sum_ = int(v.sum())
    if sum_ == 1:
        l = np.argsort(pred[i])[-1]
        pred_label[i, l] = 1
    else:
        l = np.argsort(pred[i])[-sum_:]
        pred_label[i, l] = 1
acc = accuracy_score(pred_label, data.y[data.test_mask].detach().numpy())

In [7]:
print(f'Accuracy: {acc}')

Accuracy: 0.6058394160583942
