In [None]:
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-2.1.0+cu118.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-2.1.0+cu118.html
!pip install torch-geometric
!pip install ogb

In [2]:
import torch_geometric.transforms as T
from ogb.graphproppred import PygGraphPropPredDataset
import torch
dataset_name = 'ogbg-molhiv'
# Load the dataset and transform it to sparse tensor
dataset = PygGraphPropPredDataset(name=dataset_name,
                                transform=T.ToSparseTensor())
print('The {} dataset has {} graph'.format(dataset_name, len(dataset)))

# Extract the graph
print(dataset[0])
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# If you use GPU, the device should be cuda
print('Device: {}'.format(device))

Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip


Downloaded 0.00 GB: 100%|██████████| 3/3 [00:02<00:00,  1.05it/s]
Processing...


Extracting dataset/hiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 41127/41127 [00:02<00:00, 20005.38it/s]


Converting graphs into PyG objects...


100%|██████████| 41127/41127 [00:04<00:00, 9247.06it/s]


Saving...
The ogbg-molhiv dataset has 41127 graph
Data(edge_attr=[40, 3], x=[19, 9], y=[1, 1], num_nodes=19, adj_t=[19, 19, nnz=40])
Device: cuda


Done!


In [33]:
class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers,
                 dropout):
        super(GCN, self).__init__()

        # A list of GCNConv layers
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(input_dim, hidden_dim))
        for _ in range(num_layers -2):
          self.convs.append(GCNConv(hidden_dim, hidden_dim))
        self.convs.append(GCNConv(hidden_dim, output_dim))

        # A list of 1D batch normalization layers
        self.bns = torch.nn.ModuleList()
        for _ in range(num_layers - 1):
          self.bns.append(torch.nn.BatchNorm1d(hidden_dim))
        self.dropout = dropout

        self.linear = torch.nn.Linear(output_dim, 1)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t, batch):

        out = None

        for i in range(len(self.convs) - 1):
          x = self.convs[i](x, adj_t)
          x = self.bns[i](x)
          x = F.relu(x)
          x = F.dropout(x, self.dropout, training=self.training)
        x = self.convs[len(self.convs) - 1](x, adj_t)
        x = global_mean_pool(x, batch)

        x = self.linear(x)

        out = x

        return out

In [34]:
import torch
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.datasets import OGB_MAG
import numpy as np
from torch_geometric.nn.pool import global_mean_pool
from tqdm import tqdm

split_idx = dataset.get_idx_split()
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=64, shuffle=True)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=64, shuffle=False)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=64, shuffle=False)

# Define the GCN model
# class GCNModel(torch.nn.Module):
#     def __init__(self, in_channels, hidden_channels, out_channels):
#         super(GCNModel, self).__init__()
#         self.conv1 = GCNConv(in_channels, hidden_channels)
#         self.conv2 = GCNConv(hidden_channels, out_channels)
#         self.linear = Linear(out_channels, 1)
#         self.sigmoid = torch.nn.Sigmoid()

#     def forward(self, x, edge_index, batch):
#         x = self.conv1(x, edge_index)
#         x = F.relu(x)
#         x = self.conv2(x, edge_index)
#         x = F.relu(x)
#         x = global_add_pool(x, batch)  # Global pooling operation (sum in this case)
#         x = self.linear(x)
#         x = self.sigmoid(x)
#         return x

# model = GCNModel(in_channels=dataset.num_node_features,
#                  hidden_channels=64,
#                  out_channels=64).to(device)
args = {
      'device': device,
      'num_layers': 5,
      'hidden_dim': 256,
      'dropout': 0.5,
      'lr': 0.01,
      'epochs': 50,
      "out_channels": 128,
  }

model = GCN(dataset.num_node_features, args['hidden_dim'],
            args["out_channels"], args['num_layers'],
            args['dropout']).to(device)
# evaluator = Evaluator(name='ogbg-molhiv')
# Loss and optimizer
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])

# Training loop
def train():
  model.train()
  for epoch in range(args["epochs"]):
      train_loss = []
      for data in tqdm(train_loader, total=len(train_loader)):
          optimizer.zero_grad()
          data = data.to(device)
          data.adj_t = data.adj_t.to_symmetric()
          data.x = data.x.float()
          out = model(data.x, data.adj_t, data.batch)
          loss = criterion(out, data.y.view(-1, 1).to(torch.float32))
          loss.backward()
          optimizer.step()
          train_loss.append(loss.item())

      print(f"Epoch #{epoch + 1}. Loss:{sum(train_loss)/len(train_loss)}")




In [35]:
train()

