In [None]:
import pandas as pd
import numpy as np
import random
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from utils import load_dataset_df, smile_to_fp, get_spiking_net, make_filename, calc_metrics
from csnn_model import get_loss_fn, bias
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score

#### Load DataFrame

In [None]:
files = ['tox21.csv','sider.csv']
dt_file = files[1]
dirname = dt_file.removesuffix('.csv')

df, targets = load_dataset_df(filename=dt_file)

In [None]:
if dirname == 'tox21':
    # SR-ARE
    target_name = targets[7]
elif dirname == 'sider':
    #Hepatobiliary disorders
    target_name = targets[0]

df = df[[target_name, 'smiles']].dropna()

#### SMILE to Fingerprint

In [None]:
fp_types = [['morgan', 1024], ['maccs', 167], ['RDKit', 1024]]
mix = True
fp_type, num_bits = fp_types[1]
if mix and fp_type == 'RDKit':
    num_bits = 512
elif mix and fp_type == 'morgan': # keep morgan as 2nd MF
    mix = False
fp_config = {"fp_type": fp_type,
             "num_bits": num_bits,
             "radius": 2,
             "fp_type_2": fp_types[0][0],
             "num_bits_2": 1024 - num_bits,
             "mix": mix,
             }
print(fp_type, '-', num_bits)
if mix:
   print(fp_config['fp_type_2'], '-', fp_config['num_bits_2']) 

In [None]:
dtype = torch.float32
split = "random"
dataset = None

fp_array, target_array = smile_to_fp(df, fp_config=fp_config, target_name=target_name)
# Create Torch Dataset
fp_tensor = torch.tensor(fp_array, dtype=dtype)
target_tensor = torch.tensor(target_array, dtype=dtype).long()
dataset = TensorDataset(fp_tensor, target_tensor)

#### Train Loop

In [None]:
net_type = 'CSNN'
spike_grad = None
beta = 0.95
loss_type = 'count_loss'

net_config = {"conv_num": 1 if dirname == 'tox21' else 2,
              "input_size": 1024 if fp_config['mix'] else num_bits,
              "time_steps": 10,
              "spike_grad": spike_grad,
              "beta": beta,
              "encoding": 'rate',
              "bias": bias,
              "out_num": 2
              }

print(net_config)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lr=1e-4
iterations = 30
weight_decay = 0
optim_type = 'Adam'
batch_size = 16

train_config = {"num_epochs": 20,
                "batch_size": batch_size,
                "device": device,
                "loss_type": loss_type,
                "loss_fn": None,
                'dtype': dtype,
                'num_steps': net_config['time_steps'],
                'val_net': None,
                }
print(device)
print(train_config)

drop_last = True
pin_memory = device == "cuda"
save = True
results = [[], [], [], [], [], []]

In [None]:
for iter in range(iterations):
    print(f"Iteration:{iter + 1}/{iterations}")
    seed = iter + 1
    print(f"Seed:{seed}")
    random.seed(seed)

    net, train_net, val_net, test_net = get_spiking_net(net_config)
    net = net.to(device)
    train_config['val_net'] = val_net
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)

    # DATA SPLIT
    generator = torch.Generator().manual_seed(int(seed))
    train, val, test = random_split(dataset, [0.8, 0.1, 0.1], generator=generator)
    _, train_label = train[:]
    _, val_label = val[:]
    _, test_label = test[:]
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=pin_memory, drop_last=drop_last)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, pin_memory=pin_memory)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=False, pin_memory=pin_memory)

    # LOSS FN
    class_weights = compute_class_weight(class_weight='balanced', classes=np.array([0, 1], dtype=np.int8), y=np.array(train_label, dtype=np.int8))
    class_weights = torch.tensor(class_weights, dtype=torch.float, device=device)
    train_config["loss_fn"] = get_loss_fn(loss_type=loss_type, class_weights=class_weights)

    # TRAINING
    net, loss_hist, val_acc_hist, val_auc_hist, net_list = train_net(net=net, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, train_config=train_config, net_config=net_config)
    
    # TESTING
    model = net
    best_test_auc = 0
    best_epoch = 0
    for index, model_dict in enumerate(net_list):
        model.load_state_dict(model_dict)
        model.to(device)
        all_preds2, all_targets2 = test_net(model, device, test_loader, train_config)
        auc_roc_test = roc_auc_score(all_targets2, all_preds2)
        if auc_roc_test > best_test_auc:
            best_test_auc, best_epoch = (auc_roc_test, index)

    print(best_epoch, best_test_auc)
    model.load_state_dict(net_list[best_epoch])
    filename = make_filename(dirname, target_name, net_type, fp_config, lr, weight_decay, optim_type, net_config, train_config, model, model = True)
    model_name = filename.removesuffix('.csv') + f"_seed-{seed}" +'.pth'
    torch.save(model.state_dict(), model_name)
    print(filename)
    all_preds, all_targets = test_net(model, device, test_loader, train_config)
    calc_metrics(results, all_preds=all_preds, all_targets=all_targets)

#### Save Metrics

In [None]:
metrics_np = np.zeros(12)

for i, metric in enumerate(results):
    metrics_np[i*2] = np.round(np.mean(metric), 3)
    metrics_np[i*2+1] = np.round(np.std(metric), 3)

# Print Results
print(f"Accuracy:  {metrics_np[0]:.3f} ± {metrics_np[1]:.3f}")
print(f"AUC ROC: {metrics_np[2]:.3f} ± {metrics_np[3]:.3f}")
print(f"Sensitivity: {metrics_np[4]:.3f} ± {metrics_np[5]:.3f}")
print(f"Specificity: {metrics_np[6]:.3f} ± {metrics_np[7]:.3f}")

In [None]:
metric_names = ['Acc', 'AUC', 'Sn', 'Sp', 'F1', 'Precision']
metrics_np = metrics_np.reshape(1, -1)
columns = []
for name in metric_names:
    columns.extend([f'Mean {name}', f'Std {name}'])


df_metrics = pd.DataFrame(metrics_np, columns=columns)

filename = make_filename(dirname, target_name, net_type, fp_config, lr, weight_decay, optim_type, net_config, train_config, model)
if save: df_metrics.to_csv(filename, index=False)

print(filename)