In [None]:
import lmdb
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import random
import pickle

import pandas as pd
# data = pd.read_csv("/home/jovyan/PharmaBench/data/final_datasets/logd_reg_final_data.csv")
data = pd.read_csv("/home/jovyan/PharmaBench/data/final_datasets/bbb_cls_final_data.csv")
scaffold_training = data[data['scaffold_train_test_label'] == 'train']
scaffold_test = data[data['scaffold_train_test_label'] == 'test']
scaffold_training = scaffold_training.reset_index()
scaffold_test = scaffold_test.reset_index()


class Balance_LMDBDataset(Dataset):
    def __init__(self, db_path, conf, pad_len=200, mode="train"):
        self.db_path = db_path
        assert os.path.isfile(self.db_path), "{} not found".format(
            self.db_path
        )
        env = self.connect_db(self.db_path)
        with env.begin() as txn:
            self._keys = list(txn.cursor().iternext(values=False))
        
        # get correponding label from csv
        self.value_list = []
        self.pos_idx_list = []
        self.neg_idx_list = []
        self.max_len = 0
        if mode == "train":
            pd_data = scaffold_training
        else:
            pd_data = scaffold_test
        
        self.max_value = 0.
        self.min_value = 999.
        for i in range(len(self._keys)):
            datapoint_pickled = env.begin().get(self._keys[i])
            data = pickle.loads(datapoint_pickled)
            current_len = len(data['atoms'])
            if self.max_len < current_len:
                self.max_len = current_len
            idx = pd_data['Smiles_unify'][pd_data['Smiles_unify']==data['smi']].index[0]
            self.value_list.append(pd_data['value'][idx])
            
            # to resample and balance pos/neg num.
            if pd_data['value'][idx] == 0:
                self.neg_idx_list.append(idx)
            else:
                self.pos_idx_list.append(idx)
                
            # for normalization of reprogression task
            if self.max_value < pd_data['value'][idx]:
                self.max_value = pd_data['value'][idx]
            if self.min_value > pd_data['value'][idx]:
                self.min_value = pd_data['value'][idx]
        
        # for classfication task, to balance number of postive/negtive samples
        # num_p = 0
        # num_n = 0
        # for lab in self.value_list:
        #     if lab == 1: num_p += 1
        #     else: num_n += 1
        print(f"number of postive/negtive samples {len(self.pos_idx_list)}/{len(self.neg_idx_list)}.")
            
        
        # get word embedding index (atoms) ps:only use one time and fix the dict result 
        # self.atom_dict = {}
        # idx = 0 
        # for i in range(len(self._keys)):
        #     datapoint_pickled = env.begin().get(self._keys[i])
        #     data = pickle.loads(datapoint_pickled)
        #     for a in data['atoms']:
        #         if a not in self.atom_dict:
        #             self.atom_dict.update({a:idx})
        #             idx += 1
        self.atom_dict = {'Br': 0, 'C': 1, 'H': 2, 'N': 3, 'O': 4, 'F': 5, 'Cl': 6, 'S': 7, 'P': 8, 'I': 9, 'B': 10, 'Se': 11, 'Ar': 12, 'Kr': 13, 'Li': 14, 'Ne': 15, 'Xe': 16, 'Si': 17}
        print(self.atom_dict)
        
        self.pad_len = pad_len
        self.conf = conf # conformation
        self.mode = mode
        print(f"{mode} set is initialized successfully. The max length of the atom is {self.max_len}. The number of dataset is {len(self._keys)}. Padding length is {self.pad_len}.")
        
                    

    def connect_db(self, lmdb_path, save_to_self=False):
        env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=256,
        )
        if not save_to_self:
            return env
        else:
            self.env = env

    def __len__(self):
        return len(self.pos_idx_list)
    
    def min_max_norm(self, x):
        _min = x.min()
        _max = x.max()
        x = (x - _min) / (_max - _min)
        return x
    def __getitem__(self, idx_p):
        if not hasattr(self, 'env'):
            self.connect_db(self.db_path, save_to_self=True)
            
        #################################    Get the positive sample    #################################
        datapoint_pickled = self.env.begin().get(self._keys[idx_p])
        data = pickle.loads(datapoint_pickled)
        # random select conformation
        conf_idx = torch.randperm(11)[:self.conf]
        coordinates = torch.tensor(np.array(data['coordinates']), dtype=torch.float32)[conf_idx, :, :]
        emb_idx = torch.tensor([self.atom_dict[atom] for atom in data['atoms']], dtype=torch.long)
        cur_len = len(emb_idx)
        
        # shortest path distance
        spd = torch.tensor(np.array(data["SPD"]), dtype=torch.float32)
        edge = torch.tensor(np.array(data["edge"]), dtype=torch.float32) + torch.eye(cur_len)
        
        # random dropout atoms
        if np.random.rand() < 0.5 and self.mode == "train" and cur_len > 100:
            num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
            num = random.choice(num_list)
            emb_idx = emb_idx[-num:]
            coordinates = coordinates[:, -num:,:]
            spd = spd[-num:, -num:]
            edge = edge[-num:, -num:]
            cur_len = len(emb_idx)
        
        # padding
        if cur_len < self.pad_len:
            new_emb = torch.full(size=(self.pad_len,), fill_value=self.pad_len-1, dtype=torch.long)
            new_emb[:cur_len] = emb_idx
            
            new_cor = torch.full(size=(self.conf, self.pad_len, 3), fill_value=0, dtype=torch.float32)
            new_cor[:, :cur_len, :] = coordinates
            
            new_spd = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_spd[:cur_len, :cur_len] = spd
            new_edge = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_edge[:cur_len, :cur_len] = edge
        elif cur_len >= self.pad_len:
            new_emb = emb_idx[:self.pad_len]
            new_cor = coordinates[:, :self.pad_len, :]
            new_spd = spd[:self.pad_len, :self.pad_len]
            new_edge = edge[:self.pad_len, :self.pad_len]
        
        # Normalize and augment coordination
        if self.mode == "train": # for random augmentation
            weight_list = [0.1, 0.2, 0.3, 0.5, 0.7, 0.01, 0.001]
            scale = random.choice(weight_list) 
            noise = scale * torch.randn_like(new_cor)
            if np.random.rand() < 0.5:  # for add noise locally
                mask = torch.randint_like(noise, 0, 2, dtype=torch.float32)
                noise = noise * mask
            new_cor = new_cor + noise
            
            # new_cor = new_cor + 1. * torch.rand_like(new_cor)
        # new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
            
        # to compute the pair relative distance
        atom_expanded = new_cor.unsqueeze(2)  # shape (conf, pad_len, 1, 3)
        coor_expanded = new_cor.unsqueeze(1)   # shape (conf, 1, pad_len, 3)
        # distance = atom_expanded - coor_expanded
        # distance = distance.permute(1, 0, 2, 3).reshape(-1, conf*pad_len*3)   # xyz 
        distance = torch.sqrt((atom_expanded - coor_expanded).pow(2).sum(dim=-1))   # x+y+z
        distance = distance.permute(1, 0, 2).reshape(-1, self.conf*self.pad_len)
        distance = (self.min_max_norm(distance) - 0.5) / 0.5
        
        # label = torch.tensor(self.value_list[idx], dtype=torch.float32)
        # label = (label - self.min_value) / (self.max_value - self.min_value)
        
        new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
        new_cor = new_cor.permute(1, 0, 2).reshape(-1, self.conf*3)
        # new_spd = (self.min_max_norm(new_spd) - 0.5) / 0.5
        # new_edge = (self.min_max_norm(new_edge) - 0.5) / 0.5
        
        new_emb_p = np.array(new_emb)
        new_cor_p = np.array(new_cor)
        distance_p = np.array(distance)
        new_spd_p = np.array(new_spd)
        new_edge_p = np.array(new_edge)
        
        
        #################################    Get the negtive sample    #################################
        idx_n = np.random.randint(0, len(self.neg_idx_list))
        datapoint_pickled = self.env.begin().get(self._keys[idx_n])
        data = pickle.loads(datapoint_pickled)
        # random select conformation
        conf_idx = torch.randperm(11)[:self.conf]
        coordinates = torch.tensor(np.array(data['coordinates']), dtype=torch.float32)[conf_idx, :, :]
        emb_idx = torch.tensor([self.atom_dict[atom] for atom in data['atoms']], dtype=torch.long)
        cur_len = len(emb_idx)
        
        # shortest path distance
        spd = torch.tensor(np.array(data["SPD"]), dtype=torch.float32)
        edge = torch.tensor(np.array(data["edge"]), dtype=torch.float32) + torch.eye(cur_len)
        
        # random dropout atoms
        if np.random.rand() < 0.5 and self.mode == "train" and cur_len > 100:
            num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
            num = random.choice(num_list)
            emb_idx = emb_idx[-num:]
            coordinates = coordinates[:, -num:,:]
            spd = spd[-num:, -num:]
            edge = edge[-num:, -num:]
            cur_len = len(emb_idx)
        
        # padding
        if cur_len < self.pad_len:
            new_emb = torch.full(size=(self.pad_len,), fill_value=self.pad_len-1, dtype=torch.long)
            new_emb[:cur_len] = emb_idx
            
            new_cor = torch.full(size=(self.conf, self.pad_len, 3), fill_value=0, dtype=torch.float32)
            new_cor[:, :cur_len, :] = coordinates
            
            new_spd = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_spd[:cur_len, :cur_len] = spd
            new_edge = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_edge[:cur_len, :cur_len] = edge
        elif cur_len >= self.pad_len:
            new_emb = emb_idx[:self.pad_len]
            new_cor = coordinates[:, :self.pad_len, :]
            new_spd = spd[:self.pad_len, :self.pad_len]
            new_edge = edge[:self.pad_len, :self.pad_len]
        
        # Normalize and augment coordination
        if self.mode == "train": # for random augmentation
            weight_list = [0.1, 0.2, 0.3, 0.5, 0.7, 0.01, 0.001]
            scale = random.choice(weight_list) 
            noise = scale * torch.randn_like(new_cor)
            if np.random.rand() < 0.5:  # for add noise locally
                mask = torch.randint_like(noise, 0, 2, dtype=torch.float32)
                noise = noise * mask
            new_cor = new_cor + noise
            
            # new_cor = new_cor + 1. * torch.rand_like(new_cor)
        # new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
            
        # to compute the pair relative distance
        atom_expanded = new_cor.unsqueeze(2)  # shape (conf, pad_len, 1, 3)
        coor_expanded = new_cor.unsqueeze(1)   # shape (conf, 1, pad_len, 3)
        # distance = atom_expanded - coor_expanded
        # distance = distance.permute(1, 0, 2, 3).reshape(-1, conf*pad_len*3)   # xyz 
        distance = torch.sqrt((atom_expanded - coor_expanded).pow(2).sum(dim=-1))   # x+y+z
        distance = distance.permute(1, 0, 2).reshape(-1, self.conf*self.pad_len)
        distance = (self.min_max_norm(distance) - 0.5) / 0.5
        
        # label = torch.tensor(self.value_list[idx], dtype=torch.float32)
        # label = (label - self.min_value) / (self.max_value - self.min_value)
        
        new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
        new_cor = new_cor.permute(1, 0, 2).reshape(-1, self.conf*3)
        # new_spd = (self.min_max_norm(new_spd) - 0.5) / 0.5
        # new_edge = (self.min_max_norm(new_edge) - 0.5) / 0.5
        
        new_emb_n = np.array(new_emb)
        new_cor_n = np.array(new_cor)
        distance_n = np.array(distance)
        new_spd_n = np.array(new_spd)
        new_edge_n = np.array(new_edge)
       
        return new_emb_p, new_emb_n, \
        new_cor_p, new_cor_n, \
        distance_p, distance_n, \
        new_spd_p, new_spd_n, \
        new_edge_p, new_edge_n
        # return {"atoms": new_emb, "coordinate": new_cor, "distance": distance, "SPD": new_spd, "edge": new_edge, "label": label}
    
    def collate_fn(self, batch):
        new_emb_p, new_emb_n, \
        new_cor_p, new_cor_n, \
        distance_p, distance_n, \
        new_spd_p, new_spd_n, \
        new_edge_p, new_edge_n = zip(*batch)
        data={}

        data['atoms']=torch.cat([torch.tensor(new_emb_p), torch.tensor(new_emb_n)], 0)
        data['coordinate']=torch.cat([torch.tensor(new_cor_p).float(), torch.tensor(new_cor_n).float()], 0)
        data['distance']=torch.cat([torch.tensor(distance_p).float(), torch.tensor(distance_n).float()], 0)
        data['SPD']=torch.cat([torch.tensor(new_spd_p).float(), torch.tensor(new_spd_n).float()], 0)
        data['edge']=torch.cat([torch.tensor(new_edge_p).float(), torch.tensor(new_edge_n).float()], 0)
        data['label']=torch.tensor([1]*len(new_emb_p)+[0]*len(new_emb_n))
        return data
    
    def worker_init_fn(self, worker_id):
        np.random.seed(np.random.get_state()[1][0] + worker_id)
        
