In [None]:
import pandas as pd
import numpy as np
import random
from rdkit import Chem
from snn_model import Net, device, batch_size, num_steps, train_net, test_net, get_loss_fn, num_hidden
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from snntorch import spikegen, surrogate
import matplotlib.pyplot as plt
from utils import load_dataset_df, fp_generator
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score, f1_score, precision_score

#### Load DataFrame

In [None]:
files = ['tox21.csv','sider.csv', 'BBBP.csv']
dt_file = files[2]

df, targets = load_dataset_df(filename=dt_file)
print(targets)

target_name = targets[0]
df = df[[target_name, 'smiles']].dropna()


#### Molecule to Fingerprint Vizualization

In [None]:
from rdkit.Chem import Draw, AllChem
from IPython.display import display


row = df.iloc[random.randint(0, len(df))]

mol = Chem.MolFromSmiles(row['smiles'])

if mol is not None:
    img = Draw.MolToImage(mol)
print(row['smiles'])

fpgen = fp_generator('morgan')
ao = AllChem.AdditionalOutput()
ao.CollectBitInfoMap()

fp = fpgen(mol,additionalOutput=ao)
bi = ao.GetBitInfoMap()

fp = np.array(fp)

fig, axs = plt.subplots(1, 5, figsize=(20, 20))
fig2, axs2 = plt.subplots(1, 5, figsize=(20, 20))

for i, id in enumerate(list(bi.keys())[:5]):
    print(id, bi[id])
    mfp2_svg = Draw.DrawMorganBit(mol, bitId=id, bitInfo=bi)

    atoms = [info[0] for info in bi[id]]
    colors = {idx: (1, 1, 0) for idx in atoms}

    img2 = Draw.MolToImage(mol, highlightAtoms=atoms, highlightAtomColors=colors)
    
    axs[i].imshow(mfp2_svg)
    axs2[i].imshow(img2)


    #display(img)




#### SMILE to Fingerprint

In [None]:
fp_types = [['morgan', 1024], ['maccs', 167], ['RDKit', 1024]]
fp_type, num_bits = fp_types[0]
print(fp_type, '-', num_bits)
num_rows = len(df)
fp_array = np.zeros((num_rows, num_bits))
target_array = np.zeros((num_rows, 1))
i = 0

img = None
# Smile to Fingerprint of size {num_bits}
fp_gen = fp_generator(fp_type)
for idx, row in df.iterrows():
    mol = Chem.MolFromSmiles(row['smiles'])
    #TODO: sanitize molecules to remove the warnings (?)
    
    if mol is not None:
        fingerprint = fp_gen(mol)

        fp_array[i] = np.array(fingerprint)
        target_array[i] = row[target_name]
        i += 1
target_array = target_array.ravel()

In [None]:
# Create Torch Dataset
dtype = torch.float32
fp_tensor = torch.tensor(fp_array, dtype=dtype)
target_tensor = torch.tensor(target_array, dtype=dtype).long()

dataset = TensorDataset(fp_tensor, target_tensor)
""" 
generator = torch.Generator().manual_seed(iter + 1)
train, test = random_split(dataset, [0.7, 0.3], generator=generator)

_, train_label = train[:]
_, test_label = test[:]


print("positive labels in the test set:", int(test_label.sum()))

# Load the Data
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
print(len(train), len(test)) 
"""

#### Loss Function + Optimizer

In [None]:
#------------------
#binary cross entropy
#racio do pos:neg
#BCEWithLogitsLoss: pos_weight parameter
# -----------------
#output of snn (for rate encoding) -> number

In [None]:
from sklearn.utils.class_weight import compute_class_weight

loss_types = ['cross_entropy', 'rate_loss', 'temporal_loss']
loss_type = loss_types[1]
print(loss_type)

#possibly move to train fn to always calculate weights on train split
use_weights = False


""" if use_weights:
    class_weights = compute_class_weight(class_weight='balanced', classes=np.array([0, 1]), y=np.array(train_label))
    class_weights = torch.tensor(class_weights, dtype=torch.float)
    weighted = 'class_weights'
else: 
    class_weights=None
    weighted = ''
    

loss_fn = get_loss_fn(loss_type=loss_type, class_weights=class_weights) """


In [None]:
num_epochs = 30
iterations = 30
#counter = 0

