<a href="https://colab.research.google.com/github/GalJakob/Toxicity-prediction-WS/blob/main/Non_pretrained_gnn_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install Pytorch Geometric => Build Graph Neural Network
import torch
pytorch_version = f"torch-{torch.__version__}.html"
!pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install torch-geometric

In [None]:
# install RDKit => Handle Molecule Data
from logging import getLogger, StreamHandler, INFO

logger = getLogger(__name__)
logger.addHandler(StreamHandler())
logger.setLevel(INFO)

# !curl -Lo conda_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
# import conda_installer
# conda_installer.install() # takes ~ 4 minutes #TODO: pre-install it
# !/root/miniconda/bin/conda info -e

# !pip install --pre deepchem
# import deepchem
# deepchem.__version__
!pip install rdkit
import rdkit
logger.info("rdkit-{} installation finished!".format(rdkit.__version__))



rdkit-2023.03.2 installation finished!
rdkit-2023.03.2 installation finished!
INFO:__main__:rdkit-2023.03.2 installation finished!


In [None]:
#load data
import io
import pandas as pd
from google.colab import files,drive

AUGMENTED_CASE = "none augmented"  # can be "none augmented" / "only train augmented"/ "both augmented"
GNN = "GNN" # constant
dataset_name = "cardio" # change to cardio / tox21 / clintox
ds_test = dataset_name + "_test" # change to _test / _test_aug
ds_train = dataset_name + "_train_aug" # change to _train / _train_aug
path_test = f"/content/drive/MyDrive/workshop/datasets/test datasets/{ds_test}.csv" #data is at google drive
path_train = f"/content/drive/MyDrive/workshop/datasets/train datasets/{ds_train}.csv" #data is at google drive
drive.mount("/content/drive")

try: #getting data from drive
  test_data = pd.read_csv(path_test)
  train_data = pd.read_csv(path_train)

except: #uploading data instead from drive
  data = files.upload()
  train_data = io.BytesIO(data[ds_train])
  test_data = io.BytesIO(data[ds_test])

if ("_train_aug" in ds_train) and ("_test_aug" in ds_test):
  AUGMENTED_CASE = "both augmented"
elif ("_train_aug" in ds_train) and (not ("_test_aug" in ds_test)):
  AUGMENTED_CASE = "only train augmented"
elif ("_train" in ds_train) and ("_test_aug" in ds_test):
  print("ERROR, only test is augmented")

print(AUGMENTED_CASE)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
only train augmented


In [None]:
#convert smiles to graphs
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
import pandas as pd

mol_objs_train = [Chem.MolFromSmiles(SMILE) for SMILE in train_data["smiles"]]
train_data_by_mols = pd.DataFrame({"mol_objs":mol_objs_train, "label": train_data["label"]})

mol_objs = [Chem.MolFromSmiles(SMILE) for SMILE in test_data["smiles"]]
test_data_by_mols = pd.DataFrame({"mol_objs":mol_objs, "label": test_data["label"]})

print(test_data_by_mols["mol_objs"][0].GetNumAtoms())
#draw mols:
sample = train_data_by_mols.loc[5:8]
grid = Draw.MolsToGridImage(sample["mol_objs"],molsPerRow = 2,subImgSize=(400,400))
#grid



28


In [None]:
# functions for setting graph molecule features
# every node is atom, every edge is bond

# TODO:EXPLORE ATOM AND EDGE FEATURES THAT MIGHT INDICATE TOXICITY
import pandas as pd
import numpy as np
from rdkit import Chem
import torch
import torch_geometric

def get_nodes_features(mol):
  ''' gets molecule and returns all of it's atoms' features,  can vary from 9 features '''
  features_of_all_nodes = []

  for atom in mol.GetAtoms():
    node_feats = []

    node_feats.append(atom.GetAtomicNum())
    node_feats.append(atom.GetChiralTag())
    node_feats.append(atom.GetDegree())
    node_feats.append(atom.GetFormalCharge())
    node_feats.append(atom.GetHybridization())
    node_feats.append(atom.GetIsAromatic())
    node_feats.append(atom.GetTotalNumHs())
    node_feats.append(atom.GetNumRadicalElectrons())
    node_feats.append(atom.IsInRing())

    features_of_all_nodes.append(node_feats)

  features_of_all_nodes = np.asarray(features_of_all_nodes)
  return torch.tensor(features_of_all_nodes, dtype=torch.float)