100%|██████████| 515/515 [00:20<00:00, 25.32it/s]


Epoch #1. Loss:0.16835968867215573


100%|██████████| 515/515 [00:19<00:00, 26.06it/s]


Epoch #2. Loss:0.1575283563976149


100%|██████████| 515/515 [00:20<00:00, 25.62it/s]


Epoch #3. Loss:0.15420119037978278


100%|██████████| 515/515 [00:20<00:00, 25.69it/s]


Epoch #4. Loss:0.15357839612706195


100%|██████████| 515/515 [00:19<00:00, 25.87it/s]


Epoch #5. Loss:0.15068097683121856


100%|██████████| 515/515 [00:19<00:00, 26.54it/s]


Epoch #6. Loss:0.14938519436829878


100%|██████████| 515/515 [00:19<00:00, 26.07it/s]


Epoch #7. Loss:0.1494376913219401


100%|██████████| 515/515 [00:19<00:00, 26.21it/s]


Epoch #8. Loss:0.14794742027749713


100%|██████████| 515/515 [00:19<00:00, 26.00it/s]


Epoch #9. Loss:0.14733127321216088


100%|██████████| 515/515 [00:19<00:00, 25.76it/s]


Epoch #10. Loss:0.14585479467528537


100%|██████████| 515/515 [00:19<00:00, 27.03it/s]


Epoch #11. Loss:0.14459006704245378


100%|██████████| 515/515 [00:20<00:00, 25.35it/s]


Epoch #12. Loss:0.14411384937470978


100%|██████████| 515/515 [00:20<00:00, 25.69it/s]


Epoch #13. Loss:0.14337992800887928


100%|██████████| 515/515 [00:20<00:00, 25.62it/s]


Epoch #14. Loss:0.1433820041131626


100%|██████████| 515/515 [00:20<00:00, 25.25it/s]


Epoch #15. Loss:0.14302920326590537


100%|██████████| 515/515 [00:20<00:00, 25.49it/s]


Epoch #16. Loss:0.14160348629597033


100%|██████████| 515/515 [00:20<00:00, 25.29it/s]


Epoch #17. Loss:0.1427118523643144


100%|██████████| 515/515 [00:19<00:00, 26.00it/s]


Epoch #18. Loss:0.140979259265858


100%|██████████| 515/515 [00:19<00:00, 26.21it/s]


Epoch #19. Loss:0.1397275684383309


100%|██████████| 515/515 [00:19<00:00, 25.92it/s]


Epoch #20. Loss:0.1390751602652582


100%|██████████| 515/515 [00:20<00:00, 25.69it/s]


Epoch #21. Loss:0.13888396271007153


100%|██████████| 515/515 [00:19<00:00, 26.00it/s]


Epoch #22. Loss:0.14243143180982


100%|██████████| 515/515 [00:21<00:00, 23.70it/s]


Epoch #23. Loss:0.13867002385261568


100%|██████████| 515/515 [00:20<00:00, 25.02it/s]


Epoch #24. Loss:0.13643974880281004


100%|██████████| 515/515 [00:19<00:00, 26.56it/s]


Epoch #25. Loss:0.13797422500634657


100%|██████████| 515/515 [00:20<00:00, 25.35it/s]


Epoch #26. Loss:0.13724316035225553


100%|██████████| 515/515 [00:20<00:00, 25.52it/s]


Epoch #27. Loss:0.13614902920586971


100%|██████████| 515/515 [00:19<00:00, 26.09it/s]


Epoch #28. Loss:0.1372689637454968


100%|██████████| 515/515 [00:19<00:00, 25.84it/s]


Epoch #29. Loss:0.13713945907007144


100%|██████████| 515/515 [00:19<00:00, 25.95it/s]


Epoch #30. Loss:0.13465549430439194


100%|██████████| 515/515 [00:19<00:00, 26.22it/s]


Epoch #31. Loss:0.13548332924583872


100%|██████████| 515/515 [00:20<00:00, 25.53it/s]


Epoch #32. Loss:0.13408766666662345


100%|██████████| 515/515 [00:18<00:00, 27.25it/s]


Epoch #33. Loss:0.1340686790211108


100%|██████████| 515/515 [00:19<00:00, 26.89it/s]


Epoch #34. Loss:0.13449123681351108


100%|██████████| 515/515 [00:19<00:00, 26.53it/s]


Epoch #35. Loss:0.1348413998568521


100%|██████████| 515/515 [00:18<00:00, 27.62it/s]


Epoch #36. Loss:0.13534355958905614


100%|██████████| 515/515 [00:19<00:00, 26.25it/s]


