In [2]:
import pandas as pd
import numpy as np
import random
from rdkit import Chem
from snn_model import SNNet, device, train_snn, test_snn, get_loss_fn, num_hidden
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from snntorch import spikegen, surrogate
import matplotlib.pyplot as plt
from utils import load_dataset_df, fp_generator
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score, f1_score, precision_score
from sklearn.utils.class_weight import compute_class_weight

import os
import tempfile
from ray import train, tune
from ray.tune.schedulers import ASHAScheduler
from ray.train import Checkpoint, report

#### Load DataFrame

In [5]:
files = ['tox21.csv','sider.csv', 'BBBP.csv']
dt_file = files[2]

df, targets = load_dataset_df(filename=dt_file)
print(targets)

target_name = targets[0]
df = df[[target_name, 'smiles']].dropna()
print(df.size)
df.drop_duplicates()
print(df.size)

['p_np']
4100
4100


#### SMILE to Fingerprint

In [3]:
fp_types = [['morgan', 1024], ['maccs', 167], ['RDKit', 1024]]
fp_type, num_bits = fp_types[1]
print(fp_type, '-', num_bits)
num_rows = len(df)
fp_array = np.zeros((num_rows, num_bits))
target_array = np.zeros((num_rows, 1))
i = 0

img = None
# Smile to Fingerprint of size {num_bits}
fp_gen = fp_generator(fp_type)
for idx, row in df.iterrows():
    mol = Chem.MolFromSmiles(row['smiles'])
    #TODO: sanitize molecules to remove the warnings (?)
    
    if mol is not None:
        fingerprint = fp_gen(mol)

        fp_array[i] = np.array(fingerprint)
        target_array[i] = row[target_name]
        i += 1
target_array = target_array.ravel()

maccs - 167


[19:14:49] Explicit valence for atom # 1 N, 4, is greater than permitted
[19:14:49] Explicit valence for atom # 6 N, 4, is greater than permitted
[19:14:50] Explicit valence for atom # 6 N, 4, is greater than permitted
[19:14:50] Explicit valence for atom # 11 N, 4, is greater than permitted
[19:14:51] Explicit valence for atom # 12 N, 4, is greater than permitted
[19:14:51] Explicit valence for atom # 5 N, 4, is greater than permitted
[19:14:51] Explicit valence for atom # 5 N, 4, is greater than permitted
[19:14:51] Explicit valence for atom # 5 N, 4, is greater than permitted
[19:14:51] Explicit valence for atom # 5 N, 4, is greater than permitted
[19:14:51] Explicit valence for atom # 5 N, 4, is greater than permitted
[19:14:51] Explicit valence for atom # 5 N, 4, is greater than permitted


In [4]:
# Create Torch Dataset
dtype = torch.float32
fp_tensor = torch.tensor(fp_array, dtype=dtype)
target_tensor = torch.tensor(target_array, dtype=dtype).long()

dataset = TensorDataset(fp_tensor, target_tensor)

#### Train/Test Split

In [None]:
generator = torch.Generator().manual_seed(1)
train, val, test = random_split(dataset, [0.8, 0.1, 0.1], generator=generator) #experimentar 0.9, 0.1

_, train_label = train[:]
_, val_label = val[:]
_, test_label = test[:]


#### Loss Function

In [None]:
loss_types = ['cross_entropy', 'rate_loss', 'temporal_loss']
loss_type = loss_types[1]
print(loss_type)

use_weights = True

#Loss Function
if use_weights:
    class_weights = compute_class_weight(class_weight='balanced', classes=np.array([0, 1]), y=np.array(train_label))
    class_weights = torch.tensor(class_weights, dtype=torch.float)
    class_weights[0] += 1
    weighted = 'class_weights'
else: 
    class_weights=None
    weighted = ''

loss_fn = get_loss_fn(loss_type=loss_type, class_weights=class_weights)

cross_entropy


#### Ray Tune

In [9]:
random.seed(1)

