In [1]:
%load_ext autoreload
%autoreload 2

In [20]:
import torch
import random

import pandas as pd

from torch.utils.data import Dataset
import torch.nn.functional as F

from torch.optim import Adam

import torch_geometric.transforms as T

from torch_geometric.data import Batch

from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import global_add_pool
from torch_geometric.nn import GraphConv
from torch_geometric.loader import DataLoader

from pathlib import Path

from tqdm import tqdm

In [21]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)
from Dataset_Creation.script_pairs_creation_with_torch import get_subgraph_with_terminal_nodes, smiles_to_torch_geometric

In [22]:
class ZincSubgraphDataset(Dataset):

    def __init__(self, data_path):
        self.data_list = torch.load(data_path)


    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        
        preprocessed_graph = self.data_list[idx]
        
        mol_size = len(preprocessed_graph.x)
        num_atoms = random.choice(range(3, mol_size + 1))
        subgraph, terminal_nodes, id_map = get_subgraph_with_terminal_nodes(preprocessed_graph, num_atoms)

        subgraph.x[id_map[terminal_nodes[0]]][9] = 1

        #get the embedding of all the first element of terminal_nodes[1] and make them into a list to take the mean, if terminal_nodes[1] empty make torch.zeros(10)
        label_gnn1 = torch.zeros(10)
        neighbor_atom_list = [neighbor[1] for neighbor in terminal_nodes[1]]

        if len(neighbor_atom_list) != 0:
            label_gnn1 += torch.mean(torch.stack(neighbor_atom_list, dim=0), dim=0)

        subgraph.y = label_gnn1

        return subgraph

In [23]:
import os.path as osp
import torch
import random
from torch_geometric.data import Dataset as geomDataset

class ZincSubgraphDataset(geomDataset):
    def __init__(self, root, data_path, transform=None, pre_transform=None):
        self.data_path = data_path
        self.data_list = torch.load(data_path)
        super(ZincSubgraphDataset, self).__init__(root, transform, pre_transform) 

    def len(self):
        return len(self.data_list)
    
    def get(self, idx):
        preprocessed_graph = self.data_list[idx]
        
        mol_size = len(preprocessed_graph.x)
        num_atoms = random.choice(range(3, mol_size + 1))
        subgraph, terminal_nodes, id_map = get_subgraph_with_terminal_nodes(preprocessed_graph, num_atoms)

        subgraph.x[id_map[terminal_nodes[0]]][9] = 1

        label_gnn1 = torch.zeros(10)
        neighbor_atom_list = [neighbor[1] for neighbor in terminal_nodes[1]]

        if len(neighbor_atom_list) != 0:
            label_gnn1 += torch.mean(torch.stack(neighbor_atom_list, dim=0), dim=0)

        subgraph.y = label_gnn1

        return subgraph

In [24]:
datapath = Path('..') / 'Dataset_Creation/preprocessed_graph.pt'
dataset = ZincSubgraphDataset(data_path = datapath, root='fdp')

In [26]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=1)

In [27]:
i = 0
for batch in dataloader:
    print(batch)
    break

In [11]:
#dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

from pyinstrument import Profiler

profiler = Profiler()

profiler.start()
i = 0
for batch in dataloader:
    i += 1
    if i == 100:
        break
profiler.stop()

print(profiler.output_text(unicode=True, color=True))


  _     ._   __/__   _ _  _  _ _/_   Recorded: 16:22:57  Samples:  3497
 /_//_/// /_\ / //_// / //_'/ //     Duration: 3.543     CPU time: 4.578
/   _/                      v4.4.0

Program: c:\Users\goupi\.conda\envs\torch_geometric\lib\site-packages\ipykernel_launcher.py --ip=127.0.0.1 --stdin=9019 --control=9017 --hb=9016 --Session.signature_scheme="hmac-sha256" --Session.key=b"915a24f9-0548-49ba-bdae-7ff7a6bee7d0" --shell=9018 --transport="tcp" --iopub=9020 --f=c:\Users\goupi\AppData\Roaming\jupyter\runtime\kernel-v2-6060P19y4jU9Faff.json

[31m3.542[0m [48;5;24m[38;5;15m<module>[0m  [2m..\..\..\AppData\Local\Temp\ipykernel_20896\3692288471.py:1[0m
└─ [31m3.534[0m _SingleProcessDataLoaderIter.__next__[0m  [2mtorch\utils\data\dataloader.py:623[0m
      [175 frames hidden]  [2mtorch, torch_geometric, <built-in>, _...[0m
         [31m3.307[0m ZincSubgraphDataset.__getitem__[0m  [2mtorch_geometric\data\dataset.py:238[0m
         └─ [31m3.286[0m [48;5;24m[38;5;15mZ

In [15]:
class ModelGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels_list, mlp_hidden_channels, num_classes=10):
        super(ModelGCN, self).__init__()
        torch.manual_seed(12345)

        assert len(hidden_channels_list) == 6, "hidden_channels_list must have 6 elements"

        self.conv1 = GCNConv(in_channels, hidden_channels_list[0])
        self.conv2 = GCNConv(hidden_channels_list[0], hidden_channels_list[1])
        self.conv3 = GCNConv(hidden_channels_list[1], hidden_channels_list[2])
        self.conv4 = GCNConv(hidden_channels_list[2], hidden_channels_list[3])
        self.conv5 = GCNConv(hidden_channels_list[3], hidden_channels_list[4])
        self.conv6 = GCNConv(hidden_channels_list[4], hidden_channels_list[5])

        self.fc1 = torch.nn.Linear(hidden_channels_list[5], mlp_hidden_channels)
        self.fc2 = torch.nn.Linear(mlp_hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = self.conv5(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = self.conv6(x, edge_index)
        x = F.relu(x)

        # Aggregation function to obtain graph embedding
        x = global_add_pool(x, batch)

        # Two-layer MLP for classification
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

In [17]:
model = ModelGCN(in_channels=10, hidden_channels_list=[16, 32, 64, 128, 256, 512], mlp_hidden_channels=128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


# Set up the optimizer and loss function
optimizer = Adam(model.parameters(), lr=0.001)
#crossentropy
criterion = torch.nn.CrossEntropyLoss()

# Training function

def train(loader):
    model.train()
    total_loss = 0
    mse_sum = 0
    progress_bar = tqdm(loader, desc="Training", unit="batch")
    i = 0

    for batch in progress_bar:

        i += 1
        data = batch
        print(data)
        terminal_node_infos = data.y
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        terminal_node_infos = terminal_node_infos.to(device)
        loss = criterion(out, terminal_node_infos)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
        loss_value = total_loss / (data.num_graphs * (progress_bar.last_print_n + 1))

        # Compute MSE
        mse = torch.mean((terminal_node_infos.detach().cpu() - out.detach().cpu()) ** 2)
        mse_sum += mse * data.num_graphs
        mse_value = mse_sum / (data.num_graphs * (progress_bar.last_print_n + 1))

        
        progress_bar.set_postfix(loss=loss_value, mse=mse_value)
    return total_loss / len(loader.dataset)
# Train the model
n_epochs = 100
for epoch in range(1, n_epochs+1):
    loss = train(dataloader)
    print(f'Epoch: {epoch}, Loss: {loss:.4f}')

Training:   0%|          | 0/7796 [00:00<?, ?batch/s]

DataBatch(x=[742, 10], edge_index=[2, 1596], edge_attr=[1596, 4], batch=[742], ptr=[33])





AttributeError: 'NoneType' object has no attribute 'to'