if __name__ == "__main__":
    lmdb_file = './results/bbb_train.lmdb'
    train_dataset = Balance_LMDBDataset(lmdb_file, conf=10, pad_len=150, mode="train")
    train_set = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=4,
                                                shuffle=False,
                                                pin_memory=True,
                                                num_workers=0,
                                                collate_fn=train_dataset.collate_fn,
                                                worker_init_fn=train_dataset.worker_init_fn
                                                )
    for datas in train_set:
        print(datas["atoms"].shape)
        # print(datas["coordinate"].shape)
        # print(datas["label"])


In [None]:
torch.randperm(11)[:10]

In [None]:
import lmdb
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import random
import pickle
import pandas as pd


def find_min_max(x_list):
    x_np = np.array(x_list)
    print(x_np.max(), x_np.min())
    

def min_max_norm(x_list, max_v, min_v):
    x_np = np.array(x_list)
    x_np = (x_np - min_v) / (max_v - min_v)
    x_list = np.tolist(x_np)
    return x_list


class HERG_Multi_Class_LMDBDataset(Dataset):
    def __init__(self, conf, pad_len=200, mode="train"):
        if mode == "train":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_train.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_train_data.csv"
        elif mode == "val":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_val.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_valid_data.csv"
        elif mode == "week1":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_week1.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_week1_1201.csv"
        elif mode == "week2":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_week2.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_week2_1201.csv"
        elif mode == "week3":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_week3.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_week3_1201.csv"
        elif mode == "week4":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_week4.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_week4_1201.csv"
        pd_data = pd.read_csv(csv_path)
        
        self.db_path = db_path
        assert os.path.isfile(self.db_path), "{} not found".format(
            self.db_path
        )
        env = self.connect_db(self.db_path)
        with env.begin() as txn:
            self._keys = list(txn.cursor().iternext(values=False))
        
        # get correponding label from csv
        self.value_list = []
        self.logd_list = []
        self.logp_list = []
        self.pka_list = []
        self.pkb_list = []
        self.logsol_list = []
        self.wlogsol_list = []
        self.max_len = 0
        
        self.max_value = 0.
        self.min_value = 999.
        for i in range(len(self._keys)):
            datapoint_pickled = env.begin().get(self._keys[i])
            data = pickle.loads(datapoint_pickled)
            current_len = len(data['atoms'])
            if self.max_len < current_len:
                self.max_len = current_len
            idx = pd_data['smiles'][pd_data['smiles']==data['smi']].index[0]
            self.value_list.append(pd_data['class'][idx])
            
            self.logd_list.append(pd_data['LogD_pred'][idx])
            self.logp_list.append(pd_data['LogP_pred'][idx])
            self.pka_list.append(pd_data['pKa_class_pred'][idx])
            self.pkb_list.append(pd_data['pKb_class_pred'][idx])
            self.logsol_list.append(pd_data['LogSol_pred'][idx])
            self.wlogsol_list.append(pd_data['wLogSol_pred'][idx])
        
        # find_min_max(self.logd_list)
        # find_min_max(self.logp_list)
        # find_min_max(self.logsol_list)
        # find_min_max(self.wlogsol_list)
        # 5.75390625 -0.93310546875
        # 7.89453125 -1.91796875
        # 2.880859375 -0.54931640625
        # 3.5625 -9.6171875
        min_max_norm()
        
        
        # for classfication task, to balance number of postive/negtive samples
        num_p = 0
        num_n = 0
        for lab in self.value_list:
            if lab == 1: num_p += 1
            else: num_n += 1
        print(f"number of positive/negative samples {num_p}/{num_n}.")
            
        
        # get word embedding index (atoms) ps:only use one time and fix the dict result 
        # self.atom_dict = {}
        # idx = 0 
        # for i in range(len(self._keys)):
        #     datapoint_pickled = env.begin().get(self._keys[i])
        #     data = pickle.loads(datapoint_pickled)
        #     for a in data['atoms']:
        #         if a not in self.atom_dict:
        #             self.atom_dict.update({a:idx})
        #             idx += 1
        self.atom_dict = {'Br': 0, 'C': 1, 'H': 2, 'N': 3, 'O': 4, 'F': 5, 'Cl': 6, 'S': 7, 'P': 8, 'I': 9, 'B': 10, 'Se': 11, 'Ar': 12, 'Kr': 13, 'Li': 14, 'Ne': 15, 'Xe': 16, 'Si': 17, 'Na': 18}
        print(self.atom_dict)
        
        self.pad_len = pad_len
        self.conf = conf # conformation
        self.mode = mode
        print(f"{mode} set is initialized successfully. The max length of the atom is {self.max_len}. The number of dataset is {len(self._keys)}. Padding length is {self.pad_len}.")
        
                    

    def connect_db(self, lmdb_path, save_to_self=False):
        env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=256,
        )
        if not save_to_self:
            return env
        else:
            self.env = env

    def __len__(self):
        return len(self._keys)
    
    def min_max_norm(self, x):
        _min = x.min()
        _max = x.max()
        x = (x - _min) / (_max - _min)
        return x

    def __getitem__(self, idx):
        if not hasattr(self, 'env'):
            self.connect_db(self.db_path, save_to_self=True)
        datapoint_pickled = self.env.begin().get(self._keys[idx])
        data = pickle.loads(datapoint_pickled)
        coordinates = torch.tensor(np.array(data['coordinates']), dtype=torch.float32)[:self.conf, :, :]
        emb_idx = torch.tensor([self.atom_dict[atom] for atom in data['atoms']], dtype=torch.long)
        cur_len = len(emb_idx)
        
        # shortest path distance
        spd = torch.tensor(np.array(data["SPD"]), dtype=torch.float32)
        edge = torch.tensor(np.array(data["edge"]), dtype=torch.float32) + torch.eye(cur_len)
        
        # random dropout atoms
        if np.random.rand() < 0.5 and self.mode == "train" and cur_len > 100:
            num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
            num = random.choice(num_list)
            emb_idx = emb_idx[-num:]
            coordinates = coordinates[:, -num:,:]
            spd = spd[-num:, -num:]
            edge = edge[-num:, -num:]
            cur_len = len(emb_idx)
        
        # padding
        if cur_len < self.pad_len:
            new_emb = torch.full(size=(self.pad_len,), fill_value=self.pad_len-1, dtype=torch.long)
            new_emb[:cur_len] = emb_idx
            
            new_cor = torch.full(size=(self.conf, self.pad_len, 3), fill_value=0, dtype=torch.float32)
            new_cor[:, :cur_len, :] = coordinates
            
            new_spd = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_spd[:cur_len, :cur_len] = spd
            new_edge = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_edge[:cur_len, :cur_len] = edge
        elif cur_len >= self.pad_len:
            new_emb = emb_idx[:self.pad_len]
            new_cor = coordinates[:, :self.pad_len, :]
            new_spd = spd[:self.pad_len, :self.pad_len]
            new_edge = edge[:self.pad_len, :self.pad_len]
        
        # Normalize and augment coordination
        if self.mode == "train": # for random augmentation
            weight_list = [0.1, 0.2, 0.3, 0.5, 0.7, 0.01, 0.001]
            scale = random.choice(weight_list) 
            noise = scale * torch.randn_like(new_cor)
            if np.random.rand() < 0.5:  # for add noise locally
                mask = torch.randint_like(noise, 0, 2, dtype=torch.float32)
                noise = noise * mask
            new_cor = new_cor + noise
            
            # new_cor = new_cor + 1. * torch.rand_like(new_cor)
        # new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
            
        # to compute the pair relative distance
        atom_expanded = new_cor.unsqueeze(2)  # shape (conf, pad_len, 1, 3)
        coor_expanded = new_cor.unsqueeze(1)   # shape (conf, 1, pad_len, 3)
        # distance = atom_expanded - coor_expanded
        # distance = distance.permute(1, 0, 2, 3).reshape(-1, conf*pad_len*3)   # xyz 
        distance = torch.sqrt((atom_expanded - coor_expanded).pow(2).sum(dim=-1))   # x+y+z
        distance = distance.permute(1, 0, 2).reshape(-1, self.conf*self.pad_len)
        distance = (self.min_max_norm(distance) - 0.5) / 0.5
        
        label = torch.tensor(self.value_list[idx], dtype=torch.float32)
        # label = (label - self.min_value) / (self.max_value - self.min_value)
        
        new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
        new_cor = new_cor.permute(1, 0, 2).reshape(-1, self.conf*3)
        # new_spd = (self.min_max_norm(new_spd) - 0.5) / 0.5
        # new_edge = (self.min_max_norm(new_edge) - 0.5) / 0.5
       
        
        return {"atoms": new_emb, "coordinate": new_cor, "distance": distance, "SPD": new_spd, "edge": new_edge, "label": label}
        

    def worker_init_fn(self, worker_id):
        np.random.seed(np.random.get_state()[1][0] + worker_id)


