# Graph Neural Network

## Importing

In [1]:
from selfdist_toolkit.pyg_tools import gnn_load, GIN_nn, execution, sd_utils
import pandas as pd
import torch
import torch_geometric
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Loading aid_list

In [2]:
aid_list = pd.read_csv("results/random_forest/experiments_check/chem-desc_good-aid_1.csv").aid.to_numpy().astype(int)

In [3]:
aid_list

array([    884,     891,     899,     914,    1418,    1431,    1770,
          1771,    1795,  493073,  493102,  493177,  493191,  493240,
        588834,  651741,  651812,  651814,  686978,  687022,  720691,
        743036,  743040,  743065, 1053173, 1259381, 1346982])

In [4]:
# for one aid now
aid = aid_list[0]
random_state = 131313

## Select mode: smooth for 2 dim label and hard for 1-dim label

In [5]:
# mode = "smooth"
mode = "hard"

## Load pytorch data

In [6]:
whole_data = gnn_load.load_pyg_data_aid(aid=aid, label_type=mode, do_in_parallel=True)

In [7]:
whole_data[:4]

[Data(x=[19, 9], edge_index=[2, 42], edge_attr=[42, 3], smiles='CC1(C=CC2=C(O1)C3=C(C=CC(=C3)OC)NC2=O)C', y=[1]),
 Data(x=[10, 9], edge_index=[2, 20], edge_attr=[20, 3], smiles='C1CN=C(N1)SCC(=O)O', y=[1]),
 Data(x=[29, 9], edge_index=[2, 64], edge_attr=[64, 3], smiles='CC(=O)OCC(=O)[C@]1(CC[C@@H]2[C@@]1(CC(=O)[C@H]3[C@H]2CCC4=CC(=O)CC[C@]34C)C)O', y=[1]),
 Data(x=[25, 9], edge_index=[2, 56], edge_attr=[56, 3], smiles='C[C@]12CCC(=O)C=C1CC[C@@H]3[C@@H]2CC[C@]4([C@H]3CC[C@@]4(C(=O)CO)O)C', y=[1])]

In [8]:
# build labels for split
labels_hard = np.array([
    data.y.detach().cpu().numpy()
    for data in whole_data
]).flatten().astype(int)

In [9]:
labels_hard.shape

(9593,)

In [10]:
labels_hard.sum()

3274

In [11]:
# data splitting
for train_idx, test_idx in StratifiedShuffleSplit(n_splits=1, random_state=random_state, test_size=0.2).split(whole_data, labels_hard):
    break

In [12]:
labels_hard[train_idx].sum()/len(train_idx)

0.34128225175918686

In [13]:
labels_hard[test_idx].sum()/len(test_idx)

0.341323606044815

In [14]:
# create the data loader
dl_train = torch_geometric.loader.DataLoader([whole_data[idx] for idx in train_idx], batch_size=100)
dl_test = torch_geometric.loader.DataLoader([whole_data[idx] for idx in test_idx], batch_size=100)

## Instantiate model

In [15]:
# GNN model
model = GIN_nn.GIN_basic(1)

In [16]:
# loss
loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([len(train_idx)/labels_hard[train_idx].sum()]))

In [17]:
# device
device = torch.device('cuda')

In [18]:
device

device(type='cuda')

In [19]:
model = model.to(device)
loss = loss.to(device)

In [20]:
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## Train model

In [21]:
# training:
execution.training(model, dl_train, device, optimizer, loss)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [00:03<00:00, 20.75it/s]


0.8897513

## Testing model

In [22]:
# testing:
pred_hard = execution.predict(model, dl_test, device, reduce_to_hard_label=True)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 73.39it/s]


In [23]:
pred_hard.sum()/pred_hard.shape[0]

0.39030745179781134

In [24]:
for batch in dl_train:
    break

In [25]:
batch

DataBatch(x=[2366, 9], edge_index=[2, 5100], edge_attr=[5100, 3], smiles=[100], y=[100], batch=[2366], ptr=[101])

In [26]:
batch = batch.to(device)

In [27]:
pred = model(batch).flatten()

In [28]:
loss.to(device)

BCEWithLogitsLoss()

In [29]:
loss(pred, batch.y)

tensor(0.9469, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

## Self distillation elements

In [30]:
# define number_to_pick
number_to_pick = 100

In [31]:
# load data for self distillation
selfdistillation_data = sd_utils.generate_gnn_sd_data(aid=aid)

In [32]:
# get good elements to insert into training pool for student model
selected, remaining = execution.self_distillation_procedure_1dim(
    model=model, 
    self_distillation_data=selfdistillation_data, 
    number_to_pick=number_to_pick, 
    device=device, 
    correct_label=False, 
    batch_size=100
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4455/4455 [01:05<00:00, 68.17it/s]


In [33]:
selected

[Data(x=[9, 9], edge_index=[2, 16], edge_attr=[16, 3], smiles='C(C(=O)C(=O)O)C(=O)O', y=[1]),
 Data(x=[10, 9], edge_index=[2, 18], edge_attr=[18, 3], smiles='C(C(C(=O)O)S)(C(=O)O)S', y=[1]),
 Data(x=[6, 9], edge_index=[2, 10], edge_attr=[10, 3], smiles='CC=CC(=O)O', y=[1]),
 Data(x=[6, 9], edge_index=[2, 10], edge_attr=[10, 3], smiles='C(=NO)C=NO', y=[1]),
 Data(x=[10, 9], edge_index=[2, 18], edge_attr=[18, 3], smiles='C(C(C(=O)O)O)(C(=O)O)O', y=[1]),
 Data(x=[9, 9], edge_index=[2, 16], edge_attr=[16, 3], smiles='C([C@H](C(=O)O)O)C(=O)O', y=[1]),
 Data(x=[4, 9], edge_index=[2, 6], edge_attr=[6, 3], smiles='C(=O)(N)O', y=[1]),
 Data(x=[6, 9], edge_index=[2, 10], edge_attr=[10, 3], smiles='C(=O)NNC=O', y=[1]),
 Data(x=[9, 9], edge_index=[2, 16], edge_attr=[16, 3], smiles='C(C(C(=O)O)O)C(=O)O', y=[1]),
 Data(x=[8, 9], edge_index=[2, 14], edge_attr=[14, 3], smiles='CC(C(=O)O)C(=O)O', y=[1]),
 Data(x=[12, 9], edge_index=[2, 22], edge_attr=[22, 3], smiles='C(/C(=C/C(=O)O)/C(=O)O)C(=O)O', y=[

In [34]:
label_analysis = np.array([data.y.numpy() for data in selected])

In [35]:
label_analysis.sum()/label_analysis.shape[0]

0.010000615119934083

In [36]:
# todo: ausbauen auf viele epochen und student - teacher übergang. außerdem begutachten ob positive oder negative label bevorzugt werden sollten