# Graph Neural Network

## Importing

In [1]:
from selfdist_toolkit.pyg_tools import gnn_load, GIN_nn, execution, sd_utils, accuracy
import pandas as pd
import torch
import torch_geometric
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Loading aid_list

In [2]:
aid_list = pd.read_csv("results/random_forest/experiments_check/chem-desc_good-aid_1.csv").aid.to_numpy().astype(int)

In [3]:
aid_list

array([    884,     891,     899,     914,    1418,    1431,    1770,
          1771,    1795,  493073,  493102,  493177,  493191,  493240,
        588834,  651741,  651812,  651814,  686978,  687022,  720691,
        743036,  743040,  743065, 1053173, 1259381, 1346982])

In [4]:
# for one aid now
aid = aid_list[0]
random_state = 131313

## Select mode: smooth for 2 dim label and hard for 1-dim label

In [5]:
# mode = "smooth"
mode = "hard"

## Load pytorch data

In [6]:
whole_data = gnn_load.load_pyg_data_aid(aid=aid, label_type=mode, do_in_parallel=True)

In [7]:
whole_data[:4]

[Data(x=[19, 9], edge_index=[2, 42], edge_attr=[42, 3], smiles='CC1(C=CC2=C(O1)C3=C(C=CC(=C3)OC)NC2=O)C', y=[1]),
 Data(x=[10, 9], edge_index=[2, 20], edge_attr=[20, 3], smiles='C1CN=C(N1)SCC(=O)O', y=[1]),
 Data(x=[29, 9], edge_index=[2, 64], edge_attr=[64, 3], smiles='CC(=O)OCC(=O)[C@]1(CC[C@@H]2[C@@]1(CC(=O)[C@H]3[C@H]2CCC4=CC(=O)CC[C@]34C)C)O', y=[1]),
 Data(x=[25, 9], edge_index=[2, 56], edge_attr=[56, 3], smiles='C[C@]12CCC(=O)C=C1CC[C@@H]3[C@@H]2CC[C@]4([C@H]3CC[C@@]4(C(=O)CO)O)C', y=[1])]

In [8]:
# build labels for split
labels_hard = np.array([
    data.y.detach().cpu().numpy()
    for data in whole_data
]).flatten().astype(int)

In [9]:
labels_hard.shape

(9593,)

In [10]:
labels_hard.sum()

3274

In [11]:
# data splitting
for train_idx, test_idx in StratifiedShuffleSplit(n_splits=1, random_state=random_state, test_size=0.2).split(whole_data, labels_hard):
    break

In [12]:
labels_hard[train_idx].sum()/len(train_idx)

0.34128225175918686

In [13]:
labels_hard[test_idx].sum()/len(test_idx)

0.341323606044815

In [14]:
# create the data loader
dl_train = torch_geometric.loader.DataLoader([whole_data[idx] for idx in train_idx], batch_size=100)
dl_test = torch_geometric.loader.DataLoader([whole_data[idx] for idx in test_idx], batch_size=100)

## Instantiate model

In [15]:
# GNN model
model = GIN_nn.GIN_basic(1)

In [16]:
# loss
loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([len(train_idx)/labels_hard[train_idx].sum()]))

In [17]:
# device
device = torch.device('cuda')

In [18]:
device

device(type='cuda')

In [19]:
model = model.to(device)
loss = loss.to(device)

In [20]:
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## Train model

In [71]:
# training:
execution.training(model, dl_train, device, optimizer, loss)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [00:03<00:00, 22.52it/s]


0.6475654

## Testing model

In [72]:
# testing:
pred_hard = execution.predict(model, dl_test, device, reduce_to_hard_label=True)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 36.75it/s]


In [73]:
pred_hard.sum()/pred_hard.shape[0]

0.43772798332464824

In [58]:
for batch in dl_train:
    break

In [59]:
batch

DataBatch(x=[2366, 9], edge_index=[2, 5100], edge_attr=[5100, 3], smiles=[100], y=[100], batch=[2366], ptr=[101])

In [60]:
batch = batch.to(device)

In [61]:
pred = model(batch).flatten()

In [62]:
loss.to(device)

BCEWithLogitsLoss()

In [63]:
loss(pred, batch.y)

tensor(0.7008, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [64]:
pred_hard[:5]

array([0, 1, 1, 0, 1])

In [74]:
# true labels
pred_hard_true = np.array([data.y.numpy().astype(int) for data in [whole_data[idx] for idx in test_idx]]).flatten()

In [75]:
pred_hard_true[:5]

array([0, 1, 0, 0, 1])

In [76]:
# get accuracy scores
acc = accuracy.calculate_accuracies_1d(y_true=pred_hard_true, y_pred=pred_hard)

In [77]:
acc.to_dict()

{'accuracy_score': 0.7972902553413236,
 'balanced_accuracy': 0.8086089235674945,
 'roc_score': 0.8086089235674945,
 'precision': 0.8211151214872718,
 'recall': 0.7972902553413236}

## Self distillation elements

In [35]:
# define number_to_pick
number_to_pick = 100

In [36]:
# load data for self distillation
selfdistillation_data = sd_utils.generate_gnn_sd_data(aid=aid)

In [78]:
# get good elements to insert into training pool for student model
selected, remaining = sd_utils.self_distillation_procedure_1dim(
    model=model, 
    self_distillation_data=selfdistillation_data, 
    number_to_pick=number_to_pick, 
    device=device, 
    correct_label=False, 
    batch_size=100
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4455/4455 [01:02<00:00, 71.14it/s]


In [79]:
selected[:5]

[Data(x=[5, 9], edge_index=[2, 8], edge_attr=[8, 3], smiles='OS(=O)(=O)O', y=[1]),
 Data(x=[11, 9], edge_index=[2, 20], edge_attr=[20, 3], smiles='C(P(=O)(O)O)(P(=O)(O)O)(Cl)Cl', y=[1]),
 Data(x=[4, 9], edge_index=[2, 6], edge_attr=[6, 3], smiles='Cl[Sb](Cl)Cl', y=[1]),
 Data(x=[9, 9], edge_index=[2, 18], edge_attr=[18, 3], smiles='C1(=O)NC(=O)NC(=O)N1', y=[1]),
 Data(x=[5, 9], edge_index=[2, 8], edge_attr=[8, 3], smiles='CS(=O)(=O)O', y=[1])]

In [80]:
label_analysis = np.array([data.y.numpy() for data in selected])

In [81]:
label_analysis.sum()/label_analysis.shape[0]

1.4962436580390205e-07

In [82]:
selected2, remaining2 = sd_utils.self_distillation_procedure_1dim_advanced(
    model=model, 
    self_distillation_data=selfdistillation_data, 
    number_to_pick=number_to_pick, 
    pos_elem_to_pick_ratio=0.3,
    device=device, 
    correct_label=False, 
    batch_size=100
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4455/4455 [01:00<00:00, 73.64it/s]


In [83]:
label_analysis2 = np.array([data.y.numpy() for data in selected2])

In [84]:
label_analysis2.sum()/label_analysis2.shape[0]

0.29977272033691404

In [None]:
# todo: ausbauen auf viele epochen und student - teacher übergang. außerdem begutachten ob positive oder negative label bevorzugt werden sollten