In [None]:
"""
Contains functions for embeding train set, training and testing a PyTorch model.
"""

import os
import torch
import numpy as np
import h5py
import pandas as pd
from torch_geometric.data import Data, DataLoader as PyGDataLoader
from tqdm import tqdm
from rdkit import Chem
import duckdb
from multiprocessing import Pool
from sklearn.model_selection import train_test_split
import argparse
from pathlib import Path
import model_builder, engine  # Assuming these are your custom modules

# Set up hyperparameters and system configuration
PACK_NODE_DIM = 9
PACK_EDGE_DIM = 1
NODE_DIM = PACK_NODE_DIM * 8
EDGE_DIM = PACK_EDGE_DIM * 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 4
DEFAULT_BATCH_SIZE = 128
torch.manual_seed(42)


# Download data from a specified path
def download_data(path):
    con = duckdb.connect()
    sql_query = "(SELECT * FROM read_csv('{}') WHERE binds = 0 ORDER BY random() LIMIT 15000000) UNION ALL (SELECT * FROM read_csv('{}') WHERE binds = 1 ORDER BY random())".format(path, path)
    df_list = []  # Liste pour stocker les blocs de DataFrame
    try:
        result = con.execute(sql_query)
        while True:
            df_chunk = result.fetch_df_chunk(1000)  # recup les donnees par chunks de 10 000
            if df_chunk.empty:
                break
            df_list.append(df_chunk)  # Ajoute chaque bloc à la liste
        df = pd.concat(df_list, ignore_index=True)  # Concatene tous les blocs en un seul DataFrame
    except Exception as e:
        print("An error occurred: " + str(e))
        raise
    finally:
        con.close()
    return df

# Preprocess the data'
def preprocessing(df):
    data_train = [smile.replace('[Dy]', 'C') for smile in df["molecule_smiles"]]
    labels = df["binds"]
    return data_train, labels

# Convert SMILES strings to graph data
def smile_to_graph(args):
	smiles, label = args
	mol = Chem.MolFromSmiles(smiles)
	N = mol.GetNumAtoms()
	node_feature = []
	edge_feature = []
	edge = []
	for i in range(mol.GetNumAtoms()):
		atom_i = mol.GetAtomWithIdx(i)
		atom_i_features = get_atom_feature(atom_i)
		node_feature.append(atom_i_features)

		for j in range(mol.GetNumAtoms()):
			bond_ij = mol.GetBondBetweenAtoms(i, j)
			if bond_ij is not None:
				edge.append([i, j])
				bond_features_ij = get_bond_feature(bond_ij)
				edge_feature.append(bond_features_ij)
	node_feature=np.stack(node_feature)
	edge_feature=np.stack(edge_feature)
	edge = np.array(edge,dtype=np.uint8)
	return N,edge,node_feature,edge_feature, label

######### Helper functions for feature extraction from smile
def one_of_k_encoding(x, allowable_set, allow_unk=False):
	if x not in allowable_set:
		if allow_unk:
			x = allowable_set[-1]
		else:
			raise Exception('input {} not in allowable set{allowable_set}!!!'.format(x))
	return list(map(lambda s: x == s, allowable_set))


##Get features of an atom (one-hot encoding:)
'''
	1.atom element: 44+1 dimensions    
	2.the atom's hybridization: 5 dimensions
	3.degree of atom: 6 dimensions                        
	4.total number of H bound to atom: 6 dimensions
	5.number of implicit H bound to atom: 6 dimensions    
	6.whether the atom is on ring: 1 dimension
	7.whether the atom is aromatic: 1 dimension           
	Total: 70 dimensions
'''

ATOM_SYMBOL = [
	'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg',
	'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl',
	'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H',
	'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
	'Pt', 'Hg', 'Pb', 'Dy',
	#'Unknown'
]
##print('ATOM_SYMBOL', len(ATOM_SYMBOL))44
HYBRIDIZATION_TYPE = [
	Chem.rdchem.HybridizationType.S,
	Chem.rdchem.HybridizationType.SP,
	Chem.rdchem.HybridizationType.SP2,
	Chem.rdchem.HybridizationType.SP3,
	Chem.rdchem.HybridizationType.SP3D
]

def get_atom_feature(atom):
	feature = (
		 one_of_k_encoding(atom.GetSymbol(), ATOM_SYMBOL)
	   + one_of_k_encoding(atom.GetHybridization(), HYBRIDIZATION_TYPE)
	   + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
	   + one_of_k_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5])
	   + one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])
	   + [atom.IsInRing()]
	   + [atom.GetIsAromatic()]
	)
	##feature = np.array(feature, dtype=np.uint8)
	feature = np.packbits(feature)
	return feature


