In [None]:
import pandas as pd
from SmilesPE.pretokenizer import atomwise_tokenizer
from hdlib.space import Vector, Space
from hdlib.arithmetic import bundle, bind
import random

In [None]:
df = pd.read_csv("tox21.csv")
df = df.dropna(subset=["NR-ER-LBD"]).reset_index(drop=True)

In [None]:
zero_set = list()
one_set = list()

In [None]:
for i, row in df.iterrows():
    val = row['NR-ER-LBD']
    smiles = row['smiles']
    if val == 1:
        one_set.append(smiles)
    else:
        zero_set.append(smiles)

In [None]:
zero_sample = random.sample(zero_set, 100)
one_sample = random.sample(one_set, 100)

In [None]:
zero_shared_space = Space()
one_shared_space = Space()

In [None]:
def encode_sample_set(sample, shared_space):
    all_tokens = list()
    str_vec = dict()
    for hd_vec in sample:
        tokens = atomwise_tokenizer(hd_vec)
        all_tokens.extend(tokens)
    
    shared_space.bulk_insert(all_tokens)
    for hd_vec in sample:
        cur_tokens = atomwise_tokenizer(hd_vec)
        if len(cur_tokens) == 1:
            return shared_space.get(names=[cur_tokens[0]])[0]
        token_vec0 = shared_space.get(names=[cur_tokens[0]])[0]
        token_vec1 = shared_space.get(names=[cur_tokens[1]])[0]
        token_vec0.permute(rotate_by=0)
        token_vec1.permute(rotate_by=1)
        culmination = bind(token_vec0, token_vec1)
        for i in range(2, len(cur_tokens)):
            current_vec = shared_space.get(names=[cur_tokens[i]])[0]
            current_vec.permute(rotate_by=i)
            culmination = bind(culmination, current_vec)
        
        str_vec[hd_vec] = culmination
    mol_vecs = list(str_vec.values())
    class_vec = bundle(mol_vecs[0], mol_vecs[1])
    for i in range(2, len(mol_vecs)):
        current_vec = mol_vecs[i]
        class_vec = bundle(class_vec, current_vec)
    return class_vec

In [None]:
toxic_class_vec = encode_sample_set(zero_sample, zero_shared_space)
nontoxic_class_vec = encode_sample_set(one_sample, one_shared_space)

In [None]:
print(toxic_class_vec)
print(nontoxic_class_vec)