class HERG_LMDBDataset(Dataset):
    def __init__(self, conf, pad_len=200, mode="train"):
        if mode == "train":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_train.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_train_data.csv"
        elif mode == "val":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_val.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_valid_data.csv"
        elif mode == "week1":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_week1.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_week1_1201.csv"
        elif mode == "week2":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_week2.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_week2_1201.csv"
        elif mode == "week3":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_week3.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_week3_1201.csv"
        elif mode == "week4":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_week4.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_week4_1201.csv"
        pd_data = pd.read_csv(csv_path)
        
        self.db_path = db_path
        assert os.path.isfile(self.db_path), "{} not found".format(
            self.db_path
        )
        env = self.connect_db(self.db_path)
        with env.begin() as txn:
            self._keys = list(txn.cursor().iternext(values=False))
        
        # get correponding label from csv
        self.value_list = []
        self.max_len = 0
        
        self.max_value = 0.
        self.min_value = 999.
        for i in range(len(self._keys)):
            datapoint_pickled = env.begin().get(self._keys[i])
            data = pickle.loads(datapoint_pickled)
            current_len = len(data['atoms'])
            if self.max_len < current_len:
                self.max_len = current_len
            idx = pd_data['smiles'][pd_data['smiles']==data['smi']].index[0]
            self.value_list.append(pd_data['class'][idx])
            
            # for normalization of reprogression task
            # if self.max_value < pd_data['value'][idx]:
            #     self.max_value = pd_data['value'][idx]
            # if self.min_value > pd_data['value'][idx]:
            #     self.min_value = pd_data['value'][idx]
        
        # for classfication task, to balance number of postive/negtive samples
        num_p = 0
        num_n = 0
        for lab in self.value_list:
            if lab == 1: num_p += 1
            else: num_n += 1
        print(f"number of positive/negative samples {num_p}/{num_n}.")
            
        
        # get word embedding index (atoms) ps:only use one time and fix the dict result 
        # self.atom_dict = {}
        # idx = 0 
        # for i in range(len(self._keys)):
        #     datapoint_pickled = env.begin().get(self._keys[i])
        #     data = pickle.loads(datapoint_pickled)
        #     for a in data['atoms']:
        #         if a not in self.atom_dict:
        #             self.atom_dict.update({a:idx})
        #             idx += 1
        self.atom_dict = {'Br': 0, 'C': 1, 'H': 2, 'N': 3, 'O': 4, 'F': 5, 'Cl': 6, 'S': 7, 'P': 8, 'I': 9, 'B': 10, 'Se': 11, 'Ar': 12, 'Kr': 13, 'Li': 14, 'Ne': 15, 'Xe': 16, 'Si': 17, 'Na': 18}
        print(self.atom_dict)
        
        self.pad_len = pad_len
        self.conf = conf # conformation
        self.mode = mode
        print(f"{mode} set is initialized successfully. The max length of the atom is {self.max_len}. The number of dataset is {len(self._keys)}. Padding length is {self.pad_len}.")
        
                    

    def connect_db(self, lmdb_path, save_to_self=False):
        env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=256,
        )
        if not save_to_self:
            return env
        else:
            self.env = env

    def __len__(self):
        return len(self._keys)
    
    def min_max_norm(self, x):
        _min = x.min()
        _max = x.max()
        x = (x - _min) / (_max - _min)
        return x

    def __getitem__(self, idx):
        if not hasattr(self, 'env'):
            self.connect_db(self.db_path, save_to_self=True)
        datapoint_pickled = self.env.begin().get(self._keys[idx])
        data = pickle.loads(datapoint_pickled)
        coordinates = torch.tensor(np.array(data['coordinates']), dtype=torch.float32)[:self.conf, :, :]
        emb_idx = torch.tensor([self.atom_dict[atom] for atom in data['atoms']], dtype=torch.long)
        cur_len = len(emb_idx)
        
        # shortest path distance
        spd = torch.tensor(np.array(data["SPD"]), dtype=torch.float32)
        edge = torch.tensor(np.array(data["edge"]), dtype=torch.float32) + torch.eye(cur_len)
        
        # random dropout atoms
        if np.random.rand() < 0.5 and self.mode == "train" and cur_len > 100:
            num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
            num = random.choice(num_list)
            emb_idx = emb_idx[-num:]
            coordinates = coordinates[:, -num:,:]
            spd = spd[-num:, -num:]
            edge = edge[-num:, -num:]
            cur_len = len(emb_idx)
        
        # padding
        if cur_len < self.pad_len:
            new_emb = torch.full(size=(self.pad_len,), fill_value=self.pad_len-1, dtype=torch.long)
            new_emb[:cur_len] = emb_idx
            
            new_cor = torch.full(size=(self.conf, self.pad_len, 3), fill_value=0, dtype=torch.float32)
            new_cor[:, :cur_len, :] = coordinates
            
            new_spd = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_spd[:cur_len, :cur_len] = spd
            new_edge = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_edge[:cur_len, :cur_len] = edge
        elif cur_len >= self.pad_len:
            new_emb = emb_idx[:self.pad_len]
            new_cor = coordinates[:, :self.pad_len, :]
            new_spd = spd[:self.pad_len, :self.pad_len]
            new_edge = edge[:self.pad_len, :self.pad_len]
        
        # Normalize and augment coordination
        if self.mode == "train": # for random augmentation
            weight_list = [0.1, 0.2, 0.3, 0.5, 0.7, 0.01, 0.001]
            scale = random.choice(weight_list) 
            noise = scale * torch.randn_like(new_cor)
            if np.random.rand() < 0.5:  # for add noise locally
                mask = torch.randint_like(noise, 0, 2, dtype=torch.float32)
                noise = noise * mask
            new_cor = new_cor + noise
            
            # new_cor = new_cor + 1. * torch.rand_like(new_cor)
        # new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
            
        # to compute the pair relative distance
        atom_expanded = new_cor.unsqueeze(2)  # shape (conf, pad_len, 1, 3)
        coor_expanded = new_cor.unsqueeze(1)   # shape (conf, 1, pad_len, 3)
        # distance = atom_expanded - coor_expanded
        # distance = distance.permute(1, 0, 2, 3).reshape(-1, conf*pad_len*3)   # xyz 
        distance = torch.sqrt((atom_expanded - coor_expanded).pow(2).sum(dim=-1))   # x+y+z
        distance = distance.permute(1, 0, 2).reshape(-1, self.conf*self.pad_len)
        distance = (self.min_max_norm(distance) - 0.5) / 0.5
        
        label = torch.tensor(self.value_list[idx], dtype=torch.float32)
        # label = (label - self.min_value) / (self.max_value - self.min_value)
        
        new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
        new_cor = new_cor.permute(1, 0, 2).reshape(-1, self.conf*3)
        # new_spd = (self.min_max_norm(new_spd) - 0.5) / 0.5
        # new_edge = (self.min_max_norm(new_edge) - 0.5) / 0.5
       
        
        return {"atoms": new_emb, "coordinate": new_cor, "distance": distance, "SPD": new_spd, "edge": new_edge, "label": label}
        

    def worker_init_fn(self, worker_id):
        np.random.seed(np.random.get_state()[1][0] + worker_id)

        
