In [20]:
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from tqdm import tqdm
from collections import defaultdict
from rdkit import Chem
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
import math
print(f"Torch version: {torch.__version__}")
print(f"Cuda available: {torch.cuda.is_available()}")
print(f"Torch geometric version: {torch_geometric.__version__}")
import random

random.seed(1234)

mutation_file= r"D:\PHD\Codes\From_GitHub\BIBM 2022\GraphCDR-main\data\Drug\dataset\raw_dir\genomic_mutation_34673_demap_features.csv"
gene_expression_file = r"D:\PHD\Codes\From_GitHub\BIBM 2022\GraphCDR-main\data\Drug\dataset\raw_dir\genomic_expression_561celllines_697genes_demap_features.csv"
methylation_file = r"D:\PHD\Codes\From_GitHub\BIBM 2022\GraphCDR-main\data\Drug\dataset\raw_dir\genomic_methylation_561celllines_808genes_demap_features.csv"
gdsc_file = r"D:\PHD\Codes\From_GitHub\BIBM 2022\GraphCDR-main\data\Drug\dataset\raw_dir\GDSC_IC50.csv"
smiles_file =r"D:\PHD\Codes\From_GitHub\BIBM 2022\GraphCDR-main\data\Drug\dataset\raw_dir\222drugs_pubchem_smiles.txt"
pubchem_file = r"D:\PHD\Codes\From_GitHub\BIBM 2022\GraphCDR-main\data\Drug\dataset\raw_dir\Drug_listMon Jun 24 09_00_55 2019.csv"
root=r"D:\PHD\Codes\From_GitHub\BIBM 2022\GraphCDR-main\data\Drug\dataset"