In [None]:
def tune_train(config,  checkpoint_dir=None):
    # Extract the hyperparameters to be used by the model
    net = Net(num_inputs=num_bits,num_steps=config["time_steps"], spike_grad=None, use_l2=True).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=(0.9, 0.999), weight_decay=0)

    train_loader = DataLoader(train, batch_size=config["batch_size"], shuffle=True)
    val_loader = DataLoader(val, batch_size=config["batch_size"], shuffle=True)
    test_loader = DataLoader(test, batch_size=config["batch_size"], shuffle=True)


    for i in range(config["num_iterations"]):
        net, _ = train_snn(
                net=net,
                optimizer=optimizer,
                device=device,
                num_epochs=config["num_epochs"],
                train_loader=train_loader,
                val_loader=val_loader,
                loss_type=config["loss_type"],
                loss_fn=config["loss_fn"],
                dtype=config["dtype"]
        )
        all_preds, all_targets = test_snn(net, device, test_loader)
        auc_roc = roc_auc_score(all_targets, all_preds)

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint = None
            if (i + 1) % 5 == 0:
                # This saves the model to the trial directory
                torch.save(
                    net.state_dict(),
                    os.path.join(temp_checkpoint_dir, "model.pth")
                )
                checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            # Send the current training result back to Tune
            #suggestion: auc * acc
            report({"mean_auc": auc_roc}, checkpoint=checkpoint)


In [None]:
config = {
    "loss_type": loss_type,
    "loss_fn": loss_fn,
    "lr": tune.grid_search([1e-6, 1e-5, 1e-4, 1e-3]),
    #"batch_size": tune.grid_search([16, 32, 64]),
    "batch_size": tune.grid_search([32, 64]),
    "num_epochs": 1000, # mudar para 1000 com early stopping
    "dtype": torch.float32,
    "num_iterations": 1,
    "time_steps": tune.grid_search([5, 10, 15, 20])
}


In [None]:
tuner = tune.Tuner(
    tune_train,
    tune_config=tune.TuneConfig(
        num_samples=50,
        scheduler=ASHAScheduler(metric="mean_auc", mode="max"),
    ),
    param_space=config,

)
results = tuner.fit()

0,1
Current time:,2024-11-13 21:04:12
Running for:,01:12:38.04
Memory:,20.1/31.9 GiB

Trial name,status,loc,batch_size,lr,iter,total time (s),mean_auc
tune_train_afbc1_00018,RUNNING,127.0.0.1:28548,32,0.0001,3.0,3205.32,0.783437
tune_train_afbc1_00026,RUNNING,127.0.0.1:17544,32,0.0001,2.0,1996.93,0.810223
tune_train_afbc1_00027,RUNNING,127.0.0.1:30408,64,0.0001,3.0,2342.15,0.791227
tune_train_afbc1_00040,RUNNING,127.0.0.1:27832,32,1e-05,1.0,833.49,0.822327
tune_train_afbc1_00041,RUNNING,127.0.0.1:26652,64,1e-05,1.0,528.329,0.81999
tune_train_afbc1_00042,RUNNING,127.0.0.1:26712,32,0.0001,,,
tune_train_afbc1_00043,RUNNING,127.0.0.1:29900,64,0.0001,,,
tune_train_afbc1_00044,RUNNING,127.0.0.1:30344,32,0.001,,,
tune_train_afbc1_00045,RUNNING,127.0.0.1:26764,64,0.001,,,
tune_train_afbc1_00046,RUNNING,127.0.0.1:29220,32,0.01,,,


  return torch.load(io.BytesIO(b))
[36m(pid=26476)[0m   return torch.load(io.BytesIO(b))
[36m(tune_train pid=26476)[0m Stack (most recent call first):
[36m(tune_train pid=26476)[0m   File "<frozen importlib._bootstrap>", line 488 in _call_with_frames_removed
[36m(tune_train pid=26476)[0m   File "<frozen importlib._bootstrap_external>", line 1288 in create_module
[36m(tune_train pid=26476)[0m   File "<frozen importlib._bootstrap>", line 813 in module_from_spec
[36m(tune_train pid=26476)[0m   File "<frozen importlib._bootstrap>", line 921 in _load_unlocked
[36m(tune_train pid=26476)[0m   File "<frozen importlib._bootstrap>", line 1331 in _find_and_load_unlocked
[36m(tune_train pid=26476)[0m   File "<frozen importlib._bootstrap>", line 1360 in _find_and_load
[36m(tune_train pid=26476)[0m   File "c:\Users\knsve\Desktop\MEI\Tese\torch\pt_venv\Lib\site-packages\onnx\__init__.py", line 77 in <module>
[36m(tune_train pid=26476)[0m   File "<frozen importlib._bootstrap>", lin