def get_edges_features(mol):
  ''' gets molecule and returns all of it's bonds' features,  can vary from 2 features '''
  features_of_all_edges = []

  for bond in mol.GetBonds():
    edge_feats = []
    edge_feats.append(bond.IsInRing())
    edge_feats.append(bond.GetBondTypeAsDouble())
    features_of_all_edges += [edge_feats, edge_feats] # twice because this is undirected graph
  features_of_all_edges = np.asarray(features_of_all_edges)
  return torch.tensor(features_of_all_edges, dtype=torch.float)

def get_mat_edges(mol):
  ''' gets molecule and returns list of all it's edges by atoms idx '''
  edge_indices = []
  for bond in mol.GetBonds():
    idx_of_atom1 = bond.GetBeginAtomIdx()
    idx_of_atom2 = bond.GetEndAtomIdx()
    edge_indices += [[idx_of_atom1, idx_of_atom2], [idx_of_atom2, idx_of_atom1]]
  edge_indices = torch.tensor(edge_indices,dtype=torch.int)
  edge_indices = edge_indices.t().view(2, -1) #transpose and reshape
  return edge_indices


In [None]:
# get/create the custom dataset from google drive
#creating custom dataset for the GNN which includes mol properties
import pandas as pd
from rdkit import Chem
import torch
import torch_geometric
from torch_geometric.data import DataLoader,Dataset, Data
import numpy as np
import io
import warnings
warnings.filterwarnings("ignore")

node_feats = None
edges_idxs = None
edge_feats = None
mol_graph_data = None
custom_train_data=[]
custom_test_data=[]
NUM_GRAPHS_PER_BATCH = 128

#train_aug_gnn = AUGMENTED_CASE == "none augmented" and "not_augmented" or "augmented"
#test_aug_gnn = AUGMENTED_CASE == "none augmented" and "not_augmented" or (AUGMENTED_CASE == "only train augmented" and "not_augmented" ) or "augmented"

for idx in range(len(train_data_by_mols["mol_objs"])):
  mol_obj = train_data_by_mols["mol_objs"][idx]

  mol_graph_data = Data(x=get_nodes_features(mol_obj),
                        edge_index=get_mat_edges(mol_obj),
                        edge_attr=get_edges_features(mol_obj),
                        y=torch.tensor(train_data["label"][idx]),
                        smiles=train_data["smiles"][idx]
                        )
  custom_train_data.append(mol_graph_data)

print(custom_train_data[0].x.shape[1]) #features size of each node
for idx in range(len(test_data_by_mols["mol_objs"])):
  mol_obj = test_data_by_mols["mol_objs"][idx]

  mol_graph_data = Data(x=get_nodes_features(mol_obj),
                        edge_index=get_mat_edges(mol_obj),
                        edge_attr=get_edges_features(mol_obj),
                        y=torch.tensor(test_data["label"][idx]),
                        smiles=test_data["smiles"][idx]
                        )
  custom_test_data.append(mol_graph_data)

