In [None]:
"""
Contains various utility functions for PyTorch embedding teste set and predict. 
"""

import os
import torch
import numpy as np
import pandas as pd
import h5py
from torch_geometric.data import Data, DataLoader as PyGDataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import rdkit
from rdkit import Chem
import duckdb
from multiprocessing import Pool
from sklearn.model_selection import train_test_split
import argparse


print('import ok!')


# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Running on:', device)

NUM_WORKERS = 4
DEFAULT_BATCH_SIZE = 128
DEFAULT_TEST_PATH = '/user1/icmub/lg361770/Calculs/IA/leash_compet/test.csv'

# Download data from a specified path
def download_data(path):
    con = duckdb.connect()
    try:
        # Utilisation de la methode format pour eviter les f-strings
        sql_query = "(SELECT * FROM read_csv('{}'))".format(path)
        df = con.query(sql_query).df()
    except Exception as e:
        # Utilisation de la concatenation classique pour la gestion des erreurs
        print("An error occurred: " + str(e))
        raise
    finally:
        con.close()
    return df

############## Preprocess the data
def preprocessing(df):
    data_test = [smile.replace('[Dy]', 'C') for smile in df["molecule_smiles"]]
    Id = df["id"]
    return data_test, Id


def one_of_k_encoding(x, allowable_set, allow_unk=False):
	if x not in allowable_set:
		if allow_unk:
			x = allowable_set[-1]
		else:
			raise Exception(f'input {x} not in allowable set{allowable_set}!!!')
	return list(map(lambda s: x == s, allowable_set))


#Get features of an atom (one-hot encoding:)
'''
	1.atom element: 44+1 dimensions    
	2.the atom's hybridization: 5 dimensions
	3.degree of atom: 6 dimensions                        
	4.total number of H bound to atom: 6 dimensions
	5.number of implicit H bound to atom: 6 dimensions    
	6.whether the atom is on ring: 1 dimension
	7.whether the atom is aromatic: 1 dimension           
	Total: 70 dimensions
'''

ATOM_SYMBOL = [
	'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg',
	'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl',
	'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H',
	'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
	'Pt', 'Hg', 'Pb', 'Dy',
	#'Unknown'
]
#print('ATOM_SYMBOL', len(ATOM_SYMBOL))44
HYBRIDIZATION_TYPE = [
	Chem.rdchem.HybridizationType.S,
	Chem.rdchem.HybridizationType.SP,
	Chem.rdchem.HybridizationType.SP2,
	Chem.rdchem.HybridizationType.SP3,
	Chem.rdchem.HybridizationType.SP3D
]

def get_atom_feature(atom):
	feature = (
		 one_of_k_encoding(atom.GetSymbol(), ATOM_SYMBOL)
	   + one_of_k_encoding(atom.GetHybridization(), HYBRIDIZATION_TYPE)
	   + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
	   + one_of_k_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5])
	   + one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])
	   + [atom.IsInRing()]
	   + [atom.GetIsAromatic()]
	)
	#feature = np.array(feature, dtype=np.uint8)
	feature = np.packbits(feature)
	return feature


#Get features of an edge (one-hot encoding)
'''
	1.single/double/triple/aromatic: 4 dimensions       
	2.the atom's hybridization: 1 dimensions
	3.whether the bond is on ring: 1 dimension          
	Total: 6 dimensions
'''

def get_bond_feature(bond):
	bond_type = bond.GetBondType()
	feature = [
		bond_type == Chem.rdchem.BondType.SINGLE,
		bond_type == Chem.rdchem.BondType.DOUBLE,
		bond_type == Chem.rdchem.BondType.TRIPLE,
		bond_type == Chem.rdchem.BondType.AROMATIC,
		bond.GetIsConjugated(),
		bond.IsInRing()
	]
	#feature = np.array(feature, dtype=np.uint8)
	feature = np.packbits(feature)
	return feature

##############
## def pour transformer des smiles en graph, uniquement pour le test set qui ne contient pas les Target 
def to_pyg_list(graph):
	L = len(graph)
	for i in tqdm(range(L)):
		N, edge, node_feature, edge_feature, Id = graph[i]
		graph[i] = Data(
			idx=i,
			edge_index=torch.from_numpy(edge.T).int(),
			x=torch.from_numpy(node_feature).byte(),
			edge_attr=torch.from_numpy(edge_feature).byte(),
            Id=torch.tensor(Id, dtype=torch.int32) 
		)
	return graph