if __name__ == "__main__":
    # train_dataset = HERG_LMDBDataset(conf=10, pad_len=150, mode="train")
    train_dataset = HERG_Multi_Class_LMDBDataset(conf=10, pad_len=150, mode="train")
    train_set = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=2,
                                                shuffle=False,
                                                pin_memory=True,
                                                num_workers=0,
                                                # collate_fn=train_dataset.collate_fn,
                                                worker_init_fn=train_dataset.worker_init_fn
                                                )
    # for datas in train_set:
    #     print(datas["atoms"].shape)
    #     # print(datas["coordinate"].shape)
    #     # print(datas["label"])
    #     # exit(0)

In [None]:
import lmdb
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import random
import pickle

import pandas as pd
data = pd.read_csv("/home/jovyan/prompts_learning/bbb_cls_final_data_multi_class.csv")
# data = pd.read_csv("/home/jovyan/PharmaBench/data/final_datasets/bbb_cls_final_data.csv")
scaffold_training = data[data['scaffold_train_test_label'] == 'train']
scaffold_test = data[data['scaffold_train_test_label'] == 'test']
scaffold_training = scaffold_training.reset_index()
scaffold_test = scaffold_test.reset_index()


def find_min_max(x_list):
    x_np = np.array(x_list)
    print(x_np.max(), x_np.min())
    