custom_train_loader = DataLoader(custom_train_data,batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
custom_test_loader = DataLoader(custom_test_data,batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)


9


In [None]:
#testing if same originated smiles have same molecule object
# for idx1 in range(len(mol_objs_train)):
#   mol1 = mol_objs_train[idx1]
#   for idx2 in range(len(mol_objs_train)):
#     mol2 = mol_objs_train[idx2]
#     if mol1.HasSubstructMatch(mol2) and mol2.HasSubstructMatch(mol1): #smiles are same molecule
#       print("idx1",idx1)
#       print("idx2",idx2)
#       print(custom_train_data[idx1])
#       print(custom_train_data[idx2])
#       print(mol2)
#       # grid = Draw.MolsToGridImage([mol1,mol2],molsPerRow = 2,subImgSize=(400,400))
#       # grid
#       break

print(custom_train_data[234].x)
print(custom_train_data[150].x)
# idx1 234
# idx2 150
# Data(x=[23, 9], edge_index=[2, 50], edge_attr=[50, 2], y=1, smiles='c1ccc2c(c1)N(C[C@H](C)C[NH+](C)C)c1c(ccc(C#N)c1)S2')
# Data(x=[23, 9], edge_index=[2, 50], edge_attr=[50, 2], y=1, smiles='C[NH+](C[C@H](CN1c2ccccc2Sc2ccc(C#N)cc21)C)C')
# #draw mols:

# sample = train_data_by_mols.loc[5:8]
grid = Draw.MolsToGridImage([mol_objs_train[234],mol_objs_train[150]],molsPerRow = 2,subImgSize=(400,400))
grid

In [None]:
# creating model (simple)
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch_geometric.nn import TransformerConv, GATConv, TopKPooling, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.nn.conv.x_conv import XConv
#torch.manual_seed(42)

class GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN, self).__init__()
        num_classes = 2
        # NOTE: lower emb size in augmented results in much faster training, for size 1024 until epoch 100 model predicts only 0
        embedding_size =8 #TODO: adjust embeeding size and not by rule of thumb

        # GNN layers
        self.conv1 = GATConv(feature_size, embedding_size, heads=3,dropout=0.1)
        self.head_transform1 = Linear(embedding_size*3, embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.8)
        self.conv2 = GATConv(embedding_size, embedding_size, heads=3,dropout=0.1)
        self.head_transform2 = Linear(embedding_size*3, embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio=0.5)
        self.conv3 = GATConv(embedding_size, embedding_size, heads=3,dropout=0.1)
        self.head_transform3 = Linear(embedding_size*3, embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio=0.2)

        # Linear layers
        self.linear1 = Linear(embedding_size*2, 1024)
        self.linear2 = Linear(1024, num_classes)

    def forward(self, x, edge_attr, edge_index, batch_index,epoch):
        # First block
        # if epoch %10 == 0:
        #   print(f"before conv1 {x[0]} \n")

        x = self.conv1(x, edge_index)

        # if epoch %10 == 0:
        #   print(f"after conv1 {x} \n")
        #   print(f"x[0] {x[0]} \n")

        x = self.head_transform1(x)

        x, edge_index, edge_attr, batch_index, _, _ = self.pool1(x,
                                                        edge_index,
                                                        None,
                                                        batch_index)
        x1 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)

        # Second block
        x = self.conv2(x, edge_index)
        x = self.head_transform2(x)
        x, edge_index, edge_attr, batch_index, _, _ = self.pool2(x,
                                                        edge_index,
                                                        None,
                                                        batch_index)
        x2 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)

        # Third block
        x = self.conv3(x, edge_index)
        x = self.head_transform3(x)
        x, edge_index, edge_attr, batch_index, _, _ = self.pool3(x,
                                                        edge_index,
                                                        None,
                                                        batch_index)
        x3 = torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1)

        # Concat pooled vectors
        x = x1 + x2 + x3
        # listX=[]
        # for i in x:
        #   if any([(i == item).all() for item in listX]):
        #     print("found \n")
        #     print(i)
        #   else:
        #     listX.append(torch.sort(i))
        #     print(sorted(i))
        # Output block
        x = self.linear1(x).relu()
        # if epoch %10 == 0:
        #   print(f"linear1 {x} \n")
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.linear2(x)

        return x

In [None]:
import torch
from torch_geometric.data import DataLoader
from sklearn.metrics import confusion_matrix, f1_score, \
    accuracy_score, precision_score, recall_score, roc_auc_score,precision_recall_curve,auc
import numpy as np
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



#%% Loading the model
sh = torch.Tensor(custom_train_data[0].x.shape[1])
sh.to(device)
model = GNN(feature_size=custom_train_data[0].x.shape[1])
model = model.to(device)

#%% Loss and Optimizer
weights = torch.tensor([1,1], dtype=torch.float32).to(device)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)


#%% Prepare training
NUM_GRAPHS_PER_BATCH = 32

def train(epoch):
    # Enumerate over the data
    all_preds = []
    all_labels = []
    for _, batch in enumerate(tqdm(custom_train_loader)):
        # Use GPU
        batch = batch.to(device)
        # Reset gradients
        optimizer.zero_grad()
        # Passing the node features and the connection info
        pred = model(batch.x.float(),
                                batch.edge_attr.float(),
                                batch.edge_index,
                                batch.batch,
                                epoch)
        # Calculating the loss and gradients
        loss = loss_fn(pred,batch.y)
        loss.backward()
        # Update using the gradients
        optimizer.step()

        all_preds.append(np.argmax(pred.cpu().detach().numpy(), axis=1))
        all_labels.append(batch.y.cpu().detach().numpy())
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    calculate_metrics(all_preds, all_labels, epoch, "train")
    return loss

def test(epoch):
    all_preds = []
    all_labels = []
    for batch in custom_test_loader:
        batch.to(device)
        pred = model(batch.x.float(),
                        batch.edge_attr.float(),
                        batch.edge_index,
                        batch.batch,epoch)
        loss = loss_fn(pred, batch.y)
        all_preds.append(np.argmax(pred.cpu().detach().numpy(), axis=1))
        all_labels.append(batch.y.cpu().detach().numpy())

    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    calculate_metrics(all_preds, all_labels, epoch, "test")
    return loss