class MoleculeDataset(Dataset):
    def __init__(self, root, mutation_file,gene_expression_file,methylation_file,gdsc_file,smiles_file,pubchem_file, task_type="single", test=False, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data).
        """
        self.task_type=task_type
        self.test = test
        self.mutation_file,self.gene_expression_file,self.methylation_file,self.gdsc_file,self.smiles_file,self.pubchem_file = mutation_file,gene_expression_file,methylation_file,gdsc_file,smiles_file,pubchem_file
        super(MoleculeDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)
        """
        return  self.mutation_file,self.gene_expression_file,self.methylation_file,self.gdsc_file,self.smiles_file,self.pubchem_file

    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped"""
        self.DrugCelList = self.get_drug_cell_list(self.raw_paths)

        if self.test:
            return [f'data_test_{i}.pt' for i in list(self.DrugCelList.index)]
        else:
            return [f'data_{i}.pt' for i in list(self.DrugCelList.index)]


    def download(self):
        pass

    def process(self):
        self.DrugCelList = self.get_drug_cell_list(self.raw_paths)
        smileslist = self.get_pdframe(self.smiles_file)
        for index, DrugCell in tqdm(self.DrugCelList.iterrows(), total=self.DrugCelList.shape[0]):
            smiles,mutation,gene_expression,methylation,label= self.get_variables(DrugCell,self.raw_paths,smileslist)

          # Featurize molecule
            mol_obj = Chem.MolFromSmiles(smiles)
            #GEt node features
            node_feats=self._get_node_features(mol_obj)
            #get edges features
            edge_feats=self._get_edge_features(mol_obj)
            # Get adjacency info
            edge_index=self._get_adjacency_info(mol_obj)
            # Get label info
            label =label.resize_((1,len(label)))

            data= Data(x=node_feats,
                       edge_index= edge_index,
                       edge_attr=edge_feats,
                        y=label,
                        num_classes = 2,
                        mutation=mutation,
                        gene_expression =gene_expression,
                        methylation=methylation )

            torch.save(data,
                         os.path.join(self.processed_dir,f'data_{index}.pt'))

    def get_drug_cell_list(self,raw_paths):

        mutation_file,gene_expression_file,methylation_file,gdsc_file,smiles_file,pubchem_file =raw_paths[0],raw_paths[1],raw_paths[2],raw_paths[3],raw_paths[4],raw_paths[5]

        mutation= pd.read_csv(mutation_file,header=0,index_col=[0])
        gene_expression = pd.read_csv(gene_expression_file,header=0,index_col=[0])
        methylation = pd.read_csv(methylation_file,header=0,index_col=[0])
        gdsc = pd.read_csv(gdsc_file,header=0,index_col=[0])
        smiles = self.get_pdframe(smiles_file)
        smiles=smiles.set_index("drug_id")
        DrugCelList=pd.DataFrame(columns=["drug", "CID","cell","IC50"])
        pubchem = pd.read_csv(pubchem_file, header=0,index_col=[0])
        processed_mutation = pd.DataFrame()
        processed_gene_expression = pd.DataFrame()
        processed_methylation = pd.DataFrame()
        drug_list=[]
        cell_list=[]

        # remove drug with same CID
        drug_temp=pd.DataFrame(columns=["drug_id","CID"])
        for idx,row in gdsc.iterrows():

           cid=pubchem.loc[int(idx[5:])].at["PubCHEM"]
           if  cid in list(smiles.index.values):
             drug_temp=pd.concat([drug_temp,pd.DataFrame({"drug_id":[int(idx[5:])],"CID":[cid]})])
        drug_temp=drug_temp.drop_duplicates(subset=["CID"],keep="first")
        drug_temp_list=drug_temp.drug_id.values.tolist()

        for idx,row in gdsc.iterrows():
          if int(idx[5:]) in drug_temp_list:
            for column in gdsc.columns:
                if math.isnan(gdsc.loc[idx,column]) == False:
                    add_drug = True
                    if column in list(gene_expression.index.values) and column in list(methylation.index.values) and column in list(mutation.index.values) and int(idx[5:]) in list(pubchem.index.values) :
                        if column not in cell_list:
                            cell_list.append(column)
                        pair=pd.DataFrame({"drug":[idx],"CID":pubchem.loc[int(idx[5:]),"PubCHEM"],"cell":[column],"IC50":[gdsc.loc[idx,column]]})
                        DrugCelList = pd.concat( [DrugCelList,pair],ignore_index=True)
            if add_drug == True:
               print("drug:",int(idx[5:]))
               if idx not in drug_list:
                    drug_list.append(idx)
        # DrugCelList = DrugCelList.drop_duplicates(subset=["CID"],keep=False)
        print("drug-cell list has",len(DrugCelList),"rows")
        DrugCelList.to_csv(r"D:\PHD\Codes\From_GitHub\BIBM 2022\GraphCDR-main\data\Drug\dataset\raw_dir\Drug-Cell-Response.csv")
        print("Total number of drugs:",len(drug_list))
        print("Total number of cell-lines:",len(cell_list))
        return DrugCelList


    def get_variables(self,DrugCell,raw_paths,smileslist):

         mutation_file,gene_expression_file,methylation_file,gdsc_file,smiles_file,pubchem_file =raw_paths[0],raw_paths[1],raw_paths[2],raw_paths[3],raw_paths[4],raw_paths[5]

    # get smiles using drug pubchem ID
         smileslist = smileslist.set_index('drug_id')
         smiles= smileslist.loc[DrugCell["CID"]].at["smiles"]

     # get mutation features
         mutationlist= pd.read_csv(mutation_file,header=0,index_col=[0])
         mutations=[]
         for col in mutationlist.columns:
             mutations.append(mutationlist.loc[DrugCell["cell"]].at[col])
         mutation= torch.Tensor(mutations)

     # get gene_expression
         gene_expressions=[]
         gene_expressionlist = pd.read_csv(gene_expression_file,header=0,index_col=[0])
         for col in gene_expressionlist.columns:
             gene_expressions.append(gene_expressionlist.loc[DrugCell["cell"]].at[col])
         gene_expression= torch.Tensor(gene_expressions)

     # get methylation
         methylations =[]
         methylationlist = pd.read_csv(methylation_file,header=0,index_col=[0])
         for col in methylationlist.columns:
             methylations.append(methylationlist.loc[DrugCell["cell"]].at[col])
         methylation= torch.Tensor(methylations)

         label = torch.Tensor([DrugCell["IC50"]])

         return smiles,mutation,gene_expression,methylation,label

    def get_pdframe(self,filename):
        """
        permet de transfoemer le fichier txt en data frame

        Parameters
        ----------
        filename : TYPE   txt file
            DESCRIPTION.    fichier contenant l ensemble des molecules
        task_type : TYPE, optional  single classification or multi-label classification
            DESCRIPTION. The default is "single".
        Returns dataframe
        """
        Molecule=defaultdict(list)
        with open(filename,'r') as txt:
            for mol in txt.readlines():
                mol=mol.split("\t")
                Molecule["drug_id"].append(mol[0])
                Molecule["smiles"].append(mol[1])



        data=pd.DataFrame.from_dict(Molecule)
        data=data.drop_duplicates(keep='last')
        return data

    def _get_node_features(self, mol):
        """
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_node_feats = []

        for atom in mol.GetAtoms():
            node_feats = []
            # Feature 1: Atomic number
            node_feats.append(atom.GetAtomicNum())
            # Feature 2: Atom degree
            node_feats.append(atom.GetDegree())
            # Feature 3: Formal charge
            node_feats.append(atom.GetFormalCharge())
            # Feature 4: Hybridization
            node_feats.append(atom.GetHybridization())
            # Feature 5: Aromaticity
            node_feats.append(atom.GetIsAromatic())
            # Feature 6: Total Num Hs
            node_feats.append(atom.GetTotalNumHs())
            # Feature 7: Radical Electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            # Feature 8: In Ring
            node_feats.append(atom.IsInRing())
            # Feature 9: Chirality
            node_feats.append(atom.GetChiralTag())

            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        all_node_feats =torch.tensor(all_node_feats, dtype=torch.float)
        all_node_feats =F.normalize(all_node_feats, p=6, dim=-1)
        return  all_node_feats

    def _get_edge_features(self, mol):
        """
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)



    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices


    def _get_labels(self, label):
        label = np.asarray([label])
        return torch.tensor(label, dtype=torch.int64)


    def len(self):
        return self.DrugCelList.shape[0]

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        if self.test:
            data = torch.load(os.path.join(self.processed_dir,
                                 f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir,
                                 f'data_{idx}.pt'))
        return data


def load_dataset(root, mutation_file,gene_expression_file,methylation_file,gdsc_file,smiles_file,pubchem_file, batch_dim=32):
    """
    Parameters
    ----------
    dataset : TYPE tuple
        DESCRIPTION. (root, filename)
    split : TYPE, optional  list
        DESCRIPTION. The default is [0.8,0.1,0.1].
    batch_size : TYPE, optional
        DESCRIPTION. The default is 64.
    Returns
    -------
    """

    dataset= MoleculeDataset(root, mutation_file,gene_expression_file,methylation_file,gdsc_file,smiles_file,pubchem_file)
    dataset.num_classes = dataset[0].num_classes


    n = int(len(dataset)*10/100)
    test_dataset = dataset[:n]
    val_dataset = dataset[n:2*n]
    train_dataset = dataset[2*n:]
    print("training set size:",len(train_dataset))
    print("validation set size:",len(val_dataset))
    print("test set size:",len(test_dataset))

    train_loader = DataLoader(train_dataset, batch_size=batch_dim,shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_dim,shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_dim,shuffle=False)

    return train_loader,val_loader, test_loader
train_loader,val_loader, test_loader = load_dataset(root, mutation_file,gene_expression_file,methylation_file,gdsc_file,smiles_file,pubchem_file)


Torch version: 1.8.0
Cuda available: True
Torch geometric version: 2.0.1
drug: 1
drug: 1001
drug: 1004
drug: 1005
drug: 1006
drug: 1007
drug: 1008
drug: 1009
drug: 1010
drug: 1011
drug: 1012
drug: 1013
drug: 1014
drug: 1015
drug: 1016
drug: 1017
drug: 1018
drug: 1019
drug: 1020
drug: 1021
drug: 1022
drug: 1023
drug: 1024
drug: 1025
drug: 1026
drug: 1028
drug: 1029
drug: 1030
drug: 1031
drug: 1032
drug: 1033
drug: 1036
drug: 1037
drug: 1038
drug: 1039
drug: 104
drug: 1042
drug: 1043
drug: 1046
drug: 1047
drug: 1049
drug: 1050
drug: 1052
drug: 1053
drug: 1054
drug: 1057
drug: 1058
drug: 1059
drug: 106
drug: 1060
drug: 1061
drug: 1062
drug: 1066
drug: 1067
drug: 1069
drug: 1072
drug: 1091
drug: 11
drug: 110
drug: 111
drug: 1114
drug: 1129
drug: 1133
drug: 1149
drug: 1164
drug: 1170
drug: 1175
drug: 119
drug: 1192
drug: 1194
drug: 1199
drug: 1218
drug: 1219
drug: 1230
drug: 1236
drug: 1239
drug: 1241
drug: 1242
drug: 1243
drug: 1248
drug: 1259
drug: 1262
drug: 1264
drug: 1268
drug: 133
dru