def list_min_max_norm(x_list, max_v, min_v):
    x_np = np.array(x_list)
    x_np = (x_np - min_v) / (max_v - min_v)
    print(x_np.min(), x_np.max(0))
    x_list = x_np.tolist()
    return x_list


class LMDBDataset(Dataset):
    def __init__(self, db_path, conf, pad_len=200, mode="train"):
        self.db_path = db_path
        assert os.path.isfile(self.db_path), "{} not found".format(
            self.db_path
        )
        env = self.connect_db(self.db_path)
        with env.begin() as txn:
            self._keys = list(txn.cursor().iternext(values=False))
        
        # get correponding label from csv
        self.value_list = []
        self.logd_list = []
        self.logp_list = []
        self.pka_list = []
        self.pkb_list = []
        self.logsol_list = []
        self.wlogsol_list = []
        self.max_len = 0
        if mode == "train":
            pd_data = scaffold_training
        else:
            pd_data = scaffold_test
        
        self.max_value = 0.
        self.min_value = 999.
        for i in range(len(self._keys)):
            datapoint_pickled = env.begin().get(self._keys[i])
            data = pickle.loads(datapoint_pickled)
            current_len = len(data['atoms'])
            if self.max_len < current_len:
                self.max_len = current_len
            idx = pd_data['Smiles_unify'][pd_data['Smiles_unify']==data['smi']].index[0]
            self.value_list.append(pd_data['value'][idx])
            
            self.logd_list.append(pd_data['LogD_pred'][idx])
            self.logp_list.append(pd_data['LogP_pred'][idx])
            self.pka_list.append(pd_data['pKa_class_pred'][idx])
            self.pkb_list.append(pd_data['pKb_class_pred'][idx])
            self.logsol_list.append(pd_data['LogSol_pred'][idx])
            self.wlogsol_list.append(pd_data['wLogSol_pred'][idx])
            
            # # for normalization of reprogression task
            # if self.max_value < pd_data['value'][idx]:
            #     self.max_value = pd_data['value'][idx]
            # if self.min_value > pd_data['value'][idx]:
            #     self.min_value = pd_data['value'][idx]
        
        
        # find_min_max(self.logd_list)
        # find_min_max(self.logp_list)
        # find_min_max(self.logsol_list)
        # find_min_max(self.wlogsol_list)
        # 8.7890625 -3.41796875
        # 12.796875 -5.53125
        # 3.017578125 -0.51806640625
        # 3.92578125 -12.3984375
        self.logd_list = list_min_max_norm(self.logd_list, 8.7890625, -3.41796875)
        self.logp_list = list_min_max_norm(self.logp_list, 12.796875, -5.53125)
        self.logsol_list = list_min_max_norm(self.logsol_list, 3.017578125, -0.51806640625)
        self.wlogsol_list = list_min_max_norm(self.wlogsol_list, 3.92578125, -12.3984375)
        
        # for classfication task, to balance number of postive/negtive samples
        num_p = 0
        num_n = 0
        for lab in self.value_list:
            if lab == 1: num_p += 1
            else: num_n += 1
        print(f"number of postive/negtive samples {num_p}/{num_n}.")
            
        
        # get word embedding index (atoms) ps:only use one time and fix the dict result 
        # self.atom_dict = {}
        # idx = 0 
        # for i in range(len(self._keys)):
        #     datapoint_pickled = env.begin().get(self._keys[i])
        #     data = pickle.loads(datapoint_pickled)
        #     for a in data['atoms']:
        #         if a not in self.atom_dict:
        #             self.atom_dict.update({a:idx})
        #             idx += 1
        # self.atom_dict = {'Br': 0, 'C': 1, 'H': 2, 'N': 3, 'O': 4, 'F': 5, 'Cl': 6, 'S': 7, 'P': 8, 'I': 9, 'B': 10, 'Se': 11, 'Ar': 12, 'Kr': 13, 'Li': 14, 'Ne': 15, 'Xe': 16, 'Si': 17}
        
        self.atom_dict = {'Br': 0, 'C': 1, 'H': 2, 'N': 3, 'O': 4, 'F': 5, 'Cl': 6, 'S': 7, 'P': 8, 'I': 9, 'B': 10, 'Se': 11, 'Ar': 12, 'Kr': 13, 'Li': 14, 'Ne': 15, 'Xe': 16, 'Si': 17, 'Na': 18}
        print(self.atom_dict)
        
        self.pad_len = pad_len
        self.conf = conf # conformation
        self.mode = mode
        print(f"{mode} set is initialized successfully. The max length of the atom is {self.max_len}. The number of dataset is {len(self._keys)}. Padding length is {self.pad_len}.")
        
                    

    def connect_db(self, lmdb_path, save_to_self=False):
        env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=256,
        )
        if not save_to_self:
            return env
        else:
            self.env = env

    def __len__(self):
        return len(self._keys)
    
    def min_max_norm(self, x):
        _min = x.min()
        _max = x.max()
        x = (x - _min) / (_max - _min)
        return x

    def __getitem__(self, idx):
        if not hasattr(self, 'env'):
            self.connect_db(self.db_path, save_to_self=True)
        datapoint_pickled = self.env.begin().get(self._keys[idx])
        data = pickle.loads(datapoint_pickled)
        coordinates = torch.tensor(np.array(data['coordinates']), dtype=torch.float32)[:self.conf, :, :]
        emb_idx = torch.tensor([self.atom_dict[atom] for atom in data['atoms']], dtype=torch.long)
        cur_len = len(emb_idx)
        
        # shortest path distance
        spd = torch.tensor(np.array(data["SPD"]), dtype=torch.float32)
        edge = torch.tensor(np.array(data["edge"]), dtype=torch.float32) + torch.eye(cur_len)
        
        # random dropout atoms
        if np.random.rand() < 0.5 and self.mode == "train" and cur_len > 100:
            num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
            num = random.choice(num_list)
            emb_idx = emb_idx[-num:]
            coordinates = coordinates[:, -num:,:]
            spd = spd[-num:, -num:]
            edge = edge[-num:, -num:]
            cur_len = len(emb_idx)
        
        # padding
        if cur_len < self.pad_len:
            new_emb = torch.full(size=(self.pad_len,), fill_value=self.pad_len-1, dtype=torch.long)
            new_emb[:cur_len] = emb_idx
            
            new_cor = torch.full(size=(self.conf, self.pad_len, 3), fill_value=0, dtype=torch.float32)
            new_cor[:, :cur_len, :] = coordinates
            
            new_spd = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_spd[:cur_len, :cur_len] = spd
            new_edge = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_edge[:cur_len, :cur_len] = edge
        elif cur_len >= self.pad_len:
            new_emb = emb_idx[:self.pad_len]
            new_cor = coordinates[:, :self.pad_len, :]
            new_spd = spd[:self.pad_len, :self.pad_len]
            new_edge = edge[:self.pad_len, :self.pad_len]
        
        # Normalize and augment coordination
        if self.mode == "train": # for random augmentation
            weight_list = [0.1, 0.2, 0.3, 0.5, 0.7, 0.01, 0.001]
            scale = random.choice(weight_list) 
            noise = scale * torch.randn_like(new_cor)
            if np.random.rand() < 0.5:  # for add noise locally
                mask = torch.randint_like(noise, 0, 2, dtype=torch.float32)
                noise = noise * mask
            new_cor = new_cor + noise
            
            # new_cor = new_cor + 1. * torch.rand_like(new_cor)
        # new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
            
        # to compute the pair relative distance
        atom_expanded = new_cor.unsqueeze(2)  # shape (conf, pad_len, 1, 3)
        coor_expanded = new_cor.unsqueeze(1)   # shape (conf, 1, pad_len, 3)
        # distance = atom_expanded - coor_expanded
        # distance = distance.permute(1, 0, 2, 3).reshape(-1, conf*pad_len*3)   # xyz 
        distance = torch.sqrt((atom_expanded - coor_expanded).pow(2).sum(dim=-1))   # x+y+z
        distance = distance.permute(1, 0, 2).reshape(-1, self.conf*self.pad_len)
        distance = (self.min_max_norm(distance) - 0.5) / 0.5
        
        label = torch.tensor(self.value_list[idx], dtype=torch.float32)
        logd = torch.tensor(self.logd_list[idx], dtype=torch.float32)
        logp = torch.tensor(self.logp_list[idx], dtype=torch.float32)
        pka = torch.tensor(self.pka_list[idx], dtype=torch.float32)
        pkb = torch.tensor(self.pkb_list[idx], dtype=torch.float32)
        logsol = torch.tensor(self.logsol_list[idx], dtype=torch.float32)
        wlogsol = torch.tensor(self.wlogsol_list[idx], dtype=torch.float32)
        # label = (label - self.min_value) / (self.max_value - self.min_value)
        
        new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
        new_cor = new_cor.permute(1, 0, 2).reshape(-1, self.conf*3)
        # new_spd = (self.min_max_norm(new_spd) - 0.5) / 0.5
        # new_edge = (self.min_max_norm(new_edge) - 0.5) / 0.5
       
        
        return {"atoms": new_emb, "coordinate": new_cor, "distance": distance, "SPD": new_spd, "edge": new_edge, "label": label,\
               "logd": logd, "logp": logp, "pka": pka, "pkb": pkb, "logsol": wlogsol, "wlogsol": wlogsol}
        # return {"atoms": new_emb, "coordinate": new_cor, "distance": distance, "SPD": new_spd, "edge": new_edge, "label": label}
        

    def worker_init_fn(self, worker_id):
        np.random.seed(np.random.get_state()[1][0] + worker_id)

        