def calculate_metrics(y_pred, y_true, epoch, type):
  #TODO:MODEL PREDICTS ALL MOLS AS 1
    # for pred in y_pred:
    #   if pred ==0:
    #     print("predicteddddddddddddd 0 ")
    tn, fp, fn, tp = confusion_matrix(y_true,y_pred).ravel()
    print(f"tn: {tn} \n ")
    print(f"fp: {fp} \n ")
    print(f"fn: {fn} \n ")
    print(f"tp: {tp} \n ")
    print(f"\n Confusion matrix: \n {confusion_matrix(y_true,y_pred)}")
    print(f"F1 Score: {f1_score(y_true,y_pred )}")
    print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
    print(f"Precision: {precision_score(y_true,y_pred)}")
    print(f"Recall: {recall_score(y_true, y_pred)}")
    try:
        roc = roc_auc_score(y_true, y_pred)
        print(f"ROC AUC: {roc}")
        precision, recall, _ = precision_recall_curve(y_true,  y_pred)
        pr_auc = auc(recall, precision)
        print(f"PR AUC: {pr_auc}")
    except:
        print(f"ROC AUC: notdefined")

# %% Run the training
def run_all_train():
    for epoch in range(200):
        # Training
        model.train()
        loss = train(epoch=epoch)
        loss = loss.detach().cpu().numpy()
        print(f"Epoch {epoch} | Train Loss {loss}")
        # Testing
        model.eval()
        if epoch % 5 == 0:
            loss = test(epoch=epoch)
            loss = loss.detach().cpu().numpy()
            print(f"Epoch {epoch} | Test Loss {loss}")
        scheduler.step()
    print("Done.")

run_all_train()


  0%|          | 0/32 [00:00<?, ?it/s]