Epoch #37. Loss:0.13407988651718908


100%|██████████| 515/515 [00:18<00:00, 27.70it/s]


Epoch #38. Loss:0.13316822811888837


100%|██████████| 515/515 [00:19<00:00, 26.62it/s]


Epoch #39. Loss:0.1333301894981595


100%|██████████| 515/515 [00:19<00:00, 27.00it/s]


Epoch #40. Loss:0.13379097670967718


100%|██████████| 515/515 [00:19<00:00, 26.55it/s]


Epoch #41. Loss:0.13402390969031067


100%|██████████| 515/515 [00:19<00:00, 26.19it/s]


Epoch #42. Loss:0.13390067190799898


100%|██████████| 515/515 [00:18<00:00, 27.19it/s]


Epoch #43. Loss:0.13243675884763592


100%|██████████| 515/515 [00:19<00:00, 26.25it/s]


Epoch #44. Loss:0.13381724636097556


100%|██████████| 515/515 [00:18<00:00, 27.67it/s]


Epoch #45. Loss:0.1318578983394845


100%|██████████| 515/515 [00:19<00:00, 26.26it/s]


Epoch #46. Loss:0.13218484734845104


100%|██████████| 515/515 [00:19<00:00, 26.49it/s]


Epoch #47. Loss:0.13178720999038915


100%|██████████| 515/515 [00:18<00:00, 27.70it/s]


Epoch #48. Loss:0.13215007878550628


100%|██████████| 515/515 [00:19<00:00, 26.49it/s]


Epoch #49. Loss:0.13185780259590704


100%|██████████| 515/515 [00:18<00:00, 27.65it/s]


Epoch #50. Loss:0.1316569222155416


100%|██████████| 515/515 [00:20<00:00, 25.72it/s]


Epoch #51. Loss:0.13189710203232696


100%|██████████| 515/515 [00:19<00:00, 26.92it/s]


Epoch #52. Loss:0.13129919654536015


100%|██████████| 515/515 [00:19<00:00, 26.31it/s]


Epoch #53. Loss:0.13125096568204825


100%|██████████| 515/515 [00:19<00:00, 26.35it/s]


Epoch #54. Loss:0.13100889432198792


100%|██████████| 515/515 [00:19<00:00, 26.90it/s]


Epoch #55. Loss:0.13246402627996448


100%|██████████| 515/515 [00:19<00:00, 26.84it/s]


Epoch #56. Loss:0.13054570040656524


100%|██████████| 515/515 [00:18<00:00, 27.39it/s]


Epoch #57. Loss:0.1301269798530537


100%|██████████| 515/515 [00:19<00:00, 26.22it/s]


Epoch #58. Loss:0.1301761113215708


100%|██████████| 515/515 [00:19<00:00, 26.53it/s]


Epoch #59. Loss:0.12980358979474863


100%|██████████| 515/515 [00:19<00:00, 26.97it/s]


Epoch #60. Loss:0.13116266848609864


100%|██████████| 515/515 [00:19<00:00, 26.56it/s]


Epoch #61. Loss:0.13015427439319857


100%|██████████| 515/515 [00:18<00:00, 27.36it/s]


Epoch #62. Loss:0.12953638474194748


100%|██████████| 515/515 [00:19<00:00, 25.89it/s]


Epoch #63. Loss:0.1305488144707622


100%|██████████| 515/515 [00:19<00:00, 25.75it/s]


Epoch #64. Loss:0.13023425561638133


100%|██████████| 515/515 [00:18<00:00, 27.39it/s]


Epoch #65. Loss:0.13047486305309153


100%|██████████| 515/515 [00:19<00:00, 26.77it/s]


Epoch #66. Loss:0.1293928845848852


100%|██████████| 515/515 [00:19<00:00, 27.09it/s]


Epoch #67. Loss:0.13011891986198218


100%|██████████| 515/515 [00:19<00:00, 26.76it/s]


Epoch #68. Loss:0.1284836667363794


100%|██████████| 515/515 [00:19<00:00, 26.53it/s]


Epoch #69. Loss:0.1294177018576166


100%|██████████| 515/515 [00:18<00:00, 27.67it/s]


Epoch #70. Loss:0.12895445773757777


100%|██████████| 515/515 [00:19<00:00, 26.92it/s]


Epoch #71. Loss:0.12873392074721532


100%|██████████| 515/515 [00:19<00:00, 27.07it/s]


Epoch #72. Loss:0.12945422799624864


100%|██████████| 515/515 [00:19<00:00, 26.06it/s]


Epoch #73. Loss:0.12656444669927208


100%|██████████| 515/515 [00:19<00:00, 26.65it/s]