if __name__ == "__main__":
    lmdb_file = './results/bbb_train.lmdb'
    train_dataset = LMDBDataset(lmdb_file, conf=10, pad_len=150, mode="train")
    train_set = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=2,
                                                shuffle=False,
                                                pin_memory=True,
                                                num_workers=0,
                                                # collate_fn=train_dataset.collate_fn,
                                                worker_init_fn=train_dataset.worker_init_fn
                                                )
    for datas in train_set:
        print(datas["atoms"].shape)
        # print(datas["coordinate"].shape)
        # print(datas["label"])
        exit(0)

In [None]:
import lmdb
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import random
import pickle

import pandas as pd
data = pd.read_csv("/home/jovyan/prompts_learning/bbb_cls_final_data_multi_class.csv")
# data = pd.read_csv("/home/jovyan/PharmaBench/data/final_datasets/bbb_cls_final_data.csv")
scaffold_training = data[data['scaffold_train_test_label'] == 'train']
scaffold_test = data[data['scaffold_train_test_label'] == 'test']
scaffold_training = scaffold_training.reset_index()
scaffold_test = scaffold_test.reset_index()



class LMDBDataset(Dataset):
    def __init__(self, db_path, conf, pad_len=200, mode="train"):
        self.db_path = db_path
        assert os.path.isfile(self.db_path), "{} not found".format(
            self.db_path
        )
        env = self.connect_db(self.db_path)
        with env.begin() as txn:
            self._keys = list(txn.cursor().iternext(values=False))
        
        # get correponding label from csv
        self.value_list = []
        self.max_len = 0
        if mode == "train":
            pd_data = scaffold_training
        else:
            pd_data = scaffold_test
        
        self.max_value = 0.
        self.min_value = 999.
        for i in range(len(self._keys)):
            datapoint_pickled = env.begin().get(self._keys[i])
            data = pickle.loads(datapoint_pickled)
            current_len = len(data['atoms'])
            if self.max_len < current_len:
                self.max_len = current_len
            idx = pd_data['Smiles_unify'][pd_data['Smiles_unify']==data['smi']].index[0]
            self.value_list.append(pd_data['value'][idx])
            
  

            
        self.atom_dict = {'Br': 0, 'C': 1, 'H': 2, 'N': 3, 'O': 4, 'F': 5, 'Cl': 6, 'S': 7, 'P': 8, 'I': 9, 'B': 10, 'Se': 11, 'Ar': 12, 'Kr': 13, 'Li': 14, 'Ne': 15, 'Xe': 16, 'Si': 17, 'Na': 18, 'Mask': 19}
        print(self.atom_dict)
        self.mask_token_id = self.atom_dict['Mask']
        
        self.pad_len = pad_len
        self.conf = conf # conformation
        self.mode = mode
        print(f"{mode} set is initialized successfully. The max length of the atom is {self.max_len}. The number of dataset is {len(self._keys)}. Padding length is {self.pad_len}.")
        
                    

    def connect_db(self, lmdb_path, save_to_self=False):
        env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=256,
        )
        if not save_to_self:
            return env
        else:
            self.env = env

    def __len__(self):
        return len(self._keys)
    
    def min_max_norm(self, x):
        _min = x.min()
        _max = x.max()
        x = (x - _min) / (_max - _min)
        return x
    
    def random_mask(self, atoms, mask_ratio=0.2):
        print("before", atoms)
        len_masked = int(len(atoms) * mask_ratio)
        noise = np.random.rand(len(atoms))  # noise in [0, 1]
        
        # sort noise for each sample
        ids_masked = np.argsort(noise)[:len_masked]  # ascend: small is keep, large is remove
        output_label = []
        
        output_label = atoms[ids_masked]
        atoms[ids_masked] = self.mask_token_id
        print("masked", atoms)
        print("lab", output_label)

        return atoms, output_label

    def __getitem__(self, idx):
        if not hasattr(self, 'env'):
            self.connect_db(self.db_path, save_to_self=True)
        datapoint_pickled = self.env.begin().get(self._keys[idx])
        data = pickle.loads(datapoint_pickled)
        coordinates = torch.tensor(np.array(data['coordinates']), dtype=torch.float32)[:self.conf, :, :]
        emb_idx = torch.tensor([self.atom_dict[atom] for atom in data['atoms']], dtype=torch.long)
        cur_len = len(emb_idx)
        
        # shortest path distance
        spd = torch.tensor(np.array(data["SPD"]), dtype=torch.float32)
        edge = torch.tensor(np.array(data["edge"]), dtype=torch.float32) + torch.eye(cur_len)
        
        # random dropout atoms
        if np.random.rand() < 0.5 and self.mode == "train" and cur_len > 100:
            num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
            num = random.choice(num_list)
            emb_idx = emb_idx[-num:]
            coordinates = coordinates[:, -num:,:]
            spd = spd[-num:, -num:]
            edge = edge[-num:, -num:]
            cur_len = len(emb_idx)
            
        
        # padding
        if cur_len < self.pad_len:
            new_emb = torch.full(size=(self.pad_len,), fill_value=self.pad_len-1, dtype=torch.long)
            new_emb[:cur_len] = emb_idx
            
            new_cor = torch.full(size=(self.conf, self.pad_len, 3), fill_value=0, dtype=torch.float32)
            new_cor[:, :cur_len, :] = coordinates
            
            new_spd = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_spd[:cur_len, :cur_len] = spd
            new_edge = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_edge[:cur_len, :cur_len] = edge
        elif cur_len >= self.pad_len:
            new_emb = emb_idx[:self.pad_len]
            new_cor = coordinates[:, :self.pad_len, :]
            new_spd = spd[:self.pad_len, :self.pad_len]
            new_edge = edge[:self.pad_len, :self.pad_len]
        
        # Normalize and augment coordination
        if self.mode == "train": # for random augmentation
            weight_list = [0.1, 0.2, 0.3, 0.5, 0.7, 0.01, 0.001]
            # weight_list = [0.1, 0.2, 0.3, 0.01, 0.001]
            scale = random.choice(weight_list) 
            noise = scale * torch.randn_like(new_cor)
            if np.random.rand() < 0.5:  # for add noise locally
                mask = torch.randint_like(noise, 0, 2, dtype=torch.float32)
                noise = noise * mask
            new_cor = new_cor + noise
            
            
        # to compute the pair relative distance
        atom_expanded = new_cor.unsqueeze(2)  # shape (conf, pad_len, 1, 3)
        coor_expanded = new_cor.unsqueeze(1)   # shape (conf, 1, pad_len, 3)
        # distance = atom_expanded - coor_expanded
        # distance = distance.permute(1, 0, 2, 3).reshape(-1, conf*pad_len*3)   # xyz 
        distance = torch.sqrt((atom_expanded - coor_expanded).pow(2).sum(dim=-1))   # x+y+z
        distance = distance.permute(1, 0, 2).reshape(-1, self.conf*self.pad_len)
        distance = (self.min_max_norm(distance) - 0.5) / 0.5
        
        label = torch.tensor(self.value_list[idx], dtype=torch.float32)
        
        new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
        new_cor = new_cor.permute(1, 0, 2).reshape(-1, self.conf*3)
        
        # for pre-training task 1: atom prediction
        new_emb, masked_label = self.random_mask(new_emb)
        # for pre-training task 2: noised distance prediction
        weight_list = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9, 0.01, 0.001]
        scale = random.choice(weight_list) 
        noise_gt = scale * torch.randn_like(distance)
        if np.random.rand() < 0.5:  # for add noise locally
            mask = torch.randint_like(noise, 0, 2, dtype=torch.float32)
            noise_gt = noise_gt * mask
        distance = distance + noise_gt
       
        return {"atoms": new_emb, "coordinate": new_cor, "distance": distance, "SPD": new_spd, "edge": new_edge, "label": label,\
               "masked_label": masked_label, "noise_gt": noise_gt}

        

    def worker_init_fn(self, worker_id):
        np.random.seed(np.random.get_state()[1][0] + worker_id)

        