def to_pyg_format(N,edge,node_feature,edge_feature, Id):
	graph = Data(
		idx=-1,
		edge_index = torch.from_numpy(edge.T).int(),
		x          = torch.from_numpy(node_feature).byte(),
		edge_attr  = torch.from_numpy(edge_feature).byte(),
        Id=torch.tensor(Id, dtype=torch.int32) 
	)
	return graph


def smile_to_graph(args):
	smiles, Id = args
	mol = Chem.MolFromSmiles(smiles)
	N = mol.GetNumAtoms()
	node_feature = []
	edge_feature = []
	edge = []
	for i in range(mol.GetNumAtoms()):
		atom_i = mol.GetAtomWithIdx(i)
		atom_i_features = get_atom_feature(atom_i)
		node_feature.append(atom_i_features)

		for j in range(mol.GetNumAtoms()):
			bond_ij = mol.GetBondBetweenAtoms(i, j)
			if bond_ij is not None:
				edge.append([i, j])
				bond_features_ij = get_bond_feature(bond_ij)
				edge_feature.append(bond_features_ij)
	node_feature=np.stack(node_feature)
	edge_feature=np.stack(edge_feature)
	edge = np.array(edge,dtype=np.uint8)
	return N,edge,node_feature,edge_feature, Id

# Main function to get data in the required format
def get_data_good_format(path, batch_size=32):
    df = download_data(path)
    smiles, Id = preprocessing(df)
    # Transformer les Smiles en Graph 
    test_data = list(zip(smiles, Id))
    num_test = len(test_data)
    with Pool(NUM_WORKERS) as pool:
        test_graphs = list(tqdm(pool.imap(smile_to_graph, test_data), total=num_test))
    # Transformer les graph en objet Data Pytorch 
    test_graphs = to_pyg_list(test_graphs)
    # Separe les donnees en ensembles dentrainement, de validation et de test
    #train_val_graphs, test_graphs = train_test_split(train_graphs, test_size=0.1, random_state=42)
    #train_graphs, val_graphs = train_test_split(train_val_graphs, test_size=0.1, random_state=42)
    # Cree des DataLoader pour chaque ensemble de donnees
    #train_loader = PyGDataLoader(train_graphs, batch_size=32, shuffle=True)
    #val_loader = PyGDataLoader(val_graphs, batch_size=32, shuffle=False)
    #test_loader = PyGDataLoader(test_graphs, batch_size=32, shuffle=False)
    return test_graphs



"""
Trains a PyTorch image classification model using device-agnostic code.
"""


# helper
# torch version of np unpackbits
#https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a

def tensor_dim_slice(tensor, dim, dim_slice):
	return tensor[(dim if dim >= 0 else dim + tensor.dim()) * (slice(None),) + (dim_slice,)]

# @torch.jit.script
def packshape(shape, dim: int = -1, mask: int = 0b00000001, dtype=torch.uint8, pack=True):
	dim = dim if dim >= 0 else dim + len(shape)
	bits, nibble = (
		8 if dtype is torch.uint8 else 16 if dtype is torch.int16 else 32 if dtype is torch.int32 else 64 if dtype is torch.int64 else 0), (
		1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111 else 0)
	# bits = torch.iinfo(dtype).bits # does not JIT compile
	assert nibble <= bits and bits % nibble == 0
	nibbles = bits // nibble
	shape = (shape[:dim] + (int(math.ceil(shape[dim] / nibbles)),) + shape[1 + dim:]) if pack else (
				shape[:dim] + (shape[dim] * nibbles,) + shape[1 + dim:])
	return shape, nibbles, nibble

# @torch.jit.script
def F_unpackbits(tensor, dim: int = -1, mask: int = 0b00000001, shape=None, out=None, dtype=torch.uint8):
	dim = dim if dim >= 0 else dim + tensor.dim()
	shape_, nibbles, nibble = packshape(tensor.shape, dim=dim, mask=mask, dtype=tensor.dtype, pack=False)
	shape = shape if shape is not None else shape_
	out = out if out is not None else torch.empty(shape, device=tensor.device, dtype=dtype)
	assert out.shape == shape

	if shape[dim] % nibbles == 0:
		shift = torch.arange((nibbles - 1) * nibble, -1, -nibble, dtype=torch.uint8, device=tensor.device)
		shift = shift.view(nibbles, *((1,) * (tensor.dim() - dim - 1)))
		return torch.bitwise_and((tensor.unsqueeze(1 + dim) >> shift).view_as(out), mask, out=out)

	else:
		for i in range(nibbles):
			shift = nibble * i
			sliced_output = tensor_dim_slice(out, dim, slice(i, None, nibbles))
			sliced_input = tensor.narrow(dim, 0, sliced_output.shape[dim])
			torch.bitwise_and(sliced_input >> shift, mask, out=sliced_output)
	return out