[tensor(-1.9453, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(-0.9860, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(-0.9209, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(-0.7685, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(-0.6665, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(-0.1691, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(-0.1365, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(-0.0737, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(-0.0621, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(0.2056, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(0.2364, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(0.4529, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(0.5287, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(0.7982, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(1.6788, device='cuda:0', grad_fn=<UnbindBackward0>), tensor(2.3759, device='cuda:0', grad_fn=<UnbindBackward0>)]





AttributeError: ignored

In [None]:
# creating model (complicated, TODO:TEST IT)
import torch
import torch.nn.functional as F
from torch.nn import Linear, BatchNorm1d, ModuleList
from torch_geometric.nn import TransformerConv, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
torch.manual_seed(42)

class GNN(torch.nn.Module):
    def __init__(self, feature_size, model_params):
        super(GNN, self).__init__()
        embedding_size = model_params["model_embedding_size"]
        n_heads = model_params["model_attention_heads"]
        self.n_layers = model_params["model_layers"]
        dropout_rate = model_params["model_dropout_rate"]
        top_k_ratio = model_params["model_top_k_ratio"]
        self.top_k_every_n = model_params["model_top_k_every_n"]
        dense_neurons = model_params["model_dense_neurons"]
        edge_dim = model_params["model_edge_dim"]

        self.conv_layers = ModuleList([])
        self.transf_layers = ModuleList([])
        self.pooling_layers = ModuleList([])
        self.bn_layers = ModuleList([])

        # Transformation layer
        self.conv1 = TransformerConv(feature_size,
                                    embedding_size,
                                    heads=n_heads,
                                    dropout=dropout_rate,
                                    edge_dim=edge_dim,
                                    beta=True)

        self.transf1 = Linear(embedding_size*n_heads, embedding_size)
        self.bn1 = BatchNorm1d(embedding_size)

        # Other layers
        for i in range(self.n_layers):
            self.conv_layers.append(TransformerConv(embedding_size,
                                                    embedding_size,
                                                    heads=n_heads,
                                                    dropout=dropout_rate,
                                                    edge_dim=edge_dim,
                                                    beta=True))

            self.transf_layers.append(Linear(embedding_size*n_heads, embedding_size))
            self.bn_layers.append(BatchNorm1d(embedding_size))
            if i % self.top_k_every_n == 0:
                self.pooling_layers.append(TopKPooling(embedding_size, ratio=top_k_ratio))


        # Linear layers
        self.linear1 = Linear(embedding_size*2, dense_neurons)
        self.linear2 = Linear(dense_neurons, int(dense_neurons/2))
        self.linear3 = Linear(int(dense_neurons/2), 1)

    def forward(self, x, edge_attr, edge_index, batch_index):
        # Initial transformation
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(self.transf1(x))
        x = self.bn1(x)

        # Holds the intermediate graph representations
        global_representation = []

        for i in range(self.n_layers):
            x = self.conv_layers[i](x, edge_index, edge_attr)
            x = torch.relu(self.transf_layers[i](x))
            x = self.bn_layers[i](x)
            # Always aggregate last layer
            if i % self.top_k_every_n == 0 or i == self.n_layers:
                x , edge_index, edge_attr, batch_index, _, _ = self.pooling_layers[int(i/self.top_k_every_n)](
                    x, edge_index, edge_attr, batch_index
                    )
                # Add current representation
                global_representation.append(torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1))

        x = sum(global_representation)

        # Output block
        x = torch.relu(self.linear1(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = torch.relu(self.linear2(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = self.linear3(x)

        return x

In [None]:
# loading model
num_of_node_features = custom_train_data[0].x.shape[1]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
params = {
    "model_edge_dim":custom_train_data[0].edge_attr.shape[1],
    "batch_size": 128,
    "learning_rate": 0.01,
    "weight_decay": 0.0001,
    "sgd_momentum": 0.8,
    "scheduler_gamma": 0.8,
    "pos_weight": 1.3,
    "model_embedding_size": 64,
    "model_attention_heads": 3,
    "model_layers": 4,
    "model_dropout_rate": 0.2,
    "model_top_k_ratio": 0.5,
    "model_top_k_every_n": 1,
    "model_dense_neurons": 256
}

model = GNN(feature_size = num_of_node_features,model_params=params)
model = model.to(device)

In [None]:
#creating loss and optimizer

In [None]:
#%% imports
import torch
from torch_geometric.data import DataLoader
from sklearn.metrics import confusion_matrix, f1_score, \
    accuracy_score, precision_score, recall_score, roc_auc_score
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def train_one_epoch(epoch, model, train_loader, optimizer, loss_fn):
    # Enumerate over the data
    all_preds = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for _, batch in enumerate(tqdm(train_loader)):
        # Use GPU
        batch.to(device)
        # Reset gradients
        optimizer.zero_grad()
        # Passing the node features and the connection info
        pred = model(batch.x.float(),
                                batch.edge_attr.float(),
                                batch.edge_index,
                                batch.batch)
        # Calculating the loss and gradients
        loss = loss_fn(torch.squeeze(pred), batch.y.float())
        loss.backward()
        optimizer.step()
        # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_labels.append(batch.y.cpu().detach().numpy())
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    calculate_metrics(all_preds, all_labels, epoch, "train")
    return running_loss/step

def test(epoch, model, test_loader, loss_fn):
    all_preds = []
    all_preds_raw = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for batch in test_loader:
        batch.to(device)
        pred = model(batch.x.float(),
                        batch.edge_attr.float(),
                        batch.edge_index,
                        batch.batch)
        loss = loss_fn(torch.squeeze(pred), batch.y.float())

         # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_preds_raw.append(torch.sigmoid(pred).cpu().detach().numpy())
        all_labels.append(batch.y.cpu().detach().numpy())

    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    print(all_preds_raw[0][:10])
    print(all_preds[:10])
    print(all_labels[:10])
    calculate_metrics(all_preds, all_labels, epoch, "test")
    log_conf_matrix(all_preds, all_labels, epoch)
    return running_loss/step

def log_conf_matrix(y_pred, y_true, epoch):
    # Log confusion matrix as image
    cm = confusion_matrix(y_pred, y_true)
    classes = ["0", "1"]
    df_cfm = pd.DataFrame(cm, index = classes, columns = classes)
    plt.figure(figsize = (10,7))
    cfm_plot = sns.heatmap(df_cfm, annot=True, cmap='Blues', fmt='g')
    cfm_plot.figure.savefig(f'data/images/cm_{epoch}.png')


def calculate_metrics(y_pred, y_true, epoch, type):
    print(f"\n Confusion matrix: \n {confusion_matrix(y_pred, y_true)}")
    print(f"F1 Score: {f1_score(y_true, y_pred)}")
    print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    print(f"Precision: {prec}")
    print(f"Recall: {rec}")

    try:
        roc = roc_auc_score(y_true, y_pred)
        print(f"ROC AUC: {roc}")

    except:
      return



# Prepare training
weight = torch.tensor([params["pos_weight"]], dtype=torch.float32).to(device)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=weight)
optimizer = torch.optim.SGD(model.parameters(),
                                    lr=params["learning_rate"],
                                    momentum=params["sgd_momentum"],
                                    weight_decay=params["weight_decay"])
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["scheduler_gamma"])


model.train()
for epoch in range(10):
  loss = train_one_epoch(epoch, model, custom_train_loader, optimizer, loss_fn)
  print(f"Epoch {epoch} | Train Loss {loss}")