Epoch #74. Loss:0.12821203915818224


100%|██████████| 515/515 [00:18<00:00, 27.17it/s]


Epoch #75. Loss:0.12890594054411336


100%|██████████| 515/515 [00:19<00:00, 26.38it/s]


Epoch #76. Loss:0.12752900194485212


100%|██████████| 515/515 [00:18<00:00, 27.59it/s]


Epoch #77. Loss:0.12886881298448855


100%|██████████| 515/515 [00:19<00:00, 26.48it/s]


Epoch #78. Loss:0.12744790784725288


100%|██████████| 515/515 [00:18<00:00, 27.26it/s]


Epoch #79. Loss:0.12786139718537193


100%|██████████| 515/515 [00:19<00:00, 26.63it/s]


Epoch #80. Loss:0.12867951837268848


100%|██████████| 515/515 [00:19<00:00, 26.50it/s]


Epoch #81. Loss:0.12796434079414432


100%|██████████| 515/515 [00:19<00:00, 26.90it/s]


Epoch #82. Loss:0.1274068976454075


100%|██████████| 515/515 [00:19<00:00, 26.62it/s]


Epoch #83. Loss:0.1273817383367749


100%|██████████| 515/515 [00:18<00:00, 27.49it/s]


Epoch #84. Loss:0.1286149532000706


100%|██████████| 515/515 [00:19<00:00, 26.20it/s]


Epoch #85. Loss:0.1275926939745262


100%|██████████| 515/515 [00:18<00:00, 27.37it/s]


Epoch #86. Loss:0.12703833896124248


100%|██████████| 515/515 [00:19<00:00, 26.27it/s]


Epoch #87. Loss:0.12756840226792016


100%|██████████| 515/515 [00:19<00:00, 26.72it/s]


Epoch #88. Loss:0.12806099535336773


100%|██████████| 515/515 [00:19<00:00, 27.02it/s]


Epoch #89. Loss:0.12923624527468844


100%|██████████| 515/515 [00:20<00:00, 25.72it/s]


Epoch #90. Loss:0.12634014516200834


100%|██████████| 515/515 [00:18<00:00, 27.15it/s]


Epoch #91. Loss:0.1269565778739244


100%|██████████| 515/515 [00:19<00:00, 25.77it/s]


Epoch #92. Loss:0.12790496219130396


100%|██████████| 515/515 [00:19<00:00, 25.93it/s]


Epoch #93. Loss:0.1291959153651034


100%|██████████| 515/515 [00:19<00:00, 26.30it/s]


Epoch #94. Loss:0.12699606327466595


100%|██████████| 515/515 [00:20<00:00, 25.71it/s]


Epoch #95. Loss:0.12638294635684166


100%|██████████| 515/515 [00:20<00:00, 25.53it/s]


Epoch #96. Loss:0.1261535531904512


100%|██████████| 515/515 [00:19<00:00, 27.04it/s]


Epoch #97. Loss:0.12707570465733703


100%|██████████| 515/515 [00:19<00:00, 26.16it/s]


Epoch #98. Loss:0.12533599470642584


100%|██████████| 515/515 [00:19<00:00, 26.36it/s]


Epoch #99. Loss:0.12485843604629479


100%|██████████| 515/515 [00:19<00:00, 26.14it/s]

Epoch #100. Loss:0.12656981169417936





In [32]:
# Evaluation
model.eval()
with torch.no_grad():
    y_true = []
    y_pred = []
    for data in valid_loader:
        data = data.to(device)
        data.x = data.x.float()
        data.adj_t = data.adj_t.to_symmetric()
        out = model(data.x, data.adj_t, data.batch)
        y_true.append(data.y.view(-1).cpu().numpy())
        y_pred.append((out > 0.5).view(-1).cpu().numpy())

y_true = np.concatenate(y_true)
y_pred = np.concatenate(y_pred)

# Evaluate using appropriate metrics (e.g., accuracy, F1 score, ROC AUC)
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
print(f"F1 Score: {f1_score(y_true, y_pred)}")
print(f"ROC AUC Score: {roc_auc_score(y_true, y_pred)}")

Accuracy: 0.9807926088013615
F1 Score: 0.20202020202020202
ROC AUC Score: 0.560736331569665


In [None]:
print(list(y_true).count(0))
print(list(y_true).count(1))

4032
81


In [25]:
for i in range(len(y_true)):
  if y_pred[i] == True:
    print(y_true[i], y_pred[i])

1 True
0 True
0 True
1 True
1 True
1 True
1 True
1 True
1 True
0 True
1 True
1 True