In [None]:
results = [[], [], [], [], [], []]
for iter in range(iterations):
    print(f"Iteration:{iter + 1}/{iterations}")
    random.seed(iter)
    #spike_grad=surrogate.fast_sigmoid()
    spike_grad=None
    net = Net(num_inputs=num_bits, spike_grad=spike_grad).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999), weight_decay=0)

    generator = torch.Generator().manual_seed(iter + 1)
    train, test = random_split(dataset, [0.7, 0.3], generator=generator)

    _, train_label = train[:]
    _, test_label = test[:]


    print("positive labels in the test set:", int(test_label.sum()))

    # Load the Data
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)

    #Loss Function
    if use_weights:
        class_weights = compute_class_weight(class_weight='balanced', classes=np.array([0, 1]), y=np.array(train_label))
        class_weights = torch.tensor(class_weights, dtype=torch.float)
        weighted = 'class_weights'
    else: 
        class_weights=None
        weighted = ''

    loss_fn = get_loss_fn(loss_type=loss_type, class_weights=class_weights)

    net, loss_hist = train_net(net, optimizer, num_steps, device, num_epochs, train_loader, loss_type, loss_fn, dtype)

    #fig = plt.figure(facecolor="w", figsize=(10, 5))
    #plt.plot(loss_hist)
    #plt.title("Loss Curve")
    #plt.xlabel("Iteration")
    #plt.ylabel("Loss")
    #plt.show()

    all_preds, all_targets = test_net(net, device, test_loader)

    accuracy = accuracy_score(all_targets, all_preds)
    auc_roc = roc_auc_score(all_targets, all_preds)
    tn, fp, fn, tp = confusion_matrix(all_targets, all_preds).ravel()
    sensitivity = tp/(tp + fn)
    specificity = tn/(tn + fp)
    f1 = f1_score(all_targets, all_preds)
    precision = precision_score(all_targets, all_preds)

    results[0].append(accuracy)
    results[1].append(auc_roc)
    results[2].append(sensitivity)
    results[3].append(specificity)
    results[4].append(f1)
    results[5].append(precision)


#### Smoothed Loss

In [None]:
from scipy.signal import savgol_filter
from scipy.ndimage import gaussian_filter1d
from statsmodels.nonparametric.smoothers_lowess import lowess

print(loss_hist[len(loss_hist) - 5:len(loss_hist)])

fig = plt.figure(facecolor="w", figsize=(10, 5))
#plt.plot(np.convolve(loss_hist, np.ones(30)/30, mode='valid'))
#plt.plot(savgol_filter(loss_hist, window_length=100, polyorder=3))
#plt.plot(lowess(loss_hist, np.arange(len(loss_hist)), frac=0.1)[:, 1])
plt.plot(gaussian_filter1d(loss_hist, sigma=6))
plt.axhline(y=1, color='r', linestyle='--', label='y = 1')
plt.title("Loss Curve")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()
plt.show()

#### Save Metrics

In [None]:
metrics_np = np.zeros(12)

for i, metric in enumerate(results):
    metrics_np[i*2] = np.round(np.mean(metric), 3)
    metrics_np[i*2+1] = np.round(np.std(metric), 3)

# Print Results
print(f"Accuracy:  {metrics_np[0]:.2f} ± {metrics_np[1]:.2f}")
print(f"AUC ROC: {metrics_np[2]:.2f} ± {metrics_np[3]:.2f}")
print(f"Sensitivity: {metrics_np[4]:.2f} ± {metrics_np[5]:.2f}")
print(f"Specificity: {metrics_np[6]:.2f} ± {metrics_np[7]:.2f}")

metric_names = ['Acc', 'AUC', 'Sn', 'Sp', 'F1', 'Precision']
metrics_np = metrics_np.reshape(1, -1)
columns = []
for name in metric_names:
    columns.extend([f'Mean {name}', f'Std {name}'])


df_metrics = pd.DataFrame(metrics_np, columns=columns)
dirname = dt_file.strip('.csv')

filename = f"results\\{dirname}\\{target_name}_{fp_type}_{num_bits}_{num_hidden}_{loss_type}_{weighted}.csv"
df_metrics.to_csv(filename, index=False)


#### Visualize the Network

In [None]:
#from torchviz import make_dot
#train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
#batch = next(iter(train_loader))
#data, target = batch  # Assuming batch is structured as (data, target)
#net = Net(num_inputs=num_bits, spike_grad=None).to(device)
#yhat = net(data)
#make_dot(yhat, params=dict(list(net.named_parameters()))).render("snn_torchviz", format="png")