class dotdict(dict):
	__setattr__ = dict.__setitem__
	__delattr__ = dict.__delitem__
	
	def __getattr__(self, name):
		try:
			return self[name]
		except KeyError:
			raise AttributeError(name)


# Setup hyperparameters
PACK_NODE_DIM =9
PACK_EDGE_DIM =1
NODE_DIM =PACK_NODE_DIM*8
EDGE_DIM =PACK_EDGE_DIM*8


### Model 
class MPNNLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim
        self.mlp_msg = nn.Sequential(
            nn.Linear(2 * emb_dim + edge_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
            nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU()
        )
        self.mlp_upd = nn.Sequential(
            nn.Linear(2 * emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
            nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU()
        )

    def forward(self, h, edge_index, edge_attr):
        out = self.propagate(edge_index, h=h, edge_attr=edge_attr)
        return out

    def message(self, h_i, h_j, edge_attr):
        msg = torch.cat([h_i, h_j, edge_attr], dim=-1)
        return self.mlp_msg(msg)

    def aggregate(self, inputs, index):
        return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)

    def update(self, aggr_out, h):
        upd_out = torch.cat([h, aggr_out], dim=-1)
        return self.mlp_upd(upd_out)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')


class MPNNModel(nn.Module):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=9, edge_dim=4):
        super().__init__()
        self.lin_in = nn.Linear(in_dim, emb_dim)
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add'))
        self.pool = global_mean_pool

    def forward(self, batch):
        h = self.lin_in(F_unpackbits(batch.x,-1).float())  

        for conv in self.convs:
            h = h + conv(h, batch.edge_index.long(), F_unpackbits(batch.edge_attr,-1).float())  # (n, d) -> (n, d)

        h_graph = self.pool(h, batch.batch)  
        return h_graph

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.output_type = ['infer']
        graph_dim = 96
        self.smile_encoder = MPNNModel(
            in_dim=NODE_DIM, edge_dim=EDGE_DIM, emb_dim=graph_dim, num_layers=4,
        )
        self.bind = nn.Sequential(
            nn.Linear(graph_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(512, 1),
        )

    def forward(self, batch):
        # Passer le batch complet à smile_encoder
        x = self.smile_encoder(batch)
        bind = self.bind(x).squeeze(-1)

        output = {}
        if 'loss' in self.output_type:
            target = batch.y  # Assurez-vous que target est dans batch.y
            output['bce_loss'] = F.binary_cross_entropy_with_logits(bind, target.float())
        if 'infer' in self.output_type:
            probs = torch.sigmoid(bind)
            output['bind'] = probs
            output['preds'] = (probs >= 0.5).float()

        return output


# Fonction pour charger le model, effectuer des prédictions, sauver les données

def load_model(model_path):
    model = Net()
    model.load_state_dict(torch.load(model_path, map_location=device))
    return model


def predict(model, loader, device):
    model.eval()
    predictions = []
    Id = []
    with torch.no_grad():
        for batch in tqdm(loader):
            batch = batch.to(device)
            output = model(batch)
            predictions.append(output['bind'].cpu().numpy())
            Id.extend(batch.Id.cpu().numpy())
    return np.concatenate(predictions, axis=0), Id


def save_predictions(predictions, ids, file_path):
    output_df = pd.DataFrame({'id': ids, 'binds': predictions})
    output_df.to_csv(file_path, index=False, header=True)
    print(f"Predictions saved to {file_path}")



# Define your main
def main():
    parser = argparse.ArgumentParser(description="Train a GNN model on chemical data.")
    parser.add_argument('--test_path', type=str, default=DEFAULT_TEST_PATH, help='Path to the testing data CSV file.')
    args = parser.parse_args()

    test_graphs = get_data_good_format(args.test_path, DEFAULT_BATCH_SIZE)
    test_loader = PyGDataLoader(test_graphs, batch_size=DEFAULT_BATCH_SIZE, shuffle=False)

    model = load_model("/user1/icmub/lg361770/Calculs/IA/leash_compet/data_15M/01_pytorch_GNN_15_000_000.pth")
    model.to(device).eval()

    predictions, ids = predict(model, test_loader, device)
    save_predictions(predictions, ids, 'output_15M.csv')

if __name__ == "__main__":
    main()