if __name__ == "__main__":
    lmdb_file = './results/bbb_train.lmdb'
    train_dataset = LMDBDataset(lmdb_file, conf=10, pad_len=150, mode="train")
    train_set = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=2,
                                                shuffle=False,
                                                pin_memory=True,
                                                num_workers=0,
                                                # collate_fn=train_dataset.collate_fn,
                                                worker_init_fn=train_dataset.worker_init_fn
                                                )
    for datas in train_set:
        print(datas["atoms"].shape)
        # print(datas["coordinate"].shape)
        # print(datas["label"])

In [1]:
import lmdb
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import random
import pickle

import pandas as pd
data = pd.read_csv("/home/jovyan/prompts_learning/bbb_cls_final_data_multi_class.csv")
# data = pd.read_csv("/home/jovyan/PharmaBench/data/final_datasets/bbb_cls_final_data.csv")
scaffold_training = data[data['scaffold_train_test_label'] == 'train']
scaffold_test = data[data['scaffold_train_test_label'] == 'test']
scaffold_training = scaffold_training.reset_index()
scaffold_test = scaffold_test.reset_index()



class Pretrain_LMDBDataset(Dataset):
    def __init__(self, conf, pad_len=150, mode="herg"):
        if mode == "herg":
            db_path = "/home/jovyan/prompts_learning/results/herg_cls_train.lmdb"
            csv_path = "/home/jovyan/prompts_learning/herg_dataset/hERGDB_cls_train_data.csv"
            pd_data = pd.read_csv(csv_path)
        elif mode == "bbb":
            db_path = "/home/jovyan/prompts_learning/results/bbb_train.lmdb"
            csv_path = "/home/jovyan/prompts_learning/bbb_cls_final_data_multi_class.csv"
            data = pd.read_csv(csv_path)
            scaffold_training = data[data['scaffold_train_test_label'] == 'train']
            scaffold_training = scaffold_training.reset_index()
            pd_data = scaffold_training
        elif mode == "logd":
            db_path = "/home/jovyan/prompts_learning/results/logd_train.lmdb"
            csv_path = "/home/jovyan/PharmaBench/data/final_datasets/logd_reg_final_data.csv"
            data = pd.read_csv(csv_path)
            scaffold_training = data[data['scaffold_train_test_label'] == 'train']
            scaffold_training = scaffold_training.reset_index()
            pd_data = scaffold_training
        elif mode == "bbb_test":
            db_path = "/home/jovyan/prompts_learning/results/bbb_test.lmdb"
            csv_path = "/home/jovyan/prompts_learning/bbb_cls_final_data_multi_class.csv"
            data = pd.read_csv(csv_path)
            scaffold_training = data[data['scaffold_train_test_label'] == 'test']
            scaffold_training = scaffold_training.reset_index()
            pd_data = scaffold_training
        
        
        self.db_path = db_path
        assert os.path.isfile(self.db_path), "{} not found".format(
            self.db_path
        )
        env = self.connect_db(self.db_path)
        with env.begin() as txn:
            self._keys = list(txn.cursor().iternext(values=False))
        
        # get correponding label from csv
        self.value_list = []
        self.max_len = 0

        
        self.max_value = 0.
        self.min_value = 999.
        for i in range(len(self._keys)):
            datapoint_pickled = env.begin().get(self._keys[i])
            data = pickle.loads(datapoint_pickled)
            current_len = len(data['atoms'])
            if self.max_len < current_len:
                self.max_len = current_len
            if mode == "herg":
                idx = pd_data['smiles'][pd_data['smiles']==data['smi']].index[0]
                self.value_list.append(pd_data['class'][idx])
            else:
                idx = pd_data['Smiles_unify'][pd_data['Smiles_unify']==data['smi']].index[0]
                self.value_list.append(pd_data['value'][idx])
            
        self.atom_dict = {'Br': 0, 'C': 1, 'H': 2, 'N': 3, 'O': 4, 'F': 5, 'Cl': 6, 'S': 7, 'P': 8, 'I': 9, 'B': 10, 'Se': 11, 'Ar': 12, 'Kr': 13, 'Li': 14, 'Ne': 15, 'Xe': 16, 'Si': 17, 'Na': 18, 'Mask': 19, "Pad": 20}
        print(self.atom_dict)
        self.mask_token_id = self.atom_dict['Mask']
        self.pad_token_id = self.atom_dict["Pad"]
        
        self.pad_len = pad_len
        self.conf = conf # conformation
        self.mode = mode
        print(f"{mode} set is initialized successfully. The max length of the atom is {self.max_len}. The number of dataset is {len(self._keys)}. Padding length is {self.pad_len}.")
        
                    

    def connect_db(self, lmdb_path, save_to_self=False):
        env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=256,
        )
        if not save_to_self:
            return env
        else:
            self.env = env

    def __len__(self):
        return len(self._keys)
    
    def min_max_norm(self, x):
        _min = x.min()
        _max = x.max()
        x = (x - _min) / (_max - _min)
        return x
    
    def random_mask(self, atoms, dist, mask_ratio=0.75):
        # print("before", atoms)
        len_masked = int(len(atoms) * mask_ratio)
        noise = np.random.rand(len(atoms))  # noise in [0, 1]
        
        # sort noise for each sample
        ids_masked = np.argsort(noise)[:len_masked]  # ascend: small is keep, large is remove
        output_label = []
        
        output_label = atoms[ids_masked]
        atoms[ids_masked] = self.mask_token_id
        dist[ids_masked, :] = 0.
        # print("masked", atoms)
        # print("lab", output_label)

        return atoms, dist, output_label, ids_masked

    def __getitem__(self, idx):
        if not hasattr(self, 'env'):
            self.connect_db(self.db_path, save_to_self=True)
        datapoint_pickled = self.env.begin().get(self._keys[idx])
        data = pickle.loads(datapoint_pickled)
        coordinates = torch.tensor(np.array(data['coordinates']), dtype=torch.float32)[:self.conf, :, :]
        emb_idx = torch.tensor([self.atom_dict[atom] for atom in data['atoms']], dtype=torch.long)
        cur_len = len(emb_idx)
        
        # shortest path distance
        spd = torch.tensor(np.array(data["SPD"]), dtype=torch.float32)
        edge = torch.tensor(np.array(data["edge"]), dtype=torch.float32) + torch.eye(cur_len)
        
        # random dropout atoms
        if cur_len > 100:
            num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
            num = random.choice(num_list)
            emb_idx = emb_idx[-num:]
            coordinates = coordinates[:, -num:,:]
            spd = spd[-num:, -num:]
            edge = edge[-num:, -num:]
            cur_len = len(emb_idx)
            
        
        # padding
        if cur_len < self.pad_len:
            new_emb = torch.full(size=(self.pad_len,), fill_value=self.pad_token_id, dtype=torch.long)
            new_emb[:cur_len] = emb_idx
            
            new_cor = torch.full(size=(self.conf, self.pad_len, 3), fill_value=0, dtype=torch.float32)
            new_cor[:, :cur_len, :] = coordinates
            
            new_spd = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_spd[:cur_len, :cur_len] = spd
            new_edge = torch.full(size=(self.pad_len, self.pad_len), fill_value=0, dtype=torch.float32)
            new_edge[:cur_len, :cur_len] = edge
        elif cur_len >= self.pad_len:
            new_emb = emb_idx[:self.pad_len]
            new_cor = coordinates[:, :self.pad_len, :]
            new_spd = spd[:self.pad_len, :self.pad_len]
            new_edge = edge[:self.pad_len, :self.pad_len]
        
        # Normalize and augment coordination
      # for random augmentation
        weight_list = [0.1, 0.2, 0.3, 0.5, 0.7, 0.01, 0.001]
        # weight_list = [0.1, 0.2, 0.3, 0.01, 0.001]
        scale = random.choice(weight_list) 
        noise = scale * torch.randn_like(new_cor)
        if np.random.rand() < 0.5:  # for add noise locally
            mask = torch.randint_like(noise, 0, 2, dtype=torch.float32)
            noise = noise * mask
        new_cor = new_cor + noise
            
            
        # to compute the pair relative distance
        atom_expanded = new_cor.unsqueeze(2)  # shape (conf, pad_len, 1, 3)
        coor_expanded = new_cor.unsqueeze(1)   # shape (conf, 1, pad_len, 3)
        # distance = atom_expanded - coor_expanded
        # distance = distance.permute(1, 0, 2, 3).reshape(-1, conf*pad_len*3)   # xyz 
        distance = torch.sqrt((atom_expanded - coor_expanded).pow(2).sum(dim=-1))   # x+y+z
        distance = distance.permute(1, 0, 2).reshape(-1, self.conf*self.pad_len)
        distance = (self.min_max_norm(distance) - 0.5) / 0.5
        
        label = torch.tensor(self.value_list[idx], dtype=torch.float32)
        
        new_cor = (self.min_max_norm(new_cor) - 0.5) / 0.5
        new_cor = new_cor.permute(1, 0, 2).reshape(-1, self.conf*3)
        
        # for pre-training task 1: atom prediction
        new_emb, masked_distance, masked_label, ids_masked = self.random_mask(new_emb, distance)
        # for pre-training task 2: noised distance prediction
        # weight_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        weight_list = [0]
        scale = random.choice(weight_list) 
        noise_gt = scale * torch.randn_like(distance)
        distance = distance # + noise_gt
       
        return {"atoms": new_emb, "coordinate": new_cor, "masked_distance": masked_distance, "distance": distance, "SPD": new_spd, "edge": new_edge, "label": label,\
               "masked_label": masked_label, "ids_masked": ids_masked, "noise_gt": noise_gt}

        

    def worker_init_fn(self, worker_id):
        np.random.seed(np.random.get_state()[1][0] + worker_id)

        