##Get features of an edge (one-hot encoding)
'''
	1.single/double/triple/aromatic: 4 dimensions       
	2.the atom's hybridization: 1 dimensions
	3.whether the bond is on ring: 1 dimension          
	Total: 6 dimensions
'''

def get_bond_feature(bond):
	bond_type = bond.GetBondType()
	feature = [
		bond_type == Chem.rdchem.BondType.SINGLE,
		bond_type == Chem.rdchem.BondType.DOUBLE,
		bond_type == Chem.rdchem.BondType.TRIPLE,
		bond_type == Chem.rdchem.BondType.AROMATIC,
		bond.GetIsConjugated(),
		bond.IsInRing()
	]
	##feature = np.array(feature, dtype=np.uint8)
	feature = np.packbits(feature)
	return feature
#########


# Convert into a list of graphs 
def to_pyg_list(graph):
    L = len(graph)
    for i in tqdm(range(L)):
        N, edge, node_feature, edge_feature, label = graph[i]
        graph[i] = Data(
            idx=i,
            edge_index=torch.from_numpy(edge.T).int().to(device),
            x=torch.from_numpy(node_feature).byte().to(device),
            edge_attr=torch.from_numpy(edge_feature).byte().to(device),
            y=torch.tensor(label).long().to(device),
        )
        torch.cuda.empty_cache()  # Liberer la memoire CUDA non utilisee apres chaque iteration
    return graph


# Convert a list of graphs to PyTorch Geometric DataLoader
def to_pyg_loader(graphs, batch_size=32):
    return PyGDataLoader(graphs, batch_size=batch_size)


# Main function to get data in the required format
def get_data_good_format(path, batch_size=32):
    df = download_data(path)
    smiles, labels = preprocessing(df)
    # Transformer les Smiles en Graph 
    train_data = list(zip(smiles, labels))
    num_train = len(train_data)
    with Pool(NUM_WORKERS) as pool:
        train_graphs = list(tqdm(pool.imap(smile_to_graph, train_data), total=num_train))
    # Transformer les graph en objet Data Pytorch 
    train_graphs = to_pyg_list(train_graphs)
    # Separe les donnees en ensembles dentrainement, de validation et de test
    train_val_graphs, test_graphs = train_test_split(train_graphs, test_size=0.1, random_state=42)
    train_graphs, val_graphs = train_test_split(train_val_graphs, test_size=0.1, random_state=42)
    # Cree des DataLoader pour chaque ensemble de donnees
    #train_loader = PyGDataLoader(train_graphs, batch_size=32, shuffle=True)
    #val_loader = PyGDataLoader(val_graphs, batch_size=32, shuffle=False)
    #test_loader = PyGDataLoader(test_graphs, batch_size=32, shuffle=False)
    return train_graphs, val_graphs, test_graphs

def main():
    parser = argparse.ArgumentParser(description="Train a GNN model on chemical data.")
    parser.add_argument('--train_path', type=str, default='/user1/icmub/lg361770/Calculs/IA/leash_compet/train.csv', help='Path to the training data CSV file.')
    parser.add_argument('--epochs', type=int, default=15, help='Number of epochs to train.')
    args = parser.parse_args()

    train_graphs, val_graphs, test_graphs = get_data_good_format(args.train_path, DEFAULT_BATCH_SIZE)
    train_loader = PyGDataLoader(train_graphs, batch_size=DEFAULT_BATCH_SIZE, shuffle=True)
    val_loader = PyGDataLoader(val_graphs, batch_size=DEFAULT_BATCH_SIZE, shuffle=False)
    test_loader = PyGDataLoader(test_graphs, batch_size=DEFAULT_BATCH_SIZE, shuffle=False)

    model = model_builder.Net().to(device)
    engine.run_experiment(model=model, model_name=model.__class__.__name__, val_loader=val_loader, test_loader=test_loader, train_loader=train_loader, n_epochs=args.epochs)

    MODEL_PATH = Path("/user1/icmub/lg361770/Calculs/IA/leash_compet/data_15M")
    MODEL_PATH.mkdir(parents=True, exist_ok=True)
    MODEL_SAVE_PATH = MODEL_PATH / "01_pytorch_GNN_15_000_000.pth"
    torch.save(obj=model.state_dict(), f=MODEL_SAVE_PATH)
    print(f"Model saved to: {MODEL_SAVE_PATH}")

if __name__ == "__main__":
    main()




