# 对数据处理进行初步尝试

In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader, Subset
from typing import List, Tuple
import re


In [76]:
# tokenizer，鉴于SMILES的特性，这里需要自己定义tokenizer和vocab
# 这里直接将smiles str按字符拆分，并替换为词汇表中的序号
class Smiles_tokenizer():
    def __init__(self, pad_token, regex, vocab_file, max_length):
        self.pad_token = pad_token
        self.regex = regex
        self.vocab_file = vocab_file
        self.max_length = max_length

        with open(self.vocab_file, "r") as f:
            lines = f.readlines()
        lines = [line.strip("\n") for line in lines]
        vocab_dic = {}
        for index, token in enumerate(lines):
            vocab_dic[token] = index
        self.vocab_dic = vocab_dic

    def _regex_match(self, smiles):
        regex_string = r"(" + self.regex + r"|"
        regex_string += r".)"
        prog = re.compile(regex_string)

        tokenised = []
        for smi in smiles:
            tokens = prog.findall(smi)
            if len(tokens) > self.max_length:
                tokens = tokens[:self.max_length]
            tokenised.append(tokens) # 返回一个所有的字符串列表
        return tokenised
    
    def tokenize(self, smiles):
        tokens = self._regex_match(smiles)
        # 添加上表示开始和结束的token：<cls>, <end>
        tokens = [["<CLS>"] + token + ["<SEP>"] for token in tokens]
        tokens = self._pad_seqs(tokens, self.pad_token)
        token_idx = self._pad_token_to_idx(tokens)
        return tokens, token_idx

    def _pad_seqs(self, seqs, pad_token):
        pad_length = max([len(seq) for seq in seqs])
        padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs]
        return padded

    def _pad_token_to_idx(self, tokens):
        idx_list = []
        for token in tokens:
            tokens_idx = []
            for i in token:
                if i in self.vocab_dic.keys():
                    tokens_idx.append(self.vocab_dic[i])
                else:
                    self.vocab_dic[i] = max(self.vocab_dic.values()) + 1
                    tokens_idx.append(self.vocab_dic[i])
            idx_list.append(tokens_idx)
        
        return idx_list

# REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
# tokenizer = Smiles_tokenizer("<PAD>", REGEX, "../vocab_full.txt", 10)

# # 注意，这里一定要输入tuple（smiles1, smiles2）
# res = tokenizer.tokenize(('[PH+][CCL2+]Clc1c(Cl)ccc(c1)[C@@H](C)[C@H](N=[N+]=[N-])C(=O)OC.C>C=O', 
#                         'Cc1cc(Br)c2c(c1)N(C(=O)O'))
# # # res2 = tokenizer.tokenize(('[PH+][CCL2+]Clc1c(Cl)ccc(c1)[C@@H](C)[C@H](N=[N+]=[N-])C(=O)OC.C>C=O', 
# # #                           'Cc1cc(Br)c2c(c1)N(C(=O)O'))
# print(res[0])
# print(res[1])
# # print(res2[0])
# print(res2[1])

In [77]:
# 处理数据
def read_data(file_path):
    df = pd.read_csv(file_path)
    reactant1 = df["Reactant1"].tolist()
    reactant2 = df["Reactant2"].tolist()
    product = df["Product"].tolist()
    additive = df["Additive"].tolist()
    solvent = df["Solvent"].tolist()
    react_yield = df["Yield"].tolist()
    
    # 将reactant\additive\solvent拼到一起，之间用.分开。product也拼到一起，用>>分开
    input_data_list = []
    for react1, react2, prod, addi, sol in zip(reactant1, reactant2, product, additive, solvent):
        input_info = ".".join([react1, react2, addi, sol])
        input_info = ">".join([input_info, prod])
        input_data_list.append(input_info)
    output = [(react, y) for react, y in zip(input_data_list, react_yield)]
    return output
# read_data("../dataset/train_data_demo.csv")

In [78]:
# 定义数据集
class ReactionDataset(Dataset):
    def __init__(self, data: List[Tuple[List[str], float]], SMILES_tokenizer):
        self.data = data
        self.smiles_tokenizer = SMILES_tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_info, react_yeild = self.data[idx]
        # input_info = self.smiles_tokenizer.tokenize(input_info) # 对SMILES进行tokenize

        return input_info, react_yeild
    
def collate_fn(batch):
    REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
    tokenizer = Smiles_tokenizer("<PAD>", REGEX, "../vocab_full.txt", 10)
    smi_list = []
    yield_list = []
    for i in batch:
        smi_list.append(i[0])
        yield_list.append(i[1])
    tokenizer_batch = tokenizer.tokenize(smi_list)
    return tokenizer_batch, yield_list

# res1, res2 = collate_fn(('Clc1c(Cl)ccc(c1)[C@@H](C)[C@H](N=[N+]=[N-])C(=O)OC.C>C=O', 'Cc1cc(Br)c2c(c1)N(C(=O)O'))
# print(res1)
# print(res2)


In [79]:
data = read_data("../dataset/train_data_demo.csv")
# REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
# tokenizer = Smiles_tokenizer("<PAD>", REGEX, "../vocab_full.txt") 
dataset = ReactionDataset(data, None)
# 选择数据集的前N个样本进行训练
N = 1  #int(len(dataset) * 1)  # 或者你可以设置为数据集大小的一定比例，如 int(len(dataset) * 0.1)
subset_indices = list(range(3))
subset_dataset = Subset(dataset, subset_indices)
train_loader = DataLoader(subset_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
for i, (src, trg) in enumerate(train_loader):
    print(src)
    print(trg)
    raise KeyError

([['<CLS>', 'c', '1', 'c', 'c', 'c', '2', 'c', '(', 'c', '1', '<SEP>'], ['<CLS>', 'c', '1', 'c', 'c', 'c', '2', 'c', '(', 'c', '1', '<SEP>']], [[1, 7, 8, 7, 7, 7, 11, 7, 9, 7, 8, 3], [1, 7, 8, 7, 7, 7, 11, 7, 9, 7, 8, 3]])
[0.85, 0.9]


KeyError: 