if __name__ == "__main__":
    train_dataset = Pretrain_LMDBDataset(conf=10, pad_len=150, mode="bbb_test")
    train_set = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=2,
                                                shuffle=False,
                                                pin_memory=True,
                                                num_workers=0,
                                                # collate_fn=train_dataset.collate_fn,
                                                worker_init_fn=train_dataset.worker_init_fn
                                                )
    for datas in train_set:
        print(datas["atoms"].shape)
        # print(datas["coordinate"].shape)
        # print(datas["label"])

{'Br': 0, 'C': 1, 'H': 2, 'N': 3, 'O': 4, 'F': 5, 'Cl': 6, 'S': 7, 'P': 8, 'I': 9, 'B': 10, 'Se': 11, 'Ar': 12, 'Kr': 13, 'Li': 14, 'Ne': 15, 'Xe': 16, 'Si': 17, 'Na': 18, 'Mask': 19, 'Pad': 20}
bbb_test set is initialized successfully. The max length of the atom is 196. The number of dataset is 1660. Padding length is 150.
torch.Size([150, 1500])
torch.Size([150, 1500])
torch.Size([2, 150])
torch.Size([150, 1500])
torch.Size([150, 1500])
torch.Size([2, 150])
torch.Size([150, 1500])
torch.Size([150, 1500])
torch.Size([2, 150])
torch.Size([150, 1500])
torch.Size([150, 1500])
torch.Size([2, 150])
torch.Size([150, 1500])
torch.Size([150, 1500])
torch.Size([2, 150])
torch.Size([150, 1500])
torch.Size([150, 1500])
torch.Size([2, 150])
torch.Size([150, 1500])
torch.Size([150, 1500])
torch.Size([2, 150])
torch.Size([150, 1500])
torch.Size([150, 1500])
torch.Size([2, 150])
torch.Size([150, 1500])
torch.Size([150, 1500])
torch.Size([2, 150])
torch.Size([150, 1500])
torch.Size([150, 1500])
torch


KeyboardInterrupt

